nnet-chain-training.cc
Go to the documentation of this file.
1 // nnet3/nnet-chain-training.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2016 Xiaohui Zhang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
22 #include "nnet3/nnet-utils.h"
23 
24 namespace kaldi {
25 namespace nnet3 {
26 
28  const fst::StdVectorFst &den_fst,
29  Nnet *nnet):
30  opts_(opts),
31  den_graph_(den_fst, nnet->OutputDim("output")),
32  nnet_(nnet),
33  compiler_(*nnet, opts_.nnet_config.optimize_config,
34  opts_.nnet_config.compiler_config),
35  num_minibatches_processed_(0),
36  max_change_stats_(*nnet),
37  srand_seed_(RandInt(0, 100000)) {
39  ZeroComponentStats(nnet);
40  KALDI_ASSERT(opts.nnet_config.momentum >= 0.0 &&
41  opts.nnet_config.max_param_change >= 0.0 &&
43  delta_nnet_ = nnet_->Copy();
44  ScaleNnet(0.0, delta_nnet_);
45 
46  if (opts.nnet_config.read_cache != "") {
47  bool binary;
48  try {
49  Input ki(opts.nnet_config.read_cache, &binary);
50  compiler_.ReadCache(ki.Stream(), binary);
51  KALDI_LOG << "Read computation cache from " << opts.nnet_config.read_cache;
52  } catch (...) {
53  KALDI_WARN << "Could not open cached computation. "
54  "Probably this is the first training iteration.";
55  }
56  }
57 }
58 
59 
61  NVTX_RANGE(__func__);
62  bool need_model_derivative = true;
63  const NnetTrainerOptions &nnet_config = opts_.nnet_config;
64  bool use_xent_regularization = (opts_.chain_config.xent_regularize != 0.0);
65  ComputationRequest request;
66  GetChainComputationRequest(*nnet_, chain_eg, need_model_derivative,
67  nnet_config.store_component_stats,
68  use_xent_regularization, need_model_derivative,
69  &request);
70  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
71 
73  % nnet_config.backstitch_training_interval ==
75  // backstitch training is incompatible with momentum > 0
76  KALDI_ASSERT(nnet_config.momentum == 0.0);
78  bool is_backstitch_step1 = true;
81  TrainInternalBackstitch(chain_eg, *computation, is_backstitch_step1);
82  FreezeNaturalGradient(false, delta_nnet_); // un-freeze natural gradient
83  is_backstitch_step1 = false;
86  TrainInternalBackstitch(chain_eg, *computation, is_backstitch_step1);
87  } else { // conventional training
88  TrainInternal(chain_eg, *computation);
89  }
90  if (num_minibatches_processed_ == 0) {
93  }
95 }
96 
98  const NnetComputation &computation) {
99  NVTX_RANGE(__func__);
100  const NnetTrainerOptions &nnet_config = opts_.nnet_config;
101  // note: because we give the 1st arg (nnet_) as a pointer to the
102  // constructor of 'computer', it will use that copy of the nnet to
103  // store stats.
104  NnetComputer computer(nnet_config.compute_config, computation,
105  nnet_, delta_nnet_);
106 
107  // give the inputs to the computer object.
108  computer.AcceptInputs(*nnet_, eg.inputs);
109  computer.Run();
110 
111  this->ProcessOutputs(false, eg, &computer);
112  computer.Run();
113 
114  // If relevant, add in the part of the gradient that comes from
115  // parameter-level L2 regularization.
117  GetNumNvalues(eg.inputs, false) *
118  nnet_config.l2_regularize_factor,
119  delta_nnet_);
120 
121  // Updates the parameters of nnet
122  bool success = UpdateNnetWithMaxChange(
123  *delta_nnet_,
124  nnet_config.max_param_change,
125  1.0, 1.0 - nnet_config.momentum, nnet_,
127 
128  // Scale down the batchnorm stats (keeps them fresh... this affects what
129  // happens when we use the model with batchnorm test-mode set).
131 
132  // The following will only do something if we have a LinearComponent
133  // or AffineComponent with orthonormal-constraint set to a nonzero value.
135 
136  // Scale delta_nnet
137  if (success)
138  ScaleNnet(nnet_config.momentum, delta_nnet_);
139  else
140  ScaleNnet(0.0, delta_nnet_);
141 }
142 
144  const NnetComputation &computation,
145  bool is_backstitch_step1) {
146  const NnetTrainerOptions &nnet_config = opts_.nnet_config;
147  // note: because we give the 1st arg (nnet_) as a pointer to the
148  // constructor of 'computer', it will use that copy of the nnet to
149  // store stats.
150  NnetComputer computer(nnet_config.compute_config, computation,
151  nnet_, delta_nnet_);
152  // give the inputs to the computer object.
153  computer.AcceptInputs(*nnet_, eg.inputs);
154  computer.Run();
155 
156  bool is_backstitch_step2 = !is_backstitch_step1;
157  this->ProcessOutputs(is_backstitch_step2, eg, &computer);
158  computer.Run();
159 
160  BaseFloat max_change_scale, scale_adding;
161  if (is_backstitch_step1) {
162  // max-change is scaled by backstitch_training_scale;
163  // delta_nnet is scaled by -backstitch_training_scale when added to nnet;
164  max_change_scale = nnet_config.backstitch_training_scale;
165  scale_adding = -nnet_config.backstitch_training_scale;
166  } else {
167  // max-change is scaled by 1 + backstitch_training_scale;
168  // delta_nnet is scaled by 1 + backstitch_training_scale when added to nnet;
169  max_change_scale = 1.0 + nnet_config.backstitch_training_scale;
170  scale_adding = 1.0 + nnet_config.backstitch_training_scale;
171  // If relevant, add in the part of the gradient that comes from L2
172  // regularization. It may not be optimally inefficient to do it on both
173  // passes of the backstitch, like we do here, but it probably minimizes
174  // any harmful interactions with the max-change.
176  1.0 / scale_adding * GetNumNvalues(eg.inputs, false) *
177  nnet_config.l2_regularize_factor, delta_nnet_);
178  }
179 
180  // Updates the parameters of nnet
182  *delta_nnet_, nnet_config.max_param_change,
183  max_change_scale, scale_adding, nnet_,
185 
186  if (is_backstitch_step1) {
187  // The following will only do something if we have a LinearComponent or
188  // AffineComponent with orthonormal-constraint set to a nonzero value. We
189  // choose to do this only on the 1st backstitch step, for efficiency.
191  }
192 
193  if (!is_backstitch_step1) {
194  // Scale down the batchnorm stats (keeps them fresh... this affects what
195  // happens when we use the model with batchnorm test-mode set). Do this
196  // after backstitch step 2 so that the stats are scaled down before we start
197  // the next minibatch.
199  }
200 
201  ScaleNnet(0.0, delta_nnet_);
202 }
203 
204 void NnetChainTrainer::ProcessOutputs(bool is_backstitch_step2,
205  const NnetChainExample &eg,
206  NnetComputer *computer) {
207  NVTX_RANGE(__func__);
208  // normally the eg will have just one output named 'output', but
209  // we don't assume this.
210  // In backstitch training, the output-name with the "_backstitch" suffix is
211  // the one computed after the first, backward step of backstitch.
212  const std::string suffix = (is_backstitch_step2 ? "_backstitch" : "");
213  std::vector<NnetChainSupervision>::const_iterator iter = eg.outputs.begin(),
214  end = eg.outputs.end();
215  for (; iter != end; ++iter) {
216  const NnetChainSupervision &sup = *iter;
217  int32 node_index = nnet_->GetNodeIndex(sup.name);
218  if (node_index < 0 ||
219  !nnet_->IsOutputNode(node_index))
220  KALDI_ERR << "Network has no output named " << sup.name;
221 
222  const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);
223  CuMatrix<BaseFloat> nnet_output_deriv(nnet_output.NumRows(),
224  nnet_output.NumCols(),
225  kUndefined);
226 
227  bool use_xent = (opts_.chain_config.xent_regularize != 0.0);
228  std::string xent_name = sup.name + "-xent"; // typically "output-xent".
229  CuMatrix<BaseFloat> xent_deriv;
230 
231  BaseFloat tot_objf, tot_l2_term, tot_weight;
232 
233  ComputeChainObjfAndDeriv(opts_.chain_config, den_graph_,
234  sup.supervision, nnet_output,
235  &tot_objf, &tot_l2_term, &tot_weight,
236  &nnet_output_deriv,
237  (use_xent ? &xent_deriv : NULL));
238 
239  if (use_xent) {
240  // this block computes the cross-entropy objective.
241  const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput(
242  xent_name);
243  // at this point, xent_deriv is posteriors derived from the numerator
244  // computation. note, xent_objf has a factor of '.supervision.weight'
245  BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans);
246  objf_info_[xent_name + suffix].UpdateStats(xent_name + suffix,
249  tot_weight, xent_objf);
250  }
251 
252  if (opts_.apply_deriv_weights && sup.deriv_weights.Dim() != 0) {
253  CuVector<BaseFloat> cu_deriv_weights(sup.deriv_weights);
254  nnet_output_deriv.MulRowsVec(cu_deriv_weights);
255  if (use_xent)
256  xent_deriv.MulRowsVec(cu_deriv_weights);
257  }
258 
259  computer->AcceptInput(sup.name, &nnet_output_deriv);
260 
261  objf_info_[sup.name + suffix].UpdateStats(sup.name + suffix,
264  tot_weight, tot_objf, tot_l2_term);
265 
266  if (use_xent) {
267  xent_deriv.Scale(opts_.chain_config.xent_regularize);
268  computer->AcceptInput(xent_name, &xent_deriv);
269  }
270  }
271 }
272 
274  unordered_map<std::string, ObjectiveFunctionInfo, StringHasher>::const_iterator
275  iter = objf_info_.begin(),
276  end = objf_info_.end();
277  bool ans = false;
278  for (; iter != end; ++iter) {
279  const std::string &name = iter->first;
280  const ObjectiveFunctionInfo &info = iter->second;
281  ans = info.PrintTotalStats(name) || ans;
282  }
284  return ans;
285 }
286 
288  if (opts_.nnet_config.write_cache != "") {
291  KALDI_LOG << "Wrote computation cache to " << opts_.nnet_config.write_cache;
292  }
293  delete delta_nnet_;
294 }
295 
296 
297 } // namespace nnet3
298 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void TrainInternal(const NnetChainExample &eg, const NnetComputation &computation)
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
Vector< BaseFloat > deriv_weights
This is a vector of per-frame weights, required to be between 0 and 1, that is applied to the derivat...
void TrainInternalBackstitch(const NnetChainExample &eg, const NnetComputation &computation, bool is_backstitch_step1)
chain::Supervision supervision
The supervision object, containing the FST.
std::vector< NnetIo > inputs
&#39;inputs&#39; contains the input to the network– normally just it has just one element called "input"...
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale, Nnet *nnet)
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComp...
Definition: nnet-utils.cc:536
const NnetChainTrainingOptions opts_
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
fst::StdVectorFst StdVectorFst
chain::DenominatorGraph den_graph_
std::string name
the name of the output in the neural net; in simple setups it will just be "output".
This file contains some miscellaneous functions dealing with class Nnet.
void ConstrainOrthonormal(Nnet *nnet)
This function, to be called after processing every minibatch, is responsible for enforcing the orthog...
Definition: nnet-utils.cc:1108
void Scale(Real value)
Definition: cu-matrix.cc:644
void FreezeNaturalGradient(bool freeze, Nnet *nnet)
Controls if natural gradient will be updated.
Definition: nnet-utils.cc:432
int32 GetNumNvalues(const std::vector< NnetIo > &io_vec, bool exhaustive)
This utility function can be used to obtain the number of distinct &#39;n&#39; values in a training example...
Definition: nnet-utils.cc:2198
void AcceptInput(const std::string &node_name, CuMatrix< BaseFloat > *input)
e.g.
void ResetGenerators(Nnet *nnet)
This function calls &#39;ResetGenerator()&#39; on all components in &#39;nnet&#39; that inherit from class RandomComp...
Definition: nnet-utils.cc:582
std::istream & Stream()
Definition: kaldi-io.cc:826
std::ostream & Stream()
Definition: kaldi-io.cc:701
const CuMatrixBase< BaseFloat > & GetOutput(const std::string &node_name)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
void ApplyL2Regularization(const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange().
Definition: nnet-utils.cc:2244
CachingOptimizingCompiler compiler_
std::vector< NnetChainSupervision > outputs
&#39;outputs&#39; contains the chain output supervision.
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
Nnet * Copy() const
Definition: nnet-nnet.h:246
NnetChainExample is like NnetExample, but specialized for lattice-free (chain) training.
#define KALDI_ERR
Definition: kaldi-error.h:147
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Definition: nnet-utils.cc:269
#define KALDI_WARN
Definition: kaldi-error.h:150
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
bool PrintTotalStats(const std::string &output_name) const
void Train(const NnetChainExample &eg)
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
void ReadCache(std::istream &is, bool binary)
Matrix for CUDA computing.
Definition: matrix-common.h:69
void WriteCache(std::ostream &os, bool binary)
void ConsolidateMemory(Nnet *nnet)
This just calls ConsolidateMemory() on all the components of the nnet.
Definition: nnet-utils.cc:1147
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define NVTX_RANGE(name)
Definition: cu-common.h:143
NnetChainTrainer(const NnetChainTrainingOptions &config, const fst::StdVectorFst &den_fst, Nnet *nnet)
chain::ChainTrainingOptions chain_config
void Print(const Nnet &nnet) const
Definition: nnet-utils.cc:2284
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_
void GetChainComputationRequest(const Nnet &nnet, const NnetChainExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetChainExample and produces a ComputationRequest.
NnetComputeOptions compute_config
Definition: nnet-training.h:49
#define KALDI_LOG
Definition: kaldi-error.h:153
void MulRowsVec(const CuVectorBase< Real > &scale)
scale i&#39;th row by scale[i]
Definition: cu-matrix.cc:792
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
This function does the operation &#39;*nnet += scale * delta_nnet&#39;, while respecting any max-parameter-ch...
Definition: nnet-utils.cc:2106
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
void ProcessOutputs(bool is_backstitch_step2, const NnetChainExample &eg, NnetComputer *computer)
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...