NnetRescaler Class Reference
Collaboration diagram for NnetRescaler:

Public Member Functions

 NnetRescaler (const NnetRescaleConfig &config, const std::vector< NnetExample > &examples, Nnet *nnet)
 
void Rescale ()
 

Private Member Functions

void FormatInput (const std::vector< NnetExample > &data, CuMatrix< BaseFloat > *input)
 takes the input and formats as a single matrix, in forward_data_[0]. More...
 
void RescaleComponent (int32 c, int32 num_chunks, CuMatrixBase< BaseFloat > *cur_data_in, CuMatrix< BaseFloat > *next_data)
 
void ComputeRelevantIndexes ()
 
BaseFloat GetTargetAvgDeriv (int32 c)
 

Private Attributes

const NnetRescaleConfigconfig_
 
const std::vector< NnetExample > & examples_
 
Nnetnnet_
 
std::vector< ChunkInfochunk_info_out_
 
std::set< int32relevant_indexes_
 

Detailed Description

Definition at line 26 of file rescale-nnet.cc.

Constructor & Destructor Documentation

◆ NnetRescaler()

NnetRescaler ( const NnetRescaleConfig config,
const std::vector< NnetExample > &  examples,
Nnet nnet 
)
inline

Member Function Documentation

◆ ComputeRelevantIndexes()

void ComputeRelevantIndexes ( )
private

Definition at line 89 of file rescale-nnet.cc.

References Nnet::GetComponent(), NnetRescaler::nnet_, Nnet::NumComponents(), and NnetRescaler::relevant_indexes_.

Referenced by NnetRescaler::NnetRescaler(), and NnetRescaler::Rescale().

89  {
90  for (int32 c = 0; c + 1 < nnet_->NumComponents(); c++)
91  if (dynamic_cast<AffineComponent*>(&nnet_->GetComponent(c)) != NULL &&
92  (dynamic_cast<NonlinearComponent*>(&nnet_->GetComponent(c+1)) != NULL &&
93  dynamic_cast<SoftmaxComponent*>(&nnet_->GetComponent(c+1)) == NULL))
94  relevant_indexes_.insert(c);
95 }
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
kaldi::int32 int32
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
std::set< int32 > relevant_indexes_
Definition: rescale-nnet.cc:51

◆ FormatInput()

void FormatInput ( const std::vector< NnetExample > &  data,
CuMatrix< BaseFloat > *  input 
)
private

takes the input and formats as a single matrix, in forward_data_[0].

Definition at line 56 of file rescale-nnet.cc.

References NnetRescaler::chunk_info_out_, Nnet::ComputeChunkInfo(), CuMatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::CopyRowsFromVec(), Nnet::InputDim(), KALDI_ASSERT, Nnet::LeftContext(), NnetRescaler::nnet_, CuMatrix< Real >::Resize(), and Nnet::RightContext().

Referenced by NnetRescaler::NnetRescaler(), and NnetRescaler::Rescale().

