All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
nnet-train-mpe-sequential.cc File Reference
Include dependency graph for nnet-train-mpe-sequential.cc:

Go to the source code of this file.

Namespaces

 kaldi
 Relabels neural network egs with the read pdf-id alignments.
 
 kaldi::nnet1
 

Functions

void LatticeAcousticRescore (const Matrix< BaseFloat > &log_like, const TransitionModel &trans_model, const std::vector< int32 > &state_times, Lattice *lat)
 
int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 85 of file nnet-train-mpe-sequential.cc.

References fst::AcousticLatticeScale(), Nnet::AppendComponentPointer(), Nnet::Backpropagate(), PdfPriorOptions::class_frame_counts, SequentialTableReader< Holder >::Done(), Timer::Elapsed(), Nnet::Feedforward(), ParseOptions::GetArg(), Nnet::GetLastComponent(), Component::GetType(), kaldi::GetVerboseLevel(), RandomAccessTableReader< Holder >::HasKey(), Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), Component::kSoftmax, kaldi::nnet1::LatticeAcousticRescore(), kaldi::LatticeForwardBackwardMpeVariants(), fst::LatticeScale(), kaldi::LatticeStateTimes(), NnetTrainOptions::learn_rate, CuMatrixBase< Real >::Max(), CuMatrixBase< Real >::Min(), SequentialTableReader< Holder >::Next(), fst::NumArcs(), ParseOptions::NumArgs(), Nnet::OutputDim(), kaldi::nnet1::PosteriorToPdfMatrix(), ParseOptions::PrintUsage(), Nnet::Propagate(), ParseOptions::Read(), Nnet::Read(), kaldi::ReadKaldiObject(), PdfPriorOptions::Register(), NnetTrainOptions::Register(), ParseOptions::Register(), Nnet::RemoveLastComponent(), CuMatrix< Real >::Resize(), CuMatrixBase< Real >::Scale(), fst::ScaleLattice(), Nnet::SetTrainOptions(), kaldi::SortAndUniq(), kaldi::SplitStringToIntegers(), PdfPrior::SubtractOnLogpost(), Component::TypeToMarker(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and Nnet::Write().

85  {
86  using namespace kaldi;
87  using namespace kaldi::nnet1;
88  typedef kaldi::int32 int32;
89  try {
90  const char *usage =
91  "Perform one iteration of MPE/sMBR training using SGD with per-utterance"
92  "updates.\n"
93 
94  "Usage: nnet-train-mpe-sequential [options] "
95  "<model-in> <transition-model-in> <feature-rspecifier> "
96  "<den-lat-rspecifier> <ali-rspecifier> [<model-out>]\n"
97 
98  "e.g.: nnet-train-mpe-sequential nnet.init trans.mdl scp:feats.scp "
99  "scp:denlats.scp ark:ali.ark nnet.iter1\n";
100 
101  ParseOptions po(usage);
102 
103  NnetTrainOptions trn_opts;
104  trn_opts.learn_rate = 0.00001; // changing default,
105  trn_opts.Register(&po);
106 
107  bool binary = true;
108  po.Register("binary", &binary, "Write output in binary mode");
109 
110  std::string feature_transform;
111  po.Register("feature-transform", &feature_transform,
112  "Feature transform in 'nnet1' format");
113 
114  std::string silence_phones_str;
115  po.Register("silence-phones", &silence_phones_str,
116  "Colon-separated list of integer id's of silence phones, e.g. 46:47");
117 
118  PdfPriorOptions prior_opts;
119  prior_opts.Register(&po);
120 
121  BaseFloat acoustic_scale = 1.0,
122  lm_scale = 1.0,
123  old_acoustic_scale = 0.0;
124 
125  po.Register("acoustic-scale", &acoustic_scale,
126  "Scaling factor for acoustic likelihoods");
127 
128  po.Register("lm-scale", &lm_scale,
129  "Scaling factor for \"graph costs\" (including LM costs)");
130 
131  po.Register("old-acoustic-scale", &old_acoustic_scale,
132  "Add in the scores in the input lattices with this scale, rather "
133  "than discarding them.");
134 
135  bool one_silence_class = false;
136  po.Register("one-silence-class", &one_silence_class,
137  "If true, the newer behavior reduces insertions.");
138 
139  kaldi::int32 max_frames = 6000;
140  po.Register("max-frames", &max_frames,
141  "Maximum number of frames an utterance can have (skipped if longer)");
142 
143  bool do_smbr = false;
144  po.Register("do-smbr", &do_smbr,
145  "Use state-level accuracies instead of phone accuracies.");
146 
147  std::string use_gpu="yes";
148  po.Register("use-gpu", &use_gpu,
149  "yes|no|optional, only has effect if compiled with CUDA");
150 
151  po.Read(argc, argv);
152 
153  if (po.NumArgs() != 6) {
154  po.PrintUsage();
155  exit(1);
156  }
157 
158  std::string model_filename = po.GetArg(1),
159  transition_model_filename = po.GetArg(2),
160  feature_rspecifier = po.GetArg(3),
161  den_lat_rspecifier = po.GetArg(4),
162  ref_ali_rspecifier = po.GetArg(5),
163  target_model_filename = po.GetArg(6);
164 
165  std::vector<int32> silence_phones;
166  if (!kaldi::SplitStringToIntegers(silence_phones_str, ":", false,
167  &silence_phones)) {
168  KALDI_ERR << "Invalid silence-phones string " << silence_phones_str;
169  }
170  kaldi::SortAndUniq(&silence_phones);
171  if (silence_phones.empty()) {
172  KALDI_LOG << "No silence phones specified.";
173  }
174 
175 #if HAVE_CUDA == 1
176  CuDevice::Instantiate().SelectGpuId(use_gpu);
177 #endif
178 
179  Nnet nnet_transf;
180  if (feature_transform != "") {
181  nnet_transf.Read(feature_transform);
182  }
183 
184  Nnet nnet;
185  nnet.Read(model_filename);
186  // we will use pre-softmax activations, removing softmax,
187  // - pre-softmax activations are equivalent to 'log-posterior + C_frame',
188  // - all paths crossing a frame share same 'C_frame',
189  // - with GMM, we also have the unnormalized acoustic likelihoods,
190  if (nnet.GetLastComponent().GetType() ==
192  KALDI_LOG << "Removing softmax from the nnet " << model_filename;
193  nnet.RemoveLastComponent();
194  } else {
195  KALDI_LOG << "The nnet was without softmax. "
196  << "The last component in " << model_filename << " was "
198  }
199  nnet.SetTrainOptions(trn_opts);
200 
201  // Read the class-frame-counts, compute priors,
202  PdfPrior log_prior(prior_opts);
203 
204  // Read transition model,
205  TransitionModel trans_model;
206  ReadKaldiObject(transition_model_filename, &trans_model);
207 
208  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
209  RandomAccessLatticeReader den_lat_reader(den_lat_rspecifier);
210  RandomAccessInt32VectorReader ref_ali_reader(ref_ali_rspecifier);
211 
212  CuMatrix<BaseFloat> feats_transf, nnet_out, nnet_diff;
213  Matrix<BaseFloat> nnet_out_h;
214 
215  Timer time;
216  double time_now = 0;
217  KALDI_LOG << "TRAINING STARTED";
218 
219  int32 num_done = 0,
220  num_no_ref_ali = 0,
221  num_no_den_lat = 0,
222  num_other_error = 0;
223 
224  kaldi::int64 total_frames = 0;
225  double total_frame_acc = 0.0, utt_frame_acc;
226 
227  // main loop over utterances,
228  for (; !feature_reader.Done(); feature_reader.Next()) {
229  std::string utt = feature_reader.Key();
230  if (!den_lat_reader.HasKey(utt)) {
231  KALDI_WARN << "Missing lattice for " << utt;
232  num_no_den_lat++;
233  continue;
234  }
235  if (!ref_ali_reader.HasKey(utt)) {
236  KALDI_WARN << "Missing alignment for " << utt;
237  num_no_ref_ali++;
238  continue;
239  }
240 
241  // 1) get the features, numerator alignment,
242  const Matrix<BaseFloat> &mat = feature_reader.Value();
243  const std::vector<int32> &ref_ali = ref_ali_reader.Value(utt);
244  // check duration of numerator alignments,
245  if (static_cast<MatrixIndexT>(ref_ali.size()) != mat.NumRows()) {
246  KALDI_WARN << "Duration mismatch!"
247  << " alignment " << ref_ali.size()
248  << " features " << mat.NumRows();
249  num_other_error++;
250  continue;
251  }
252  if (mat.NumRows() > max_frames) {
253  KALDI_WARN << "Skipping " << utt
254  << " that has " << mat.NumRows() << " frames,"
255  << " it is longer than '--max-frames'" << max_frames;
256  num_other_error++;
257  continue;
258  }
259 
260  // 2) get the denominator lattice, preprocess
261  Lattice den_lat = den_lat_reader.Value(utt);
262  if (den_lat.Start() == -1) {
263  KALDI_WARN << "Empty lattice of " << utt << ", skipping.";
264  num_other_error++;
265  continue;
266  }
267  if (old_acoustic_scale != 1.0) {
268  fst::ScaleLattice(fst::AcousticLatticeScale(old_acoustic_scale),
269  &den_lat);
270  }
271  // optional sort it topologically
272  kaldi::uint64 props = den_lat.Properties(fst::kFstProperties, false);
273  if (!(props & fst::kTopSorted)) {
274  if (fst::TopSort(&den_lat) == false) {
275  KALDI_ERR << "Cycles detected in lattice.";
276  }
277  }
278  // get the lattice length and times of states
279  std::vector<int32> state_times;
280  int32 max_time = kaldi::LatticeStateTimes(den_lat, &state_times);
281  // check for temporal length of denominator lattices
282  if (max_time != mat.NumRows()) {
283  KALDI_WARN << "Duration mismatch!"
284  << " denominator lattice " << max_time
285  << " features " << mat.NumRows() << ","
286  << " skipping " << utt;
287  num_other_error++;
288  continue;
289  }
290 
291  // get dims,
292  int32 num_frames = mat.NumRows();
293 
294  // 3) get the pre-softmax outputs from NN,
295  // apply transform,
296  nnet_transf.Feedforward(CuMatrix<BaseFloat>(mat), &feats_transf);
297  // propagate through the nnet (we know it's w/o softmax),
298  nnet.Propagate(feats_transf, &nnet_out);
299  // subtract the log_prior,
300  if (prior_opts.class_frame_counts != "") {
301  log_prior.SubtractOnLogpost(&nnet_out);
302  }
303  // transfer it back to the host,
304  nnet_out_h = Matrix<BaseFloat>(nnet_out);
305  // release the buffers we don't need anymore
306  feats_transf.Resize(0, 0);
307  nnet_out.Resize(0, 0);
308 
309  // 4) rescore the latice
310  LatticeAcousticRescore(nnet_out_h, trans_model, state_times, &den_lat);
311  if (acoustic_scale != 1.0 || lm_scale != 1.0)
312  fst::ScaleLattice(fst::LatticeScale(lm_scale, acoustic_scale), &den_lat);
313 
314  kaldi::Posterior post;
315  if (do_smbr) {
316  // use state-level accuracies, i.e. sMBR estimation,
317  utt_frame_acc = LatticeForwardBackwardMpeVariants(
318  trans_model, silence_phones, den_lat, ref_ali, "smbr",
319  one_silence_class, &post);
320  } else {
321  // use phone-level accuracies, i.e. MPFE (minimum phone frame error),
322  utt_frame_acc = LatticeForwardBackwardMpeVariants(
323  trans_model, silence_phones, den_lat, ref_ali, "mpfe",
324  one_silence_class, &post);
325  }
326 
327  // 6) convert the Posterior to a matrix,
328  PosteriorToPdfMatrix(post, trans_model, &nnet_diff);
329  nnet_diff.Scale(-1.0); // need to flip the sign of derivative,
330 
331  KALDI_VLOG(1) << "Lattice #" << num_done + 1 << " processed"
332  << " (" << utt << "): found " << den_lat.NumStates()
333  << " states and " << fst::NumArcs(den_lat) << " arcs.";
334 
335  KALDI_VLOG(1) << "Utterance " << utt << ": Average frame accuracy = "
336  << (utt_frame_acc/num_frames) << " over " << num_frames
337  << " frames,"
338  << " diff-range(" << nnet_diff.Min() << ","
339  << nnet_diff.Max() << ")";
340 
341  // 7) backpropagate through the nnet, update,
342  nnet.Backpropagate(nnet_diff, NULL);
343  nnet_diff.Resize(0, 0); // release GPU memory,
344 
345  // increase time counter
346  total_frame_acc += utt_frame_acc;
347  total_frames += num_frames;
348  num_done++;
349 
350  if (num_done % 100 == 0) {
351  time_now = time.Elapsed();
352  KALDI_VLOG(1) << "After " << num_done << " utterances: "
353  << "time elapsed = " << time_now / 60 << " min; "
354  << "processed " << total_frames / time_now << " frames per sec.";
355 #if HAVE_CUDA == 1
356  // check that GPU computes accurately,
357  CuDevice::Instantiate().CheckGpuHealth();
358 #endif
359  }
360 
361  // GRADIENT LOGGING
362  // First utterance,
363  if (num_done == 1) {
364  KALDI_VLOG(1) << nnet.InfoPropagate();
365  KALDI_VLOG(1) << nnet.InfoBackPropagate();
366  KALDI_VLOG(1) << nnet.InfoGradient();
367  }
368  // Every 1000 utterances (--verbose=2),
369  if (GetVerboseLevel() >= 2) {
370  if (num_done % 1000 == 0) {
371  KALDI_VLOG(2) << nnet.InfoPropagate();
372  KALDI_VLOG(2) << nnet.InfoBackPropagate();
373  KALDI_VLOG(2) << nnet.InfoGradient();
374  }
375  }
376  } // main loop over utterances,
377 
378  // After last utterance,
379  KALDI_VLOG(1) << nnet.InfoPropagate();
380  KALDI_VLOG(1) << nnet.InfoBackPropagate();
381  KALDI_VLOG(1) << nnet.InfoGradient();
382 
383  // Add the softmax layer back before writing,
384  KALDI_LOG << "Appending the softmax " << target_model_filename;
385  nnet.AppendComponentPointer(new Softmax(nnet.OutputDim(), nnet.OutputDim()));
386  // Store the nnet,
387  nnet.Write(target_model_filename, binary);
388 
389  time_now = time.Elapsed();
390  KALDI_LOG << "TRAINING FINISHED; "
391  << "Time taken = " << time_now / 60 << " min; processed "
392  << total_frames / time_now << " frames per second.";
393 
394  KALDI_LOG << "Done " << num_done << " files, "
395  << num_no_ref_ali << " with no reference alignments, "
396  << num_no_den_lat << " with no lattices, "
397  << num_other_error << " with other errors.";
398 
399  KALDI_LOG << "Overall average frame-accuracy is "
400  << total_frame_acc / total_frames << " over "
401  << total_frames << " frames.";
402 
403 #if HAVE_CUDA == 1
404  CuDevice::Instantiate().PrintProfile();
405 #endif
406 
407  return 0;
408  } catch(const std::exception &e) {
409  std::cerr << e.what();
410  return -1;
411  }
412 }
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
void RemoveLastComponent()
Remove the last of the Components,.
Definition: nnet-nnet.cc:206
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void PosteriorToPdfMatrix(const Posterior &post, const TransitionModel &model, CuMatrix< Real > *mat)
Wrapper of PosteriorToMatrixMapped with CuMatrix argument.
Definition: nnet-utils.h:304
void AppendComponentPointer(Component *dynamically_allocated_comp)
Append Component* to 'this' instance of Nnet by a shallow copy ('this' instance of Nnet over-takes th...
Definition: nnet-nnet.cc:187
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
int32 OutputDim() const
Dimensionality of network outputs (posteriors | bn-features | etc.),.
Definition: nnet-nnet.cc:143
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420
void Scale(Real value)
Definition: cu-matrix.cc:608
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:64
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
void Min(const CuMatrixBase< Real > &A)
Do, elementwise, *this = min(*this, A).
Definition: cu-matrix.cc:725
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
Definition: stl-utils.h:39
const Component & GetLastComponent() const
LastComponent accessor,.
Definition: nnet-nnet.cc:161
double Elapsed()
Returns time in seconds.
Definition: timer.h:65
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:818
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Max(const CuMatrixBase< Real > &A)
Do, elementwise, *this = max(*this, A).
Definition: cu-matrix.cc:700
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:47
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
vector< vector< double > > AcousticLatticeScale(double acwt)
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:127
BaseFloat LatticeForwardBackwardMpeVariants(const TransitionModel &trans, const std::vector< int32 > &silence_phones, const Lattice &lat, const std::vector< int32 > &num_ali, std::string criterion, bool one_silence_class, Posterior *post)
This function implements either the MPFE (minimum phone frame error) or SMBR (state-level minimum bay...
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
Definition: nnet-nnet.cc:333
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:46
#define KALDI_WARN
Definition: kaldi-error.h:130
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
void ScaleLattice(const vector< vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
Definition: nnet-nnet.cc:367
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508
vector< vector< double > > LatticeScale(double lmwt, double acwt)
Arc::StateId NumArcs(const ExpandedFst< Arc > &fst)
Returns the total number of arcs in an FST.
#define KALDI_LOG
Definition: kaldi-error.h:133
void Register(OptionsItf *opts)
void LatticeAcousticRescore(const Matrix< BaseFloat > &log_like, const TransitionModel &trans_model, const std::vector< int32 > &state_times, Lattice *lat)