nnet-training.h
Go to the documentation of this file.
1 // nnet3/nnet-training.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2016 Xiaohui Zhang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET3_NNET_TRAINING_H_
22 #define KALDI_NNET3_NNET_TRAINING_H_
23 
24 #include "nnet3/nnet-example.h"
25 #include "nnet3/nnet-computation.h"
26 #include "nnet3/nnet-compute.h"
27 #include "nnet3/nnet-optimize.h"
29 #include "nnet3/nnet-utils.h"
30 
31 namespace kaldi {
32 namespace nnet3 {
33 
44  std::string read_cache;
45  std::string write_cache;
52  zero_component_stats(true),
53  store_component_stats(true),
54  print_interval(100),
55  debug_computation(false),
56  momentum(0.0),
57  l2_regularize_factor(1.0),
58  backstitch_training_scale(0.0),
59  backstitch_training_interval(1),
60  batchnorm_stats_scale(0.8),
61  binary_write_cache(true),
62  max_param_change(2.0) { }
63  void Register(OptionsItf *opts) {
64  opts->Register("store-component-stats", &store_component_stats,
65  "If true, store activations and derivatives for nonlinear "
66  "components during training.");
67  opts->Register("zero-component-stats", &zero_component_stats,
68  "If both this and --store-component-stats are true, then "
69  "the component stats are zeroed before training.");
70  opts->Register("print-interval", &print_interval, "Interval (measured in "
71  "minibatches) after which we print out objective function "
72  "during training\n");
73  opts->Register("max-param-change", &max_param_change, "The maximum change in "
74  "parameters allowed per minibatch, measured in Euclidean norm "
75  "over the entire model (change will be clipped to this value)");
76  opts->Register("momentum", &momentum, "Momentum constant to apply during "
77  "training (help stabilize update). e.g. 0.9. Note: we "
78  "automatically multiply the learning rate by (1-momenum) "
79  "so that the 'effective' learning rate is the same as "
80  "before (because momentum would normally increase the "
81  "effective learning rate by 1/(1-momentum))");
82  opts->Register("l2-regularize-factor", &l2_regularize_factor, "Factor that "
83  "affects the strength of l2 regularization on model "
84  "parameters. The primary way to specify this type of "
85  "l2 regularization is via the 'l2-regularize'"
86  "configuration value at the config-file level. "
87  " --l2-regularize-factor will be multiplied by the component-level "
88  "l2-regularize values and can be used to correct for effects "
89  "related to parallelization by model averaging.");
90  opts->Register("batchnorm-stats-scale", &batchnorm_stats_scale,
91  "Factor by which we scale down the accumulated stats of batchnorm "
92  "layers after processing each minibatch. Ensure that the final "
93  "model we write out has batchnorm stats that are fairly fresh.");
94  opts->Register("backstitch-training-scale", &backstitch_training_scale,
95  "backstitch training factor. "
96  "if 0 then in the normal training mode. It is referred as "
97  "'\\alpha' in our publications.");
98  opts->Register("backstitch-training-interval",
99  &backstitch_training_interval,
100  "do backstitch training with the specified interval of "
101  "minibatches. It is referred as 'n' in our publications.");
102  opts->Register("read-cache", &read_cache, "The location from which to read "
103  "the cached computation.");
104  opts->Register("write-cache", &write_cache, "The location to which to write "
105  "the cached computation.");
106  opts->Register("binary-write-cache", &binary_write_cache, "Write "
107  "computation cache in binary mode");
108 
109  // register the optimization options with the prefix "optimization".
110  ParseOptions optimization_opts("optimization", opts);
111  optimize_config.Register(&optimization_opts);
112  ParseOptions compiler_opts("compiler", opts);
113  compiler_config.Register(&compiler_opts);
114  // register the compute options with the prefix "computation".
115  ParseOptions compute_opts("computation", opts);
116  compute_config.Register(&compute_opts);
117  }
118 };
119 
120 // This struct is used in multiple nnet training classes for keeping
121 // track of objective function values.
122 // Also see struct AccuracyInfo, in nnet-diagnostics.h.
125  int32 minibatches_this_phase; // The number of minibatches' worth of stats that
126  // we accumulated in the phase numbered
127  // 'current_phase'.
128  double tot_weight;
129  double tot_objf;
130  double tot_aux_objf; // An 'auxiliary' objective function that is optional-
131  // may be used when things like regularization are being
132  // used.
133 
137 
139  current_phase(0),
140  minibatches_this_phase(0),
141  tot_weight(0.0), tot_objf(0.0), tot_aux_objf(0.0),
142  tot_weight_this_phase(0.0), tot_objf_this_phase(0.0),
143  tot_aux_objf_this_phase(0.0) { }
144 
145  // This function updates the stats and, if the phase has just changed,
146  // prints a message indicating progress. The phase equals
147  // minibatch_counter / minibatches_per_phase. Its only function is to
148  // control how frequently we print logging messages.
149  void UpdateStats(const std::string &output_name,
150  int32 minibatches_per_phase,
151  int32 minibatch_counter,
152  BaseFloat this_minibatch_weight,
153  BaseFloat this_minibatch_tot_objf,
154  BaseFloat this_minibatch_tot_aux_objf = 0.0);
155 
156  // Prints stats for the current phase.
157  // Note: 'phase' will normally be this->current_phase + 1, but may under
158  // unusual circumstances (e.g. multilingual training, where not all outputs
159  // are seen on all minibatches) be larger than that.
160  void PrintStatsForThisPhase(const std::string &output_name,
161  int32 minibatches_per_phase,
162  int32 phase) const;
163  // Prints total stats, and returns true if total stats' weight was nonzero.
164  bool PrintTotalStats(const std::string &output_name) const;
165 };
166 
167 
180 class NnetTrainer {
181  public:
182  NnetTrainer(const NnetTrainerOptions &config,
183  Nnet *nnet);
184 
185  // train on one minibatch.
186  void Train(const NnetExample &eg);
187 
188  // Prints out the final stats, and return true if there was a nonzero count.
189  bool PrintTotalStats() const;
190 
191  ~NnetTrainer();
192  private:
193  // The internal function for doing one step of conventional SGD training.
194  void TrainInternal(const NnetExample &eg,
195  const NnetComputation &computation);
196 
197  // The internal function for doing one step of backstitch training. Depending
198  // on whether is_backstitch_step1 is true, It could be either the first
199  // (backward) step, or the second (forward) step of backstitch.
200  void TrainInternalBackstitch(const NnetExample &eg,
201  const NnetComputation &computation,
202  bool is_backstitch_step1);
203 
204  void ProcessOutputs(bool is_backstitch_step2, const NnetExample &eg,
205  NnetComputer *computer);
206 
209  Nnet *delta_nnet_; // nnet representing parameter-change for this minibatch
210  // (or, when using momentum, the moving weighted average
211  // of this).
213 
214  // This code supports multiple output layers, even though in the
215  // normal case there will be just one output layer named "output".
216  // So we store the objective functions per output layer.
218 
219  // stats for max-change.
221 
222  unordered_map<std::string, ObjectiveFunctionInfo, StringHasher> objf_info_;
223 
224  // This value is used in backstitch training when we need to ensure
225  // consistent dropout masks. It's set to a value derived from rand()
226  // when the class is initialized.
228 };
229 
264 void ComputeObjectiveFunction(const GeneralMatrix &supervision,
265  ObjectiveType objective_type,
266  const std::string &output_name,
267  bool supply_deriv,
268  NnetComputer *computer,
269  BaseFloat *tot_weight,
270  BaseFloat *tot_objf);
271 
272 
273 
274 } // namespace nnet3
275 } // namespace kaldi
276 
277 #endif // KALDI_NNET3_NNET_TRAINING_H_
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Register(OptionsItf *opts)
Definition: nnet-optimize.h:84
This class is a wrapper that enables you to store a matrix in one of three forms: either as a Matrix<...
void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, BaseFloat *tot_objf)
This function computes the objective function, and if supply_deriv = true, supplies its derivative to...
NnetOptimizeOptions optimize_config
Definition: nnet-training.h:48
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
void Register(OptionsItf *opts)
Definition: nnet-training.h:63
kaldi::int32 int32
CachingOptimizingCompilerOptions compiler_config
Definition: nnet-training.h:50
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
This file contains some miscellaneous functions dealing with class Nnet.
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
MaxChangeStats max_change_stats_
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_
CachingOptimizingCompiler compiler_
const NnetTrainerOptions config_
void Register(OptionsItf *opts)
Definition: nnet-compute.h:42
ObjectiveType
This enum is for a kind of annotation we associate with output nodes of the network; it&#39;s for the con...
Definition: nnet-nnet.h:52
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
This class is for single-threaded training of neural nets using standard objective functions such as ...
NnetComputeOptions compute_config
Definition: nnet-training.h:49