All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
nnet-train-mmi-sequential.cc File Reference
Include dependency graph for nnet-train-mmi-sequential.cc:

Go to the source code of this file.

Namespaces

 kaldi
 Relabels neural network egs with the read pdf-id alignments.
 
 kaldi::nnet1
 

Functions

void LatticeAcousticRescore (const Matrix< BaseFloat > &log_like, const TransitionModel &trans_model, const std::vector< int32 > &state_times, Lattice *lat)
 
int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 85 of file nnet-train-mmi-sequential.cc.

References fst::AcousticLatticeScale(), Nnet::AppendComponentPointer(), Nnet::Backpropagate(), PdfPriorOptions::class_frame_counts, CuMatrixBase< Real >::CopyFromMat(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), Nnet::Feedforward(), ParseOptions::GetArg(), Nnet::GetLastComponent(), Component::GetType(), kaldi::GetVerboseLevel(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), Component::kSoftmax, kaldi::kUndefined, kaldi::nnet1::LatticeAcousticRescore(), kaldi::LatticeForwardBackward(), fst::LatticeScale(), kaldi::LatticeStateTimes(), NnetTrainOptions::learn_rate, SequentialTableReader< Holder >::Next(), fst::NumArcs(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), Nnet::OutputDim(), kaldi::nnet1::PosteriorToPdfMatrix(), ParseOptions::PrintUsage(), Nnet::Propagate(), ParseOptions::Read(), Nnet::Read(), kaldi::ReadKaldiObject(), PdfPriorOptions::Register(), NnetTrainOptions::Register(), ParseOptions::Register(), Nnet::RemoveLastComponent(), CuMatrix< Real >::Resize(), MatrixBase< Real >::Row(), fst::ScaleLattice(), Nnet::SetTrainOptions(), PdfPrior::SubtractOnLogpost(), TransitionModel::TransitionIdToPdf(), Component::TypeToMarker(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and Nnet::Write().

85  {
86  using namespace kaldi;
87  using namespace kaldi::nnet1;
88  typedef kaldi::int32 int32;
89  try {
90  const char *usage =
91  "Perform one iteration of MMI training using SGD with per-utterance"
92  "updates\n"
93 
94  "Usage: nnet-train-mmi-sequential [options] "
95  "<model-in> <transition-model-in> <feature-rspecifier> "
96  "<den-lat-rspecifier> <ali-rspecifier> [<model-out>]\n"
97 
98  "e.g.: nnet-train-mmi-sequential nnet.init trans.mdl scp:feats.scp "
99  "scp:denlats.scp ark:ali.ark nnet.iter1\n";
100 
101  ParseOptions po(usage);
102 
103  NnetTrainOptions trn_opts;
104  trn_opts.learn_rate = 0.00001; // changing default,
105  trn_opts.Register(&po);
106 
107  bool binary = true;
108  po.Register("binary", &binary, "Write output in binary mode");
109 
110  std::string feature_transform;
111  po.Register("feature-transform", &feature_transform,
112  "Feature transform in 'nnet1' format");
113 
114  PdfPriorOptions prior_opts;
115  prior_opts.Register(&po);
116 
117  BaseFloat acoustic_scale = 1.0,
118  lm_scale = 1.0,
119  old_acoustic_scale = 0.0;
120 
121  po.Register("acoustic-scale", &acoustic_scale,
122  "Scaling factor for acoustic likelihoods");
123 
124  po.Register("lm-scale", &lm_scale,
125  "Scaling factor for \"graph costs\" (including LM costs)");
126 
127  po.Register("old-acoustic-scale", &old_acoustic_scale,
128  "Add in the scores in the input lattices with this scale, "
129  "rather than discarding them.");
130 
131  kaldi::int32 max_frames = 6000;
132  po.Register("max-frames", &max_frames,
133  "Maximum number of frames an utterance can have (skipped if longer)");
134 
135  bool drop_frames = true;
136  po.Register("drop-frames", &drop_frames,
137  "Drop frames, where is zero den-posterior under numerator path "
138  "(ie. path not in lattice)");
139 
140  std::string use_gpu="yes";
141  po.Register("use-gpu", &use_gpu,
142  "yes|no|optional, only has effect if compiled with CUDA");
143 
144  po.Read(argc, argv);
145 
146  if (po.NumArgs() != 6) {
147  po.PrintUsage();
148  exit(1);
149  }
150 
151  std::string model_filename = po.GetArg(1),
152  transition_model_filename = po.GetArg(2),
153  feature_rspecifier = po.GetArg(3),
154  den_lat_rspecifier = po.GetArg(4),
155  num_ali_rspecifier = po.GetArg(5),
156  target_model_filename = po.GetArg(6);
157 
158  using namespace kaldi;
159  using namespace kaldi::nnet1;
160  typedef kaldi::int32 int32;
161 
162 #if HAVE_CUDA == 1
163  CuDevice::Instantiate().SelectGpuId(use_gpu);
164 #endif
165 
166  Nnet nnet_transf;
167  if (feature_transform != "") {
168  nnet_transf.Read(feature_transform);
169  }
170 
171  Nnet nnet;
172  nnet.Read(model_filename);
173  // we will use pre-softmax activations, removing softmax,
174  // - pre-softmax activations are equivalent to 'log-posterior + C_frame',
175  // - all paths crossing a frame share same 'C_frame',
176  // - with GMM, we also have the unnormalized acoustic likelihoods,
177  if (nnet.GetLastComponent().GetType() ==
179  KALDI_LOG << "Removing softmax from the nnet " << model_filename;
180  nnet.RemoveLastComponent();
181  } else {
182  KALDI_LOG << "The nnet was without softmax. "
183  << "The last component in " << model_filename << " was "
185  }
186  nnet.SetTrainOptions(trn_opts);
187 
188  // Read the class-frame-counts, compute priors,
189  PdfPrior log_prior(prior_opts);
190 
191  // Read transition model,
192  TransitionModel trans_model;
193  ReadKaldiObject(transition_model_filename, &trans_model);
194 
195  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
196  RandomAccessLatticeReader den_lat_reader(den_lat_rspecifier);
197  RandomAccessInt32VectorReader num_ali_reader(num_ali_rspecifier);
198 
199  CuMatrix<BaseFloat> feats_transf, nnet_out, nnet_diff;
200  Matrix<BaseFloat> nnet_out_h, nnet_diff_h;
201 
202  if (drop_frames) {
203  KALDI_LOG << "--drop-frames=true :"
204  " we will zero gradient for frames with total den/num mismatch."
205  " The mismatch is likely to be caused by missing correct path "
206  " from den-lattice due wrong annotation or search error."
207  " Leaving such frames out stabilizes the training.";
208  }
209 
210  Timer time;
211  double time_now = 0;
212  KALDI_LOG << "TRAINING STARTED";
213 
214  int32 num_done = 0, num_no_num_ali = 0, num_no_den_lat = 0,
215  num_other_error = 0, num_frm_drop = 0;
216 
217  kaldi::int64 total_frames = 0;
218  double lat_like; // total likelihood of the lattice
219  double lat_ac_like; // acoustic likelihood weighted by posterior.
220  double total_mmi_obj = 0.0, mmi_obj = 0.0;
221  double total_post_on_ali = 0.0, post_on_ali = 0.0;
222 
223  // main loop over utterances,
224  for ( ; !feature_reader.Done(); feature_reader.Next()) {
225  std::string utt = feature_reader.Key();
226  if (!den_lat_reader.HasKey(utt)) {
227  KALDI_WARN << "Missing lattice of " << utt;
228  num_no_den_lat++;
229  continue;
230  }
231  if (!num_ali_reader.HasKey(utt)) {
232  KALDI_WARN << "Missing alignment of " << utt;
233  num_no_num_ali++;
234  continue;
235  }
236 
237  // 1) get the features, numerator alignment,
238  const Matrix<BaseFloat> &mat = feature_reader.Value();
239  const std::vector<int32> &num_ali = num_ali_reader.Value(utt);
240  // check duration of numerator alignments
241  if (static_cast<int32>(num_ali.size()) != mat.NumRows()) {
242  KALDI_WARN << "Duration mismatch!"
243  << " alignment " << num_ali.size()
244  << " features " << mat.NumRows();
245  num_other_error++;
246  continue;
247  }
248  if (mat.NumRows() > max_frames) {
249  KALDI_WARN << "Skipping " << utt
250  << " that has " << mat.NumRows() << " frames,"
251  << " it is longer than '--max-frames'" << max_frames;
252  num_other_error++;
253  continue;
254  }
255 
256  // 2) get the denominator-lattice, preprocess
257  Lattice den_lat = den_lat_reader.Value(utt);
258  if (den_lat.Start() == -1) {
259  KALDI_WARN << "Empty lattice of " << utt << ", skipping.";
260  num_other_error++;
261  continue;
262  }
263  if (old_acoustic_scale != 1.0) {
264  fst::ScaleLattice(fst::AcousticLatticeScale(old_acoustic_scale),
265  &den_lat);
266  }
267  // optional sort it topologically
268  kaldi::uint64 props = den_lat.Properties(fst::kFstProperties, false);
269  if (!(props & fst::kTopSorted)) {
270  if (fst::TopSort(&den_lat) == false) {
271  KALDI_ERR << "Cycles detected in lattice.";
272  }
273  }
274  // get the lattice length and times of states,
275  vector<int32> state_times;
276  int32 max_time = kaldi::LatticeStateTimes(den_lat, &state_times);
277  // check duration of den. lattice,
278  if (max_time != mat.NumRows()) {
279  KALDI_WARN << "Duration mismatch!"
280  << " denominator lattice " << max_time
281  << " features " << mat.NumRows() << ","
282  << " skipping " << utt;
283  num_other_error++;
284  continue;
285  }
286 
287  // get dims,
288  int32 num_frames = mat.NumRows(),
289  num_pdfs = nnet.OutputDim();
290 
291  // 3) get the pre-softmax outputs from NN,
292  // apply transform,
293  nnet_transf.Feedforward(CuMatrix<BaseFloat>(mat), &feats_transf);
294  // propagate through the nnet (we know it's w/o softmax),
295  nnet.Propagate(feats_transf, &nnet_out);
296  // subtract the log_prior,
297  if (prior_opts.class_frame_counts != "") {
298  log_prior.SubtractOnLogpost(&nnet_out);
299  }
300  // transfer it back to the host,
301  nnet_out_h = Matrix<BaseFloat>(nnet_out);
302  // release the buffers we don't need anymore,
303  feats_transf.Resize(0, 0);
304  nnet_out.Resize(0, 0);
305 
306  // 4) rescore the latice,
307  LatticeAcousticRescore(nnet_out_h, trans_model, state_times, &den_lat);
308  if (acoustic_scale != 1.0 || lm_scale != 1.0)
309  fst::ScaleLattice(fst::LatticeScale(lm_scale, acoustic_scale), &den_lat);
310 
311  // 5) get the posteriors,
312  kaldi::Posterior post;
313  lat_like = kaldi::LatticeForwardBackward(den_lat, &post, &lat_ac_like);
314 
315  // 6) convert the Posterior to a matrix,
316  PosteriorToPdfMatrix(post, trans_model, &nnet_diff_h);
317 
318  // 7) Calculate the MMI-objective function,
319  // Calculate the likelihood of correct path from acoustic score,
320  // the denominator likelihood is the total likelihood of the lattice.
321  double path_ac_like = 0.0;
322  for (int32 t = 0; t < num_frames; t++) {
323  int32 pdf = trans_model.TransitionIdToPdf(num_ali[t]);
324  path_ac_like += nnet_out_h(t, pdf);
325  }
326  path_ac_like *= acoustic_scale;
327  mmi_obj = path_ac_like - lat_like;
328  //
329  // Note: numerator likelihood does not include graph score,
330  // while denominator likelihood contains graph scores.
331  // The result is offset at the MMI-objective.
332  // However the offset is constant for given alignment,
333  // so it does not change accross epochs.
334 
335  // Sum the den-posteriors under the correct path,
336  post_on_ali = 0.0;
337  for (int32 t = 0; t < num_frames; t++) {
338  int32 pdf = trans_model.TransitionIdToPdf(num_ali[t]);
339  double posterior = nnet_diff_h(t, pdf);
340  post_on_ali += posterior;
341  }
342 
343  // Report,
344  KALDI_VLOG(1) << "Lattice #" << num_done + 1 << " processed"
345  << " (" << utt << "): found " << den_lat.NumStates()
346  << " states and " << fst::NumArcs(den_lat) << " arcs.";
347 
348  KALDI_VLOG(1) << "Utterance " << utt << ": Average MMI obj. value = "
349  << (mmi_obj/num_frames) << " over " << num_frames << " frames."
350  << " (Avg. den-posterior on ali " << post_on_ali / num_frames << ")";
351 
352 
353  // 7a) Search for the frames with num/den mismatch,
354  int32 frm_drop = 0;
355  std::vector<int32> frm_drop_vec;
356  for (int32 t = 0; t < num_frames; t++) {
357  int32 pdf = trans_model.TransitionIdToPdf(num_ali[t]);
358  double posterior = nnet_diff_h(t, pdf);
359  if (posterior < 1e-20) {
360  frm_drop++;
361  frm_drop_vec.push_back(t);
362  }
363  }
364 
365  // 8) subtract the pdf-Viterbi-path,
366  for (int32 t = 0; t < nnet_diff_h.NumRows(); t++) {
367  int32 pdf = trans_model.TransitionIdToPdf(num_ali[t]);
368  nnet_diff_h(t, pdf) -= 1.0;
369  }
370 
371  // 9) Drop mismatched frames from the training by zeroing the derivative,
372  if (drop_frames) {
373  for (int32 i = 0; i < frm_drop_vec.size(); i++) {
374  nnet_diff_h.Row(frm_drop_vec[i]).Set(0.0);
375  }
376  num_frm_drop += frm_drop;
377  }
378  // Report the frame dropping
379  if (frm_drop > 0) {
380  std::stringstream ss;
381  ss << (drop_frames?"Dropped":"[dropping disabled] Would drop")
382  << " frames in " << utt << " " << frm_drop << "/" << num_frames
383  << ",";
384  // get frame intervals from vec frm_drop_vec,
385  ss << " intervals :";
386  // search for streaks of consecutive numbers,
387  int32 beg_streak = frm_drop_vec[0];
388  int32 len_streak = 0;
389  int32 i;
390  for (i = 0; i < frm_drop_vec.size(); i++, len_streak++) {
391  if (beg_streak + len_streak != frm_drop_vec[i]) {
392  ss << " " << beg_streak << ".." << frm_drop_vec[i-1] << "frm";
393  beg_streak = frm_drop_vec[i];
394  len_streak = 0;
395  }
396  }
397  ss << " " << beg_streak << ".." << frm_drop_vec[i-1] << "frm";
398  // print,
399  KALDI_WARN << ss.str();
400  }
401 
402  // 10) backpropagate through the nnet, update,
403  nnet_diff.Resize(num_frames, num_pdfs, kUndefined);
404  nnet_diff.CopyFromMat(nnet_diff_h);
405  nnet.Backpropagate(nnet_diff, NULL);
406  // relase the buffer, we don't need anymore,
407  nnet_diff.Resize(0, 0);
408 
409  // increase time counter
410  total_mmi_obj += mmi_obj;
411  total_post_on_ali += post_on_ali;
412  total_frames += num_frames;
413  num_done++;
414 
415  if (num_done % 100 == 0) {
416  time_now = time.Elapsed();
417  KALDI_VLOG(1) << "After " << num_done << " utterances: "
418  << "time elapsed = " << time_now / 60 << " min; "
419  << "processed " << total_frames / time_now << " frames per sec.";
420 #if HAVE_CUDA == 1
421  // check that GPU computes accurately,
422  CuDevice::Instantiate().CheckGpuHealth();
423 #endif
424  }
425 
426  // GRADIENT LOGGING
427  // First utterance,
428  if (num_done == 1) {
429  KALDI_VLOG(1) << nnet.InfoPropagate();
430  KALDI_VLOG(1) << nnet.InfoBackPropagate();
431  KALDI_VLOG(1) << nnet.InfoGradient();
432  }
433  // Every 1000 utterances (--verbose=2),
434  if (GetVerboseLevel() >= 2) {
435  if (num_done % 1000 == 0) {
436  KALDI_VLOG(2) << nnet.InfoPropagate();
437  KALDI_VLOG(2) << nnet.InfoBackPropagate();
438  KALDI_VLOG(2) << nnet.InfoGradient();
439  }
440  }
441  } // main loop over utterances,
442 
443  // After last utterance,
444  KALDI_VLOG(1) << nnet.InfoPropagate();
445  KALDI_VLOG(1) << nnet.InfoBackPropagate();
446  KALDI_VLOG(1) << nnet.InfoGradient();
447 
448  // Add the softmax layer back before writing,
449  KALDI_LOG << "Appending the softmax " << target_model_filename;
450  nnet.AppendComponentPointer(new Softmax(nnet.OutputDim(), nnet.OutputDim()));
451  // Store the nnet,
452  nnet.Write(target_model_filename, binary);
453 
454  time_now = time.Elapsed();
455  KALDI_LOG << "TRAINING FINISHED; "
456  << "Time taken = " << time_now/60 << " min; processed "
457  << (total_frames/time_now) << " frames per second.";
458 
459  KALDI_LOG << "Done " << num_done << " files, "
460  << num_no_num_ali << " with no numerator alignments, "
461  << num_no_den_lat << " with no denominator lattices, "
462  << num_other_error << " with other errors.";
463 
464  KALDI_LOG << "Overall MMI-objective/frame is "
465  << std::setprecision(8) << total_mmi_obj / total_frames
466  << " over " << total_frames << " frames,"
467  << " (average den-posterior on ali "
468  << total_post_on_ali / total_frames << ","
469  << " dropped " << num_frm_drop
470  << " frames with num/den mismatch)";
471 
472 #if HAVE_CUDA == 1
473  CuDevice::Instantiate().PrintProfile();
474 #endif
475 
476  return 0;
477  } catch(const std::exception &e) {
478  std::cerr << e.what();
479  return -1;
480  }
481 }
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
void RemoveLastComponent()
Remove the last of the Components,.
Definition: nnet-nnet.cc:206
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void PosteriorToPdfMatrix(const Posterior &post, const TransitionModel &model, CuMatrix< Real > *mat)
Wrapper of PosteriorToMatrixMapped with CuMatrix argument.
Definition: nnet-utils.h:304
void AppendComponentPointer(Component *dynamically_allocated_comp)
Append Component* to 'this' instance of Nnet by a shallow copy ('this' instance of Nnet over-takes th...
Definition: nnet-nnet.cc:187
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
int32 OutputDim() const
Dimensionality of network outputs (posteriors | bn-features | etc.),.
Definition: nnet-nnet.cc:143
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
int32 TransitionIdToPdf(int32 trans_id) const
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:337
const Component & GetLastComponent() const
LastComponent accessor,.
Definition: nnet-nnet.cc:161
double Elapsed()
Returns time in seconds.
Definition: timer.h:65
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:818
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:182
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post, double *acoustic_like_sum)
This function does the forward-backward over lattices and computes the posterior probabilities of the...
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:47
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
vector< vector< double > > AcousticLatticeScale(double acwt)
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:127
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
Definition: nnet-nnet.cc:333
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:46
#define KALDI_WARN
Definition: kaldi-error.h:130
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
void ScaleLattice(const vector< vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
Definition: nnet-nnet.cc:367
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508
vector< vector< double > > LatticeScale(double lmwt, double acwt)
Arc::StateId NumArcs(const ExpandedFst< Arc > &fst)
Returns the total number of arcs in an FST.
#define KALDI_LOG
Definition: kaldi-error.h:133
void Register(OptionsItf *opts)
void LatticeAcousticRescore(const Matrix< BaseFloat > &log_like, const TransitionModel &trans_model, const std::vector< int32 > &state_times, Lattice *lat)