57  {
58  KALDI_ASSERT(data.size() > 0);
59  int32 num_splice = nnet_->LeftContext() + 1 + nnet_->RightContext();
60  KALDI_ASSERT(data[0].input_frames.NumRows() == num_splice);
61 
62  int32 feat_dim = data[0].input_frames.NumCols(),
63  spk_dim = data[0].spk_info.Dim(),
64  tot_dim = feat_dim + spk_dim; // we append these at the neural net
65  // input... note, spk_dim might be 0.
66  KALDI_ASSERT(tot_dim == nnet_->InputDim());
67  int32 num_chunks = data.size();
68 
69  input->Resize(num_splice * num_chunks,
70  tot_dim);
71  for (int32 chunk = 0; chunk < num_chunks; chunk++) {
72  CuSubMatrix<BaseFloat> dest(*input,
73  chunk * num_splice, num_splice,
74  0, feat_dim);
75  Matrix<BaseFloat> src(data[chunk].input_frames);
76  dest.CopyFromMat(src);
77  if (spk_dim != 0) {
78  CuSubMatrix<BaseFloat> spk_dest(*input,
79  chunk * num_splice, num_splice,
80  feat_dim, spk_dim);
81  spk_dest.CopyRowsFromVec(data[chunk].spk_info);
82  }
83  }
84  // TODO : filter out the unnecessary rows from the input
85  nnet_->ComputeChunkInfo(num_splice, num_chunks, &chunk_info_out_);
86 
87 }
int32 LeftContext() const
Returns the left-context summed over all the Components...
Definition: nnet-nnet.cc:42
kaldi::int32 int32
int32 RightContext() const
Returns the right-context summed over all the Components...
Definition: nnet-nnet.cc:56
std::vector< ChunkInfo > chunk_info_out_
Definition: rescale-nnet.cc:50
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void ComputeChunkInfo(int32 input_chunk_size, int32 num_chunks, std::vector< ChunkInfo > *chunk_info_out) const
Uses the output of the Context() functions of the network, to compute a vector of size NumComponents(...
Definition: nnet-nnet.cc:65
int32 InputDim() const
Dimension of the input features, e.g.
Definition: nnet-nnet.cc:36

◆ GetTargetAvgDeriv()

BaseFloat GetTargetAvgDeriv ( int32  c)
private

Definition at line 98 of file rescale-nnet.cc.

References NnetRescaler::config_, Nnet::GetComponent(), KALDI_ASSERT, KALDI_ERR, NnetRescaler::nnet_, NnetRescaler::relevant_indexes_, NnetRescaleConfig::target_avg_deriv, NnetRescaleConfig::target_first_layer_avg_deriv, and NnetRescaleConfig::target_last_layer_avg_deriv.

Referenced by NnetRescaler::NnetRescaler(), and NnetRescaler::RescaleComponent().

98  {
99  KALDI_ASSERT(relevant_indexes_.count(c) == 1);
100  BaseFloat factor;
101  if (dynamic_cast<SigmoidComponent*>(&(nnet_->GetComponent(c + 1))) != NULL)
102  factor = 0.25;
103  else if (dynamic_cast<TanhComponent*>(&(nnet_->GetComponent(c + 1))) != NULL)
104  factor = 1.0;
105  else
106  KALDI_ERR << "This type of nonlinear component is not handled: index " << c;
107 
108  int32 last_c = *std::max_element(relevant_indexes_.begin(), relevant_indexes_.end()),
109  first_c = *std::min_element(relevant_indexes_.begin(), relevant_indexes_.end());
110  if (c == first_c)
111  return factor * config_.target_first_layer_avg_deriv;
112  else if (c == last_c)
113  return factor * config_.target_last_layer_avg_deriv;
114  else
115  return factor * config_.target_avg_deriv;
116 }
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
const NnetRescaleConfig & config_
Definition: rescale-nnet.cc:47
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::set< int32 > relevant_indexes_
Definition: rescale-nnet.cc:51

◆ Rescale()

void Rescale ( )

Definition at line 200 of file rescale-nnet.cc.

References NnetRescaler::chunk_info_out_, NnetRescaler::ComputeRelevantIndexes(), NnetRescaler::examples_, NnetRescaler::FormatInput(), Nnet::GetComponent(), NnetRescaler::nnet_, Nnet::NumComponents(), Component::Propagate(), NnetRescaler::relevant_indexes_, NnetRescaler::RescaleComponent(), and CuMatrix< Real >::Swap().

Referenced by NnetRescaler::NnetRescaler(), and kaldi::nnet2::RescaleNnet().

200  {
201  ComputeRelevantIndexes(); // set up relevant_indexes_.
202  CuMatrix<BaseFloat> cur_data, next_data;
203  FormatInput(examples_, &cur_data);
204  int32 num_chunks = examples_.size();
205  for (int32 c = 0; c < nnet_->NumComponents(); c++) {
206  Component &component = nnet_->GetComponent(c);
207  if (relevant_indexes_.count(c - 1) == 1) {
208  // the following function call also appropriately sets "next_data"
209  // after doing the rescaling
210  RescaleComponent(c - 1, num_chunks, &cur_data, &next_data);
211  } else {
212  component.Propagate(chunk_info_out_[c], chunk_info_out_[c+1], cur_data, &next_data);
213  }
214  cur_data.Swap(&next_data);
215  }
216 }
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
const std::vector< NnetExample > & examples_
Definition: rescale-nnet.cc:48
void RescaleComponent(int32 c, int32 num_chunks, CuMatrixBase< BaseFloat > *cur_data_in, CuMatrix< BaseFloat > *next_data)
void FormatInput(const std::vector< NnetExample > &data, CuMatrix< BaseFloat > *input)
takes the input and formats as a single matrix, in forward_data_[0].
Definition: rescale-nnet.cc:56
kaldi::int32 int32
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
std::vector< ChunkInfo > chunk_info_out_
Definition: rescale-nnet.cc:50
std::set< int32 > relevant_indexes_
Definition: rescale-nnet.cc:51

◆ RescaleComponent()

void RescaleComponent ( int32  c,
int32  num_chunks,
CuMatrixBase< BaseFloat > *  cur_data_in,
CuMatrix< BaseFloat > *  next_data 
)
private

Definition at line 121 of file rescale-nnet.cc.

References Component::Backprop(), NnetRescaler::chunk_info_out_, NnetRescaler::config_, NnetRescaleConfig::delta, Nnet::GetComponent(), NnetRescaler::GetTargetAvgDeriv(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_VLOG, NnetRescaleConfig::max_change, NnetRescaleConfig::min_change, NnetRescaler::nnet_, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), Component::Propagate(), UpdatableComponent::Scale(), and CuMatrixBase< Real >::Sum().

Referenced by NnetRescaler::NnetRescaler(), and NnetRescaler::Rescale().

125  {
126  int32 rows = cur_data_in->NumRows(), cols = cur_data_in->NumCols();
127  // Only handle sigmoid or tanh here.
128  if (dynamic_cast<SigmoidComponent*>(&(nnet_->GetComponent(c + 1))) == NULL &&
129  dynamic_cast<TanhComponent*>(&(nnet_->GetComponent(c + 1))) == NULL)
130  KALDI_ERR << "This type of nonlinear component is not handled: index " << c;
131  KALDI_ASSERT(chunk_info_out_[0].NumChunks() == num_chunks); //TODO verify how this component can be used
132  // rewrite the
133  // chunk_info_out_
134  // computation
135  // the nonlinear component:
136  NonlinearComponent &nc =
137  *(dynamic_cast<NonlinearComponent*>(&(nnet_->GetComponent(c + 1))));
138  ChunkInfo in_info, out_info;
139  in_info = chunk_info_out_[c+1];
140  out_info = chunk_info_out_[c+2];
141 
142  BaseFloat orig_avg_deriv, target_avg_deriv = GetTargetAvgDeriv(c);
143  BaseFloat cur_scaling = 1.0; // current rescaling factor (on input).
144  int32 num_iters = 10;
145 
146  CuMatrix<BaseFloat> cur_data(*cur_data_in),
147  ones(rows, cols), in_deriv(rows, cols);
148 
149  ones.Set(1.0);
150  nc.Propagate(in_info, out_info, cur_data, next_data);
151  nc.Backprop(in_info, out_info, cur_data, *next_data, ones, NULL, &in_deriv);
152  BaseFloat cur_avg_deriv;
153  cur_avg_deriv = in_deriv.Sum() / (rows * cols);
154  orig_avg_deriv = cur_avg_deriv;
155  for (int32 iter = 0; iter < num_iters; iter++) {
156  // We already have "cur_avg_deriv"; perturb the scale and compute
157  // the next avg_deriv, so we can see how it changes with the scale.
158  cur_data.CopyFromMat(*cur_data_in);
159  cur_data.Scale(cur_scaling + config_.delta);
160  nc.Propagate(in_info, out_info, cur_data, next_data);
161  nc.Backprop(in_info, out_info, cur_data, *next_data, ones, NULL, &in_deriv);
162  BaseFloat next_avg_deriv = in_deriv.Sum() / (rows * cols);
163  KALDI_ASSERT(next_avg_deriv < cur_avg_deriv);
164  // "gradient" is how avg_deriv changes as we change the scale.
165  // should be negative.
166  BaseFloat gradient = (next_avg_deriv - cur_avg_deriv) / config_.delta;
167  KALDI_ASSERT(gradient < 0.0);
168  BaseFloat proposed_change = (target_avg_deriv - cur_avg_deriv) / gradient;
169  KALDI_VLOG(2) << "cur_avg_deriv = " << cur_avg_deriv << ", target_avg_deriv = "
170  << target_avg_deriv << ", gradient = " << gradient
171  << ", proposed_change " << proposed_change;
172  // Limit size of proposed change in "cur_scaling", to ensure stability.
173  if (fabs(proposed_change / cur_scaling) > config_.max_change)
174  proposed_change = cur_scaling * config_.max_change *
175  (proposed_change > 0.0 ? 1.0 : -1.0);
176  cur_scaling += proposed_change;
177 
178  cur_data.CopyFromMat(*cur_data_in);
179  cur_data.Scale(cur_scaling);
180  nc.Propagate(in_info, out_info, cur_data, next_data);
181  nc.Backprop(in_info, out_info, cur_data, *next_data, ones, NULL, &in_deriv);
182  cur_avg_deriv = in_deriv.Sum() / (rows * cols);
183  if (fabs(proposed_change) < config_.min_change) break; // Terminate the
184  // optimization
185  }
186  UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(
187  &nnet_->GetComponent(c));
188  KALDI_ASSERT(uc != NULL);
189  uc->Scale(cur_scaling); // scale the parameters of the previous
190  // AffineComponent.
191 
192  KALDI_LOG << "For component " << c << ", scaling parameters by "
193  << cur_scaling << "; average "
194  << "derivative changed from " << orig_avg_deriv << " to "
195  << cur_avg_deriv << "; target was " << target_avg_deriv;
196 }
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
BaseFloat GetTargetAvgDeriv(int32 c)
Definition: rescale-nnet.cc:98
const NnetRescaleConfig & config_
Definition: rescale-nnet.cc:47
std::vector< ChunkInfo > chunk_info_out_
Definition: rescale-nnet.cc:50
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
#define KALDI_LOG
Definition: kaldi-error.h:153

Member Data Documentation

◆ chunk_info_out_

std::vector<ChunkInfo> chunk_info_out_
private

◆ config_

const NnetRescaleConfig& config_
private

◆ examples_

const std::vector<NnetExample>& examples_
private

Definition at line 48 of file rescale-nnet.cc.

Referenced by NnetRescaler::Rescale().

◆ nnet_

◆ relevant_indexes_

std::set<int32> relevant_indexes_
private

The documentation for this class was generated from the following file: