nnet-compute-discriminative.cc
Go to the documentation of this file.
1 // nnet2/nnet-compute-discriminative.cc
2 
3 // Copyright 2012-2013 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
21 #include "hmm/posterior.h"
22 #include "lat/lattice-functions.h"
23 
24 namespace kaldi {
25 namespace nnet2 {
26 
27 /*
28  This class does the forward and possibly backward computation for (typically)
29  a whole utterance of contiguous features. You'll instantiate one of
30  these classes each time you want to do this computation.
31 */
33  public:
34 
35  NnetDiscriminativeUpdater(const AmNnet &am_nnet,
36  const TransitionModel &tmodel,
38  const DiscriminativeNnetExample &eg,
39  Nnet *nnet_to_update,
41 
42  void Update() {
43  Propagate();
45  if (nnet_to_update_ != NULL)
46  Backprop();
47  }
48 
50  void Propagate();
51 
54  void LatticeComputations();
55 
56  void Backprop();
57 
65 
67 
69 
70  static inline Int32Pair MakePair(int32 first, int32 second) {
71  Int32Pair ans;
72  ans.first = first;
73  ans.second = second;
74  return ans;
75  }
76 
77  private:
78  typedef LatticeArc Arc;
80 
81 
82  const AmNnet &am_nnet_;
86  Nnet *nnet_to_update_; // will equal am_nnet_.GetNnet(), in SGD case, or
87  // another Nnet, in gradient-computation case, or
88  // NULL if we just need the objective function.
89  NnetDiscriminativeStats *stats_; // the objective function, etc.
90  std::vector<ChunkInfo> chunk_info_out_;
91  // forward_data_[i] is the input of the i'th component and (if i > 0)
92  // the output of the i-1'th component.
93  std::vector<CuMatrix<BaseFloat> > forward_data_;
94  Lattice lat_; // we convert the CompactLattice in the eg, into Lattice form.
96  std::vector<int32> silence_phones_; // derived from opts_.silence_phones_str
97 };
98 
99 
100 
102  const AmNnet &am_nnet,
103  const TransitionModel &tmodel,
105  const DiscriminativeNnetExample &eg,
106  Nnet *nnet_to_update,
107  NnetDiscriminativeStats *stats):
108  am_nnet_(am_nnet), tmodel_(tmodel), opts_(opts), eg_(eg),
109  nnet_to_update_(nnet_to_update), stats_(stats) {
111  &silence_phones_)) {
112  KALDI_ERR << "Bad value for --silence-phones option: "
114  }
115  const Nnet &nnet = am_nnet_.GetNnet();
117 }
118 
119 
120 
122  int32 num_frames_output = eg_.num_ali.size();
123  int32 eg_left_context = eg_.left_context,
124  eg_right_context = eg_.input_frames.NumRows() -
125  num_frames_output - eg_left_context;
126  KALDI_ASSERT(eg_right_context >= 0);
127  const Nnet &nnet = am_nnet_.GetNnet();
128  // Make sure the example has enough acoustic left and right
129  // context... normally we'll use examples generated using the same model,
130  // which will have the exact context, but we enable a mismatch in context as
131  // long as it is more, not less.
132  KALDI_ASSERT(eg_left_context >= nnet.LeftContext() &&
133  eg_right_context >= nnet.RightContext());
134  int32 offset = eg_left_context - nnet.LeftContext(),
135  num_output_frames =
136  num_frames_output + nnet.LeftContext() + nnet.RightContext();
137  SubMatrix<BaseFloat> ans(eg_.input_frames, offset, num_output_frames,
138  0, eg_.input_frames.NumCols());
139  return ans;
140 }
141 
143  const Nnet &nnet = am_nnet_.GetNnet();
144  forward_data_.resize(nnet.NumComponents() + 1);
145 
146  SubMatrix<BaseFloat> input_feats = GetInputFeatures();
147  int32 spk_dim = eg_.spk_info.Dim();
148  if (spk_dim == 0) {
149  forward_data_[0] = input_feats;
150  } else {
151  forward_data_[0].Resize(input_feats.NumRows(),
152  input_feats.NumCols() + eg_.spk_info.Dim());
153  forward_data_[0].Range(0, input_feats.NumRows(),
154  0, input_feats.NumCols()).CopyFromMat(input_feats);
155  forward_data_[0].Range(0, input_feats.NumRows(),
156  input_feats.NumCols(), spk_dim).CopyRowsFromVec(
157  eg_.spk_info);
158  }
159 
160  for (int32 c = 0; c < nnet.NumComponents(); c++) {
161  const Component &component = nnet.GetComponent(c);
163  &output = forward_data_[c+1];
164  component.Propagate(chunk_info_out_[c] , chunk_info_out_[c+1], input, &output);
165  const Component *prev_component = (c == 0 ? NULL :
166  &(nnet.GetComponent(c-1)));
167  bool will_do_backprop = (nnet_to_update_ != NULL),
168  keep_last_output = will_do_backprop &&
169  ((c>0 && prev_component->BackpropNeedsOutput()) ||
170  component.BackpropNeedsInput());
171  if (!keep_last_output)
172  forward_data_[c].Resize(0, 0); // We won't need this data; save memory.
173  }
174 }
175 
176 
177 
179  ConvertLattice(eg_.den_lat, &lat_); // convert to Lattice.
180  TopSort(&lat_); // Topologically sort (required by forward-backward algorithms)
181 
182  if (opts_.criterion == "mmi" && opts_.boost != 0.0) {
183  BaseFloat max_silence_error = 0.0;
185  opts_.boost, max_silence_error, &lat_);
186  }
187 
188  int32 num_frames = static_cast<int32>(eg_.num_ali.size());
189 
190  stats_->tot_t += num_frames;
191  stats_->tot_t_weighted += num_frames * eg_.weight;
192 
193  const VectorBase<BaseFloat> &priors = am_nnet_.Priors();
194  const CuMatrix<BaseFloat> &posteriors = forward_data_.back();
195 
196  KALDI_ASSERT(posteriors.NumRows() == num_frames);
197  int32 num_pdfs = posteriors.NumCols();
198  KALDI_ASSERT(num_pdfs == priors.Dim());
199 
200  // We need to look up the posteriors of some pdf-ids in the matrix
201  // "posteriors". Rather than looking them all up using operator (), which is
202  // very slow because each lookup involves a separate CUDA call with
203  // communication over PciExpress, we look them up all at once using
204  // CuMatrix::Lookup().
205  // Note: regardless of the criterion, we evaluate the likelihoods in
206  // the numerator alignment. Even though they may be irrelevant to
207  // the optimization, they will affect the value of the objective function.
208 
209  std::vector<Int32Pair> requested_indexes;
210  BaseFloat wiggle_room = 1.3; // value not critical.. it's just 'reserve'
211  requested_indexes.reserve(num_frames + wiggle_room * lat_.NumStates());
212 
213  if (opts_.criterion == "mmi") { // need numerator probabilities...
214  for (int32 t = 0; t < num_frames; t++) {
215  int32 tid = eg_.num_ali[t], pdf_id = tmodel_.TransitionIdToPdf(tid);
216  KALDI_ASSERT(pdf_id >= 0 && pdf_id < num_pdfs);
217  requested_indexes.push_back(MakePair(t, pdf_id));
218  }
219  }
220 
221  std::vector<int32> state_times;
222  int32 T = LatticeStateTimes(lat_, &state_times);
223  KALDI_ASSERT(T == num_frames);
224 
225  StateId num_states = lat_.NumStates();
226  for (StateId s = 0; s < num_states; s++) {
227  StateId t = state_times[s];
228  for (fst::ArcIterator<Lattice> aiter(lat_, s); !aiter.Done(); aiter.Next()) {
229  const Arc &arc = aiter.Value();
230  if (arc.ilabel != 0) { // input-side has transition-ids, output-side empty
231  int32 tid = arc.ilabel, pdf_id = tmodel_.TransitionIdToPdf(tid);
232  requested_indexes.push_back(MakePair(t, pdf_id));
233  }
234  }
235  }
236 
237  std::vector<BaseFloat> answers;
238  CuArray<Int32Pair> cu_requested_indexes(requested_indexes);
239  answers.resize(requested_indexes.size());
240  posteriors.Lookup(cu_requested_indexes, &(answers[0]));
241 
242  int32 num_floored = 0;
243 
244  BaseFloat floor_val = 1.0e-20; // floor for posteriors.
245  size_t index;
246 
247  // Replace "answers" with the vector of scaled log-probs. If this step takes
248  // too much time, we can look at other ways to do it, using the CUDA card.
249  for (index = 0; index < answers.size(); index++) {
250  BaseFloat post = answers[index];
251  if (post < floor_val) {
252  post = floor_val;
253  num_floored++;
254  }
255  int32 pdf_id = requested_indexes[index].second;
256  BaseFloat pseudo_loglike = Log(post / priors(pdf_id)) * opts_.acoustic_scale;
257  KALDI_ASSERT(!KALDI_ISINF(pseudo_loglike) && !KALDI_ISNAN(pseudo_loglike));
258  answers[index] = pseudo_loglike;
259  }
260  if (num_floored > 0) {
261  KALDI_WARN << "Floored " << num_floored << " probabilities from nnet.";
262  }
263 
264  index = 0;
265 
266  if (opts_.criterion == "mmi") {
267  double tot_num_like = 0.0;
268  for (; index < eg_.num_ali.size(); index++)
269  tot_num_like += answers[index];
270  stats_->tot_num_objf += eg_.weight * tot_num_like;
271  }
272 
273  // Now put the (scaled) acoustic log-likelihoods in the lattice.
274  for (StateId s = 0; s < num_states; s++) {
275  for (fst::MutableArcIterator<Lattice> aiter(&lat_, s);
276  !aiter.Done(); aiter.Next()) {
277  Arc arc = aiter.Value();
278  if (arc.ilabel != 0) { // input-side has transition-ids, output-side empty
279  arc.weight.SetValue2(-answers[index]);
280  index++;
281  aiter.SetValue(arc);
282  }
283  }
284  LatticeWeight final = lat_.Final(s);
285  if (final != LatticeWeight::Zero()) {
286  final.SetValue2(0.0); // make sure no acoustic term in final-prob.
287  lat_.SetFinal(s, final);
288  }
289  }
290  KALDI_ASSERT(index == answers.size());
291 
292  // Get the MPE or MMI posteriors.
293  Posterior post;
295 
296  ScalePosterior(eg_.weight, &post);
297 
298  double tot_num_post = 0.0, tot_den_post = 0.0;
299  std::vector<MatrixElement<BaseFloat> > sv_labels;
300  sv_labels.reserve(answers.size());
301  for (int32 t = 0; t < post.size(); t++) {
302  for (int32 i = 0; i < post[t].size(); i++) {
303  int32 pdf_id = post[t][i].first;
304  BaseFloat weight = post[t][i].second;
305  if (weight > 0.0) { tot_num_post += weight; }
306  else { tot_den_post -= weight; }
307  MatrixElement<BaseFloat> elem = {t, pdf_id, weight};
308  sv_labels.push_back(elem);
309  }
310  }
311  stats_->tot_num_count += tot_num_post;
312  int32 num_components = am_nnet_.GetNnet().NumComponents();
313  const CuMatrix<BaseFloat> &output(forward_data_[num_components]);
314  backward_data_.Resize(output.NumRows(), output.NumCols()); // zeroes it.
315 
316  { // We don't actually need tot_objf and tot_weight; we have already
317  // computed the objective function.
318  BaseFloat tot_objf, tot_weight;
319  backward_data_.CompObjfAndDeriv(sv_labels, output, &tot_objf, &tot_weight);
320  // Now backward_data_ will contan the derivative at the output.
321  // Our work here is done..
322  }
323 }
324 
325 
327  if (opts_.criterion == "mpfe" || opts_.criterion == "smbr") {
328  Posterior tid_post;
329  double ans;
333  &tid_post);
334  ConvertPosteriorToPdfs(tmodel_, tid_post, post);
335  return ans; // returns the objective function.
336  } else {
337  KALDI_ASSERT(opts_.criterion == "mmi");
338  bool convert_to_pdfs = true, cancel = true;
339  // we'll return the denominator-lattice forward backward likelihood,
340  // which is one term in the objective function.
342  opts_.drop_frames, convert_to_pdfs,
343  cancel, post);
344  }
345 }
346 
347 
348 
350  const Nnet &nnet = am_nnet_.GetNnet();
351  for (int32 c = nnet.NumComponents() - 1; c >= 0; c--) {
352  const Component &component = nnet.GetComponent(c);
353  Component *component_to_update = &(nnet_to_update_->GetComponent(c));
354  const CuMatrix<BaseFloat> &input = forward_data_[c],
355  &output = forward_data_[c+1],
356  &output_deriv = backward_data_;
357  CuMatrix<BaseFloat> input_deriv;
358  component.Backprop(chunk_info_out_[c], chunk_info_out_[c+1], input, output, output_deriv,
359  component_to_update, &input_deriv);
360  backward_data_.Swap(&input_deriv); // backward_data_ = input_deriv.
361  }
362 }
363 
364 
365 void NnetDiscriminativeUpdate(const AmNnet &am_nnet,
366  const TransitionModel &tmodel,
368  const DiscriminativeNnetExample &eg,
369  Nnet *nnet_to_update,
370  NnetDiscriminativeStats *stats) {
371  NnetDiscriminativeUpdater updater(am_nnet, tmodel, opts, eg,
372  nnet_to_update, stats);
373  updater.Update();
374 }
375 
377  tot_t += other.tot_t;
378  tot_t_weighted += other.tot_t_weighted;
379  tot_num_count += other.tot_num_count;
380  tot_num_objf += other.tot_num_objf;
381  tot_den_objf += other.tot_den_objf;
382 }
383 
384 void NnetDiscriminativeStats::Print(std::string criterion) {
385  KALDI_ASSERT(criterion == "mmi" || criterion == "smbr" ||
386  criterion == "mpfe");
387 
388  double avg_post_per_frame = tot_num_count / tot_t_weighted;
389  KALDI_LOG << "Number of frames is " << tot_t
390  << " (weighted: " << tot_t_weighted
391  << "), average (num or den) posterior per frame is "
392  << avg_post_per_frame;
393 
394  if (criterion == "mmi") {
395  double num_objf = tot_num_objf / tot_t_weighted,
396  den_objf = tot_den_objf / tot_t_weighted,
397  objf = num_objf - den_objf;
398  KALDI_LOG << "MMI objective function is " << num_objf << " - "
399  << den_objf << " = " << objf << " per frame, over "
400  << tot_t_weighted << " frames.";
401  } else if (criterion == "mpfe") {
402  double objf = tot_den_objf / tot_t_weighted; // this contains the actual
403  // summed objf
404  KALDI_LOG << "MPFE objective function is " << objf
405  << " per frame, over " << tot_t_weighted << " frames.";
406  } else {
407  double objf = tot_den_objf / tot_t_weighted; // this contains the actual
408  // summed objf
409  KALDI_LOG << "SMBR objective function is " << objf
410  << " per frame, over " << tot_t_weighted << " frames.";
411  }
412 }
413 
414 
415 } // namespace nnet2
416 } // namespace kaldi
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
int32 LeftContext() const
Returns the left-context summed over all the Components...
Definition: nnet-nnet.cc:42
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
std::vector< CuMatrix< BaseFloat > > forward_data_
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
Abstract class, basic element of the network, it is a box with defined inputs, outputs, and tranformation functions interface.
#define KALDI_ISINF
Definition: kaldi-math.h:73
virtual bool BackpropNeedsInput() const
static Int32Pair MakePair(int32 first, int32 second)
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
int32 TransitionIdToPdf(int32 trans_id) const
void Add(const NnetDiscriminativeStats &other)
void Lookup(const std::vector< Int32Pair > &indexes, Real *output) const
Definition: cu-matrix.cc:3370
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
BaseFloat LatticeForwardBackwardMmi(const TransitionModel &tmodel, const Lattice &lat, const std::vector< int32 > &num_ali, bool drop_frames, bool convert_to_pdf_ids, bool cancel, Posterior *post)
This function can be used to compute posteriors for MMI, with a positive contribution for the numerat...
double Log(double x)
Definition: kaldi-math.h:100
NnetDiscriminativeUpdater(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
const VectorBase< BaseFloat > & Priors() const
Definition: am-nnet.h:67
int32 RightContext() const
Returns the right-context summed over all the Components...
Definition: nnet-nnet.cc:56
double GetDiscriminativePosteriors(Posterior *post)
Assuming the lattice already has the correct scores in it, this function does the MPE or MMI forward-...
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
static const LatticeWeightTpl Zero()
Vector< BaseFloat > spk_info
spk_info contains any component of the features that varies slowly or not at all with time (and hence...
Definition: nnet-example.h:171
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
virtual void Backprop(const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, Component *to_update, CuMatrix< BaseFloat > *in_deriv) const =0
Perform backward pass propagation of the derivative, and also either update the model (if to_update =...
#define KALDI_ERR
Definition: kaldi-error.h:147
BaseFloat LatticeForwardBackwardMpeVariants(const TransitionModel &trans, const std::vector< int32 > &silence_phones, const Lattice &lat, const std::vector< int32 > &num_ali, std::string criterion, bool one_silence_class, Posterior *post)
This function implements either the MPFE (minimum phone frame error) or SMBR (state-level minimum bay...
CompactLattice den_lat
The denominator lattice.
Definition: nnet-example.h:148
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
Matrix< BaseFloat > input_frames
The input data– typically with a number of frames [NumRows()] larger than labels.size(), because it includes features to the left and right as needed for the temporal context of the network.
Definition: nnet-example.h:159
std::vector< int32 > num_ali
The numerator alignment.
Definition: nnet-example.h:143
BaseFloat weight
The weight we assign to this example; this will typically be one, but we include it for the sake of g...
Definition: nnet-example.h:140
void ScalePosterior(BaseFloat scale, Posterior *post)
Scales the BaseFloat (weight) element in the posterior entries.
Definition: posterior.cc:218
Matrix for CUDA computing.
Definition: matrix-common.h:69
void LatticeComputations()
Does the parts between Propagate() and Backprop(), that involve forward-backward over the lattice...
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
This struct is used to store the information we need for discriminative training (MMI or MPE)...
Definition: nnet-example.h:136
virtual void Propagate(const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const =0
Perform forward pass propagation Input->Output.
#define KALDI_ISNAN
Definition: kaldi-math.h:72
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
const NnetDiscriminativeUpdateOptions & opts_
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
bool LatticeBoost(const TransitionModel &trans, const std::vector< int32 > &alignment, const std::vector< int32 > &silence_phones, BaseFloat b, BaseFloat max_silence_error, Lattice *lat)
Boosts LM probabilities by b * [number of frame errors]; equivalently, adds -b*[number of frame error...
int32_cuda second
Definition: cu-matrixdim.h:80
void Propagate()
The forward-through-the-layers part of the computation.
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
Definition: posterior.cc:322
int32 left_context
The number of frames of left context in the features (we can work out the #frames of right context fr...
Definition: nnet-example.h:164
#define KALDI_LOG
Definition: kaldi-error.h:153
void ComputeChunkInfo(int32 input_chunk_size, int32 num_chunks, std::vector< ChunkInfo > *chunk_info_out) const
Uses the output of the Context() functions of the network, to compute a vector of size NumComponents(...
Definition: nnet-nnet.cc:65
Sub-matrix representation.
Definition: kaldi-matrix.h:988
virtual bool BackpropNeedsOutput() const
const Nnet & GetNnet() const
Definition: am-nnet.h:61
void NnetDiscriminativeUpdate(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
Does the neural net computation, lattice forward-backward, and backprop, for either the MMI...
int32_cuda first
Definition: cu-matrixdim.h:79