discriminative-training.h
Go to the documentation of this file.
1 // nnet3/discriminative-training.h
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // Copyright 2014-2015 Vimal Manohar
5 
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 
23 #ifndef KALDI_NNET3_DISCRIMINATIVE_TRAINING_H_
24 #define KALDI_NNET3_DISCRIMINATIVE_TRAINING_H_
25 
26 #include "base/kaldi-common.h"
27 #include "util/common-utils.h"
28 #include "fstext/fstext-lib.h"
29 #include "tree/context-dep.h"
30 #include "lat/kaldi-lattice.h"
31 #include "matrix/kaldi-matrix.h"
32 #include "hmm/transition-model.h"
34 #include "lat/lattice-functions.h"
36 
37 namespace kaldi {
38 namespace discriminative {
39 
40 /* Options for discriminative training
41  *
42  * Legend:
43  * mmi - Maximum Mutual Information
44  * mpfe - Minimum Phone Frame Error
45  * smbr - State Minimum Bayes Risk
46  *
47  */
49  std::string criterion; // one of {"mmi", "mpfe", "smbr"}
50  // If the criterion does not match the supervision
51  // object, the derivatives may not be very accurate
53  bool drop_frames; // for MMI, true if we ignore frames where alignment
54  // pdf-id is not in the lattice.
55  bool one_silence_class; // Affects MPFE and SMBR objectives
56  BaseFloat boost; // for MMI, boosting factor (would be Boosted MMI)... e.g. 0.1.
57 
58  std::string silence_phones_str; // colon-separated list of integer ids of silence phones,
59  // for MPFE and SMBR objectives
60 
61  // Cross-entropy regularization constant. (e.g. try 0.1). If nonzero,
62  // the network is expected to have an output named 'output-xent', which
63  // should have a softmax as its final nonlinearity.
65 
66  // l2 regularization constant on the 'chain' output; the actual term added to
67  // the objf will be -0.5 times this constant times the squared l2 norm.
68  // (squared so it's additive across the dimensions). e.g. try 0.0005.
70 
71  // Options for debugging discriminative training
72 
73  // Accumulates gradients wrt nnet outputs
75 
76  // Accumulates nnet output
78 
79  // Applicable for debugging discriminative training when accumulate_gradients
80  // or accumulate_output is true
82 
83  DiscriminativeOptions(): criterion("smbr"),
84  acoustic_scale(0.1),
85  drop_frames(false),
86  one_silence_class(false),
87  boost(0.0),
88  xent_regularize(0.0),
89  l2_regularize(0.0),
90  accumulate_gradients(false),
91  accumulate_output(false),
92  num_pdfs(0) { }
93 
94  void Register(OptionsItf *opts) {
95  opts->Register("criterion", &criterion, "Criterion, 'mmi'|'mpfe'|'smbr', "
96  "determines the objective function to use. Should match "
97  "option used when we created the examples.");
98  opts->Register("acoustic-scale", &acoustic_scale, "Weighting factor to "
99  "apply to acoustic likelihoods.");
100  opts->Register("drop-frames", &drop_frames, "For MMI, if true we drop frames "
101  "with no overlap of num and den pdf-ids");
102  opts->Register("boost", &boost, "Boosting factor for boosted MMI (e.g. 0.1)");
103  opts->Register("one-silence-class", &one_silence_class, "If true, newer "
104  "behavior which will tend to reduce insertions "
105  "when using MPFE or SMBR objective");
106  opts->Register("silence-phones", &silence_phones_str,
107  "For MPFE or SMBR objectives, colon-separated list of "
108  "integer ids of silence phones, e.g. 1:2:3");
109  opts->Register("l2-regularize", &l2_regularize, "l2 regularization "
110  "constant for 'chain' output "
111  "of the neural net.");
112  opts->Register("xent-regularize", &xent_regularize, "Cross-entropy "
113  "regularization constant for sequence training. If "
114  "nonzero, the network is expected to have an output "
115  "named 'output-xent', which should have a softmax as "
116  "its final nonlinearity.");
117  opts->Register("accumulate-gradients", &accumulate_gradients,
118  "Accumulate gradients wrt nnet output "
119  "for debugging discriminative training");
120  opts->Register("accumulate-output", &accumulate_output,
121  "Accumulate nnet output "
122  "for debugging discriminative training");
123  opts->Register("num-pdfs", &num_pdfs,
124  "Number of pdfs; "
125  "applicable when accumulate-output or accumulate-gradients "
126  "is true for discriminative training");
127  }
128 };
129 
131  double tot_t; // total number of frames
132  double tot_t_weighted; // total number of frames times weight.
133  double tot_objf; // for 'mmi', the (weighted) denominator likelihood; for
134  // everything else, the objective function.
135  double tot_num_count; // total count of numerator posterior
136  double tot_den_count; // total count of denominator posterior
137  double tot_num_objf; // for 'mmi', the (weighted) numerator likelihood; for
138  // everything else 0
139 
140  double tot_l2_term; // l2 regularization objective
141  // l2 regularization constant on the 'chain' output; the actual term added to
142  // the objf will be -0.5 times this constant times the squared l2 norm.
143  // (squared so it's additive across the dimensions). e.g. try 0.0005.
144 
145  // Options for debugging discriminative training
146 
147  // Accumulates gradients wrt nnet outputs
149 
150  // Accumulates nnet output
152 
153  // Applicable for debugging discriminative training when accumulate_gradients
154  // or accumulate_output is true
156 
157  // Used to accumulates gradients wrt nnet outputs
158  // when accumulate_gradients is true
160  // Used to accumulates output when accumulate_output is true
162 
163  // Print statistics for the criterion
164  void Print(const std::string &criterion,
165  bool print_avg_gradients = false,
166  bool print_avg_output = false) const;
167 
168  // Print all accumulated statistics for debugging
169  void PrintAll(const std::string &criterion) const {
170  Print(criterion, true, true);
171  }
172 
173  // Print the gradient wrt nnet output accumulated for a pdf
174  void PrintAvgGradientForPdf(int32 pdf_id) const;
175 
176  // Add stats from another object
177  void Add(const DiscriminativeObjectiveInfo &other);
178 
179  // Returns the objective function value for the criterion
180  inline double TotalObjf(const std::string &criterion) const {
181  if (criterion == "mmi") return (tot_num_objf - tot_objf);
182  return tot_objf;
183  }
184 
185  // Returns true if accumulate_gradients is true
186  // and the gradients vector has been resized to store the
187  // accumulated gradients
188  inline bool AccumulateGradients() const {
189  return accumulate_gradients && gradients.Dim() > 0;
190  }
191 
192  // Returns true if accumulate_output is true
193  // and the output vector has been resized to store the
194  // accumulated nnet output
195  inline bool AccumulateOutput() const {
196  return accumulate_output && output.Dim() > 0;
197  }
198 
199  // Empty constructor
201 
202  // Constructor preparing to gradients or output to be accumulated
204 
205  // Constructor from config options
207 
208  // Reset statistics
209  void Reset();
210 
211  void Configure(const DiscriminativeOptions &opts);
212 };
213 
237  const DiscriminativeOptions &opts,
238  const TransitionModel &tmodel,
239  const CuVectorBase<BaseFloat> &log_priors,
240  const DiscriminativeSupervision &supervision,
241  const CuMatrixBase<BaseFloat> &nnet_output,
243  CuMatrixBase<BaseFloat> *nnet_output_deriv,
244  CuMatrixBase<BaseFloat> *xent_output_deriv);
245 
246 } // namespace discriminative
247 } // namespace kaldi
248 
249 #endif // KALDI_NNET3_DISCRIMINATIVE_TRAINING_H_
250 
251 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintAll(const std::string &criterion) const
kaldi::int32 int32
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
double TotalObjf(const std::string &criterion) const
void ComputeDiscriminativeObjfAndDeriv(const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
This function does forward-backward on the numerator and denominator lattices and computes derivates ...
Matrix for CUDA computing.
Definition: matrix-common.h:69
void Print(const Fst< Arc > &fst, std::string message)
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:69
Vector for CUDA computing.
Definition: matrix-common.h:72