nnet-training.cc
Go to the documentation of this file.
1 // nnet3/nnet-training.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2015 Xiaohui Zhang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "nnet3/nnet-training.h"
22 #include "nnet3/nnet-utils.h"
23 
24 namespace kaldi {
25 namespace nnet3 {
26 
28  Nnet *nnet):
29  config_(config),
30  nnet_(nnet),
31  compiler_(*nnet, config_.optimize_config, config_.compiler_config),
32  num_minibatches_processed_(0),
33  max_change_stats_(*nnet),
34  srand_seed_(RandInt(0, 100000)) {
35  if (config.zero_component_stats)
36  ZeroComponentStats(nnet);
37  KALDI_ASSERT(config.momentum >= 0.0 &&
38  config.max_param_change >= 0.0 &&
39  config.backstitch_training_interval > 0);
40  delta_nnet_ = nnet_->Copy();
41  ScaleNnet(0.0, delta_nnet_);
42 
43  if (config_.read_cache != "") {
44  bool binary;
45  Input ki;
46  if (ki.Open(config_.read_cache, &binary)) {
47  compiler_.ReadCache(ki.Stream(), binary);
48  KALDI_LOG << "Read computation cache from " << config_.read_cache;
49  } else {
50  KALDI_WARN << "Could not open cached computation. "
51  "Probably this is the first training iteration.";
52  }
53  }
54 }
55 
56 
58  bool need_model_derivative = true;
59  ComputationRequest request;
60  GetComputationRequest(*nnet_, eg, need_model_derivative,
62  &request);
63  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
64 
68  // backstitch training is incompatible with momentum > 0
71  bool is_backstitch_step1 = true;
74  TrainInternalBackstitch(eg, *computation, is_backstitch_step1);
75  FreezeNaturalGradient(false, delta_nnet_); // un-freeze natural gradient
76  is_backstitch_step1 = false;
79  TrainInternalBackstitch(eg, *computation, is_backstitch_step1);
80  } else { // conventional training
81  TrainInternal(eg, *computation);
82  }
83  if (num_minibatches_processed_ == 0) {
86  }
88 
89 }
90 
92  const NnetComputation &computation) {
93  // note: because we give the 1st arg (nnet_) as a pointer to the
94  // constructor of 'computer', it will use that copy of the nnet to
95  // store stats.
96  NnetComputer computer(config_.compute_config, computation,
98  // give the inputs to the computer object.
99  computer.AcceptInputs(*nnet_, eg.io);
100  computer.Run();
101 
102  this->ProcessOutputs(false, eg, &computer);
103  computer.Run();
104 
105  // If relevant, add in the part of the gradient that comes from L2
106  // regularization.
109  delta_nnet_);
110 
111  // Update the parameters of nnet
112  bool success = UpdateNnetWithMaxChange(
114  1.0, 1.0 - config_.momentum, nnet_, &max_change_stats_);
115 
116  // Scale down the batchnorm stats (keeps them fresh... this affects what
117  // happens when we use the model with batchnorm test-mode set).
119 
120  // The following will only do something if we have a LinearComponent
121  // or AffineComponent with orthonormal-constraint set to a nonzero value.
123 
124  // Scale deta_nnet
125  if (success)
127  else
128  ScaleNnet(0.0, delta_nnet_);
129 }
130 
132  const NnetComputation &computation,
133  bool is_backstitch_step1) {
134  // note: because we give the 1st arg (nnet_) as a pointer to the
135  // constructor of 'computer', it will use that copy of the nnet to
136  // store stats.
137  NnetComputer computer(config_.compute_config, computation,
138  nnet_, delta_nnet_);
139  // give the inputs to the computer object.
140  computer.AcceptInputs(*nnet_, eg.io);
141  computer.Run();
142 
143  bool is_backstitch_step2 = !is_backstitch_step1;
144  this->ProcessOutputs(is_backstitch_step2, eg, &computer);
145  computer.Run();
146 
147  BaseFloat max_change_scale, scale_adding;
148  if (is_backstitch_step1) {
149  // max-change is scaled by backstitch_training_scale;
150  // delta_nnet is scaled by -backstitch_training_scale when added to nnet;
151  max_change_scale = config_.backstitch_training_scale;
152  scale_adding = -config_.backstitch_training_scale;
153  } else {
154  // max-change is scaled by 1 + backstitch_training_scale;
155  // delta_nnet is scaled by 1 + backstitch_training_scale when added to nnet;
156  max_change_scale = 1.0 + config_.backstitch_training_scale;
157  scale_adding = 1.0 + config_.backstitch_training_scale;
158  // If relevant, add in the part of the gradient that comes from L2
159  // regularization. It may not be optimally inefficient to do it on both
160  // passes of the backstitch, like we do here, but it probably minimizes
161  // any harmful interactions with the max-change.
163  1.0 / scale_adding * GetNumNvalues(eg.io, false) *
165  }
166 
167  // Updates the parameters of nnet
170  max_change_scale, scale_adding, nnet_,
172 
173  if (is_backstitch_step1) {
174  // The following will only do something if we have a LinearComponent or
175  // AffineComponent with orthonormal-constraint set to a nonzero value. We
176  // choose to do this only on the 1st backstitch step, for efficiency.
178  }
179 
180  if (!is_backstitch_step1) {
181  // Scale down the batchnorm stats (keeps them fresh... this affects what
182  // happens when we use the model with batchnorm test-mode set). Do this
183  // after backstitch step 2 so that the stats are scaled down before we start
184  // the next minibatch.
186  }
187 
188  ScaleNnet(0.0, delta_nnet_);
189 }
190 
191 void NnetTrainer::ProcessOutputs(bool is_backstitch_step2,
192  const NnetExample &eg,
193  NnetComputer *computer) {
194  // normally the eg will have just one output named 'output', but
195  // we don't assume this.
196  // In backstitch training, the output-name with the "_backstitch" suffix is
197  // the one computed after the first, backward step of backstitch.
198  const std::string suffix = (is_backstitch_step2 ? "_backstitch" : "");
199  std::vector<NnetIo>::const_iterator iter = eg.io.begin(),
200  end = eg.io.end();
201  for (; iter != end; ++iter) {
202  const NnetIo &io = *iter;
203  int32 node_index = nnet_->GetNodeIndex(io.name);
204  KALDI_ASSERT(node_index >= 0);
205  if (nnet_->IsOutputNode(node_index)) {
206  ObjectiveType obj_type = nnet_->GetNode(node_index).u.objective_type;
207  BaseFloat tot_weight, tot_objf;
208  bool supply_deriv = true;
209  ComputeObjectiveFunction(io.features, obj_type, io.name,
210  supply_deriv, computer,
211  &tot_weight, &tot_objf);
212  objf_info_[io.name + suffix].UpdateStats(io.name + suffix,
215  tot_weight, tot_objf);
216  }
217  }
218 }
219 
221  unordered_map<std::string, ObjectiveFunctionInfo, StringHasher>::const_iterator
222  iter = objf_info_.begin(),
223  end = objf_info_.end();
224  std::vector<std::pair<std::string, const ObjectiveFunctionInfo*> > all_pairs;
225  for (; iter != end; ++iter)
226  all_pairs.push_back(std::pair<std::string, const ObjectiveFunctionInfo*>(
227  iter->first, &(iter->second)));
228  // ensure deterministic order of these names (this will matter in situations
229  // where a script greps for the objective from the log).
230  std::sort(all_pairs.begin(), all_pairs.end());
231  bool ans = false;
232  for (size_t i = 0; i < all_pairs.size(); i++) {
233  const std::string &name = all_pairs[i].first;
234  const ObjectiveFunctionInfo &info = *(all_pairs[i].second);
235  bool ok = info.PrintTotalStats(name);
236  ans = ans || ok;
237  }
239  return ans;
240 }
241 
243  const std::string &output_name,
244  int32 minibatches_per_phase,
245  int32 minibatch_counter,
246  BaseFloat this_minibatch_weight,
247  BaseFloat this_minibatch_tot_objf,
248  BaseFloat this_minibatch_tot_aux_objf) {
249  int32 phase = minibatch_counter / minibatches_per_phase;
250  if (phase != current_phase) {
251  KALDI_ASSERT(phase > current_phase);
252  PrintStatsForThisPhase(output_name, minibatches_per_phase,
253  phase);
254  current_phase = phase;
255  tot_weight_this_phase = 0.0;
256  tot_objf_this_phase = 0.0;
257  tot_aux_objf_this_phase = 0.0;
258  minibatches_this_phase = 0;
259  }
260  minibatches_this_phase++;
261  tot_weight_this_phase += this_minibatch_weight;
262  tot_objf_this_phase += this_minibatch_tot_objf;
263  tot_aux_objf_this_phase += this_minibatch_tot_aux_objf;
264  tot_weight += this_minibatch_weight;
265  tot_objf += this_minibatch_tot_objf;
266  tot_aux_objf += this_minibatch_tot_aux_objf;
267 }
268 
270  const std::string &output_name,
271  int32 minibatches_per_phase,
272  int32 phase) const {
273  int32 start_minibatch = current_phase * minibatches_per_phase,
274  end_minibatch = phase * minibatches_per_phase - 1;
275 
276  if (tot_aux_objf_this_phase == 0.0) {
277  if (minibatches_per_phase == minibatches_this_phase) {
278  KALDI_LOG << "Average objective function for '" << output_name
279  << "' for minibatches " << start_minibatch
280  << '-' << end_minibatch << " is "
281  << (tot_objf_this_phase / tot_weight_this_phase) << " over "
282  << tot_weight_this_phase << " frames.";
283  } else {
284  KALDI_LOG << "Average objective function for '" << output_name
285  << " using " << minibatches_this_phase
286  << " minibatches in minibatch range " << start_minibatch
287  << '-' << end_minibatch << " is "
288  << (tot_objf_this_phase / tot_weight_this_phase) << " over "
289  << tot_weight_this_phase << " frames.";
290  }
291  } else {
292  BaseFloat objf = (tot_objf_this_phase / tot_weight_this_phase),
293  aux_objf = (tot_aux_objf_this_phase / tot_weight_this_phase),
294  sum_objf = objf + aux_objf;
295  if (minibatches_per_phase == minibatches_this_phase) {
296  KALDI_LOG << "Average objective function for '" << output_name
297  << "' for minibatches " << start_minibatch
298  << '-' << end_minibatch << " is "
299  << objf << " + " << aux_objf << " = " << sum_objf
300  << " over " << tot_weight_this_phase << " frames.";
301  } else {
302  KALDI_LOG << "Average objective function for '" << output_name
303  << "' using " << minibatches_this_phase
304  << " minibatches in minibatch range " << start_minibatch
305  << '-' << end_minibatch << " is "
306  << objf << " + " << aux_objf << " = " << sum_objf
307  << " over " << tot_weight_this_phase << " frames.";
308  }
309  }
310 }
311 
312 bool ObjectiveFunctionInfo::PrintTotalStats(const std::string &name) const {
313  BaseFloat objf = (tot_objf / tot_weight),
314  aux_objf = (tot_aux_objf / tot_weight),
315  sum_objf = objf + aux_objf;
316  if (tot_aux_objf == 0.0) {
317  KALDI_LOG << "Overall average objective function for '" << name << "' is "
318  << (tot_objf / tot_weight) << " over " << tot_weight << " frames.";
319  } else {
320  KALDI_LOG << "Overall average objective function for '" << name << "' is "
321  << objf << " + " << aux_objf << " = " << sum_objf
322  << " over " << tot_weight << " frames.";
323  }
324  KALDI_LOG << "[this line is to be parsed by a script:] "
325  << "log-prob-per-frame="
326  << objf;
327  return (tot_weight != 0.0);
328 }
329 
331  if (config_.write_cache != "") {
334  KALDI_LOG << "Wrote computation cache to " << config_.write_cache;
335  }
336  delete delta_nnet_;
337 }
338 
339 void ComputeObjectiveFunction(const GeneralMatrix &supervision,
340  ObjectiveType objective_type,
341  const std::string &output_name,
342  bool supply_deriv,
343  NnetComputer *computer,
344  BaseFloat *tot_weight,
345  BaseFloat *tot_objf) {
346  const CuMatrixBase<BaseFloat> &output = computer->GetOutput(output_name);
347 
348  if (output.NumCols() != supervision.NumCols())
349  KALDI_ERR << "Nnet versus example output dimension (num-classes) "
350  << "mismatch for '" << output_name << "': " << output.NumCols()
351  << " (nnet) vs. " << supervision.NumCols() << " (egs)\n";
352 
353  switch (objective_type) {
354  case kLinear: {
355  // objective is x * y.
356  switch (supervision.Type()) {
357  case kSparseMatrix: {
358  const SparseMatrix<BaseFloat> &post = supervision.GetSparseMatrix();
359  CuSparseMatrix<BaseFloat> cu_post(post);
360  // The cross-entropy objective is computed by a simple dot product,
361  // because after the LogSoftmaxLayer, the output is already in the form
362  // of log-likelihoods that are normalized to sum to one.
363  *tot_weight = cu_post.Sum();
364  *tot_objf = TraceMatSmat(output, cu_post, kTrans);
365  if (supply_deriv) {
366  CuMatrix<BaseFloat> output_deriv(output.NumRows(), output.NumCols(),
367  kUndefined);
368  cu_post.CopyToMat(&output_deriv);
369  computer->AcceptInput(output_name, &output_deriv);
370  }
371  break;
372  }
373  case kFullMatrix: {
374  // there is a redundant matrix copy in here if we're not using a GPU
375  // but we don't anticipate this code branch being used in many cases.
376  CuMatrix<BaseFloat> cu_post(supervision.GetFullMatrix());
377  *tot_weight = cu_post.Sum();
378  *tot_objf = TraceMatMat(output, cu_post, kTrans);
379  if (supply_deriv)
380  computer->AcceptInput(output_name, &cu_post);
381  break;
382  }
383  case kCompressedMatrix: {
384  Matrix<BaseFloat> post;
385  supervision.GetMatrix(&post);
386  CuMatrix<BaseFloat> cu_post;
387  cu_post.Swap(&post);
388  *tot_weight = cu_post.Sum();
389  *tot_objf = TraceMatMat(output, cu_post, kTrans);
390  if (supply_deriv)
391  computer->AcceptInput(output_name, &cu_post);
392  break;
393  }
394  }
395  break;
396  }
397  case kQuadratic: {
398  // objective is -0.5 (x - y)^2
399  CuMatrix<BaseFloat> diff(supervision.NumRows(),
400  supervision.NumCols(),
401  kUndefined);
402  diff.CopyFromGeneralMat(supervision);
403  diff.AddMat(-1.0, output);
404  *tot_weight = diff.NumRows();
405  *tot_objf = -0.5 * TraceMatMat(diff, diff, kTrans);
406  if (supply_deriv)
407  computer->AcceptInput(output_name, &diff);
408  break;
409  }
410  default:
411  KALDI_ERR << "Objective function type " << objective_type
412  << " not handled.";
413  }
414 }
415 
416 
417 
418 } // namespace nnet3
419 } // namespace kaldi
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
This class is a wrapper that enables you to store a matrix in one of three forms: either as a Matrix<...
void CopyToMat(CuMatrixBase< OtherReal > *dest, MatrixTransposeType trans=kNoTrans) const
void GetMatrix(Matrix< BaseFloat > *mat) const
Outputs the contents as a matrix.
void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, BaseFloat *tot_objf)
This function computes the objective function, and if supply_deriv = true, supplies its derivative to...
void UpdateStats(const std::string &output_name, int32 minibatches_per_phase, int32 minibatch_counter, BaseFloat this_minibatch_weight, BaseFloat this_minibatch_tot_objf, BaseFloat this_minibatch_tot_aux_objf=0.0)
void TrainInternalBackstitch(const NnetExample &eg, const NnetComputation &computation, bool is_backstitch_step1)
bool Open(const std::string &rxfilename, bool *contents_binary=NULL)
Definition: kaldi-io-inl.h:26
Real Sum() const
Definition: cu-matrix.cc:3012
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale, Nnet *nnet)
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComp...
Definition: nnet-utils.cc:536
kaldi::int32 int32
GeneralMatrix features
The features or labels.
Definition: nnet-example.h:46
void Train(const NnetExample &eg)
const Matrix< BaseFloat > & GetFullMatrix() const
Returns the contents as a Matrix<BaseFloat>.
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
ObjectiveType objective_type
Definition: nnet-nnet.h:97
This file contains some miscellaneous functions dealing with class Nnet.
void ConstrainOrthonormal(Nnet *nnet)
This function, to be called after processing every minibatch, is responsible for enforcing the orthog...
Definition: nnet-utils.cc:1108
void FreezeNaturalGradient(bool freeze, Nnet *nnet)
Controls if natural gradient will be updated.
Definition: nnet-utils.cc:432
int32 GetNumNvalues(const std::vector< NnetIo > &io_vec, bool exhaustive)
This utility function can be used to obtain the number of distinct &#39;n&#39; values in a training example...
Definition: nnet-utils.cc:2198
void AcceptInput(const std::string &node_name, CuMatrix< BaseFloat > *input)
e.g.
void ResetGenerators(Nnet *nnet)
This function calls &#39;ResetGenerator()&#39; on all components in &#39;nnet&#39; that inherit from class RandomComp...
Definition: nnet-utils.cc:582
MatrixIndexT NumCols() const
const NetworkNode & GetNode(int32 node) const
returns const reference to a particular numbered network node.
Definition: nnet-nnet.h:146
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
MaxChangeStats max_change_stats_
std::ostream & Stream()
Definition: kaldi-io.cc:701
const CuMatrixBase< BaseFloat > & GetOutput(const std::string &node_name)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
void ApplyL2Regularization(const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange().
Definition: nnet-utils.cc:2244
GeneralMatrixType Type() const
Returns the type of the matrix: kSparseMatrix, kCompressedMatrix or kFullMatrix.
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_
void Swap(Matrix< Real > *mat)
Definition: cu-matrix.cc:123
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
Nnet * Copy() const
Definition: nnet-nnet.h:246
#define KALDI_ERR
Definition: kaldi-error.h:147
CachingOptimizingCompiler compiler_
NnetTrainer(const NnetTrainerOptions &config, Nnet *nnet)
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Definition: nnet-utils.cc:269
#define KALDI_WARN
Definition: kaldi-error.h:150
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
bool PrintTotalStats(const std::string &output_name) const
const NnetTrainerOptions config_
void CopyFromGeneralMat(const GeneralMatrix &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:3096
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
ObjectiveType
This enum is for a kind of annotation we associate with output nodes of the network; it&#39;s for the con...
Definition: nnet-nnet.h:52
void ReadCache(std::istream &is, bool binary)
Matrix for CUDA computing.
Definition: matrix-common.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
void WriteCache(std::ostream &os, bool binary)
void ConsolidateMemory(Nnet *nnet)
This just calls ConsolidateMemory() on all the components of the nnet.
Definition: nnet-utils.cc:1147
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
void TrainInternal(const NnetExample &eg, const NnetComputation &computation)
void Print(const Nnet &nnet) const
Definition: nnet-utils.cc:2284
std::string name
the name of the input in the neural net; in simple setups it will just be "input".
Definition: nnet-example.h:36
Real TraceMatSmat(const MatrixBase< Real > &A, const SparseMatrix< Real > &B, MatrixTransposeType trans)
union kaldi::nnet3::NetworkNode::@15 u
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
const SparseMatrix< BaseFloat > & GetSparseMatrix() const
Returns the contents as a SparseMatrix.
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
void PrintStatsForThisPhase(const std::string &output_name, int32 minibatches_per_phase, int32 phase) const
NnetComputeOptions compute_config
Definition: nnet-training.h:49
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
#define KALDI_LOG
Definition: kaldi-error.h:153
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
This function does the operation &#39;*nnet += scale * delta_nnet&#39;, while respecting any max-parameter-ch...
Definition: nnet-utils.cc:2106
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
void GetComputationRequest(const Nnet &nnet, const NnetExample &eg, bool need_model_derivative, bool store_component_stats, ComputationRequest *request)
This function takes a NnetExample (which should already have been frame-selected, if desired...
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...
void ProcessOutputs(bool is_backstitch_step2, const NnetExample &eg, NnetComputer *computer)