nnet-train-multistream.cc File Reference
#include <numeric>
#include "nnet/nnet-trnopts.h"
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-loss.h"
#include "nnet/nnet-randomizer.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "base/timer.h"
#include "cudamatrix/cu-device.h"
Include dependency graph for nnet-train-multistream.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Functions

bool ReadData (SequentialBaseFloatMatrixReader &feature_reader, RandomAccessPosteriorReader &target_reader, RandomAccessBaseFloatVectorReader &weights_reader, int32 length_tolerance, Matrix< BaseFloat > *feats, Posterior *targets, Vector< BaseFloat > *weights, int32 *num_no_tgt_mat, int32 *num_other_error)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 109 of file nnet-train-multistream.cc.

References Nnet::Backpropagate(), Timer::Elapsed(), Xent::Eval(), Mse::Eval(), Nnet::Feedforward(), ParseOptions::GetArg(), kaldi::GetVerboseLevel(), rnnlm::i, Nnet::Info(), Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), Nnet::InputDim(), KALDI_ERR, KALDI_LOG, KALDI_VLOG, kaldi::kSetZero, ParseOptions::NumArgs(), RandomAccessTableReader< Holder >::Open(), ParseOptions::PrintUsage(), Nnet::Propagate(), ParseOptions::Read(), Nnet::Read(), kaldi::ReadData(), LossOptions::Register(), NnetTrainOptions::Register(), ParseOptions::Register(), Xent::Report(), Mse::Report(), Xent::ReportPerClass(), Nnet::ResetStreams(), Vector< Real >::Resize(), Matrix< Real >::Resize(), MatrixBase< Real >::Row(), Nnet::SetDropoutRate(), Nnet::SetSeqLengths(), Nnet::SetTrainOptions(), and Nnet::Write().

109  {
110  using namespace kaldi;
111  using namespace kaldi::nnet1;
112  typedef kaldi::int32 int32;
113 
114  try {
115  const char *usage =
116  "Perform one iteration of Multi-stream training, truncated BPTT for LSTMs.\n"
117  "The training targets are pdf-posteriors, usually prepared by ali-to-post.\n"
118  "The updates are per-utterance.\n"
119  "\n"
120  "Usage: nnet-train-multistream [options] "
121  "<feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n"
122  "e.g.: nnet-train-lstm-streams scp:feature.scp ark:posterior.ark nnet.init nnet.iter1\n";
123 
124  ParseOptions po(usage);
125 
126  NnetTrainOptions trn_opts;
127  trn_opts.Register(&po);
128  LossOptions loss_opts;
129  loss_opts.Register(&po);
130 
131  bool binary = true;
132  po.Register("binary", &binary, "Write output in binary mode");
133 
134  bool crossvalidate = false;
135  po.Register("cross-validate", &crossvalidate,
136  "Perform cross-validation (don't back-propagate)");
137 
138  std::string feature_transform;
139  po.Register("feature-transform", &feature_transform,
140  "Feature transform in Nnet format");
141 
142  std::string objective_function = "xent";
143  po.Register("objective-function", &objective_function,
144  "Objective function : xent|mse");
145 
146  int32 length_tolerance = 5;
147  po.Register("length-tolerance", &length_tolerance,
148  "Allowed length difference of features/targets (frames)");
149 
150  std::string frame_weights;
151  po.Register("frame-weights", &frame_weights,
152  "Per-frame weights to scale gradients (frame selection/weighting).");
153 
154  int32 batch_size = 20;
155  po.Register("batch-size", &batch_size,
156  "Length of 'one stream' in the Multi-stream training");
157 
158  int32 num_streams = 4;
159  po.Register("num-streams", &num_streams,
160  "Number of streams in the Multi-stream training");
161 
162  bool dummy = false;
163  po.Register("randomize", &dummy, "Dummy option.");
164 
165  std::string use_gpu="yes";
166  po.Register("use-gpu", &use_gpu,
167  "yes|no|optional, only has effect if compiled with CUDA");
168 
169  po.Read(argc, argv);
170 
171  if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
172  po.PrintUsage();
173  exit(1);
174  }
175 
176  std::string feature_rspecifier = po.GetArg(1),
177  targets_rspecifier = po.GetArg(2),
178  model_filename = po.GetArg(3);
179 
180  std::string target_model_filename;
181  if (!crossvalidate) {
182  target_model_filename = po.GetArg(4);
183  }
184 
185  using namespace kaldi;
186  using namespace kaldi::nnet1;
187  typedef kaldi::int32 int32;
188 
189 #if HAVE_CUDA == 1
190  CuDevice::Instantiate().SelectGpuId(use_gpu);
191 #endif
192 
193  Nnet nnet_transf;
194  if (feature_transform != "") {
195  nnet_transf.Read(feature_transform);
196  }
197 
198  Nnet nnet;
199  nnet.Read(model_filename);
200  nnet.SetTrainOptions(trn_opts);
201 
202  if (crossvalidate) {
203  nnet_transf.SetDropoutRate(0.0);
204  nnet.SetDropoutRate(0.0);
205  }
206 
207  kaldi::int64 total_frames = 0;
208 
209  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
210  RandomAccessPosteriorReader target_reader(targets_rspecifier);
211  RandomAccessBaseFloatVectorReader weights_reader;
212  if (frame_weights != "") {
213  weights_reader.Open(frame_weights);
214  }
215 
216  Xent xent(loss_opts);
217  Mse mse(loss_opts);
218 
219  Timer time;
220  double time_gpu = 0;
221  KALDI_LOG << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
222  << " STARTED";
223 
224  int32 num_done = 0,
225  num_no_tgt_mat = 0,
226  num_other_error = 0;
227 
228  // book-keeping for multi-stream training,
229  std::vector<Matrix<BaseFloat> > feats_utt(num_streams);
230  std::vector<Posterior> labels_utt(num_streams);
231  std::vector<Vector<BaseFloat> > weights_utt(num_streams);
232  std::vector<int32> cursor_utt(num_streams); // 0 initialized,
233  std::vector<int32> new_utt_flags(num_streams);
234 
235  CuMatrix<BaseFloat> feats_transf, nnet_out, obj_diff;
236 
237  // MAIN LOOP,
238  while (1) {
239 
240  // Re-fill the streams, if needed,
241  new_utt_flags.assign(num_streams, 0); // set new-utterance flags to zero,
242  for (int s = 0; s < num_streams; s++) {
243  // Need a new utterance for stream 's'?
244  if (cursor_utt[s] >= feats_utt[s].NumRows()) {
245  Matrix<BaseFloat> feats;
246  Posterior targets;
247  Vector<BaseFloat> weights;
248  // get the data from readers,
249  if (ReadData(feature_reader, target_reader, weights_reader,
250  length_tolerance,
251  &feats, &targets, &weights,
252  &num_no_tgt_mat, &num_other_error)) {
253 
254  // input transform may contain splicing,
255  Timer t;
256  nnet_transf.Feedforward(CuMatrix<BaseFloat>(feats), &feats_transf);
257  time_gpu += t.Elapsed();
258 
259  /* Here we could do the 'targets_delay', BUT...
260  * It is better to do it by a <Splice> component!
261  *
262  * The prototype would look like this (6th frame becomes 1st frame, etc.):
263  * '<Splice> <InputDim> dim1 <OutputDim> dim1 <BuildVector> 5 </BuildVector>'
264  */
265 
266  // store,
267  feats_utt[s] = Matrix<BaseFloat>(feats_transf);
268  labels_utt[s] = targets;
269  weights_utt[s] = weights;
270  cursor_utt[s] = 0;
271  new_utt_flags[s] = 1;
272  }
273  }
274  }
275 
276  // End the training when 1st stream is empty
277  // (this avoids over-adaptation to last utterances),
278  size_t inactive_streams = 0;
279  for (int32 s = 0; s < num_streams; s++) {
280  if (feats_utt[s].NumRows() - cursor_utt[s] <= 0) {
281  inactive_streams += 1;
282  }
283  }
284  if (inactive_streams >= 1) {
285  KALDI_LOG << "No more data to re-fill one of the streams, end of the training!";
286  KALDI_LOG << "(remaining stubs of data are discarded, don't overtrain on them)";
287  break;
288  }
289 
290  // number of frames we'll pack as the streams,
291  std::vector<int32> frame_num_utt;
292 
293  // pack the parallel data,
294  Matrix<BaseFloat> feat_mat_host;
295  Posterior target_host;
296  Vector<BaseFloat> weight_host;
297  {
298  // Number of sequences (can have zero length),
299  int32 n_streams = num_streams;
300 
301  // Create the final feature matrix with 'interleaved feature-lines',
302  feat_mat_host.Resize(n_streams * batch_size, nnet.InputDim(), kSetZero);
303  target_host.resize(n_streams * batch_size);
304  weight_host.Resize(n_streams * batch_size, kSetZero);
305  frame_num_utt.resize(n_streams, 0);
306 
307  // we slice at the 'cursor' at most 'batch_size' frames,
308  for (int32 s = 0; s < n_streams; s++) {
309  int32 num_rows = std::max(0, feats_utt[s].NumRows() - cursor_utt[s]);
310  frame_num_utt[s] = std::min(batch_size, num_rows);
311  }
312 
313  // pack the data,
314  {
315  for (int32 s = 0; s < n_streams; s++) {
316  if (frame_num_utt[s] > 0) {
317  auto mat_tmp = feats_utt[s].RowRange(cursor_utt[s], frame_num_utt[s]);
318  for (int32 r = 0; r < frame_num_utt[s]; r++) {
319  feat_mat_host.Row(r*n_streams + s).CopyFromVec(mat_tmp.Row(r));
320  }
321  }
322  }
323 
324  for (int32 s = 0; s < n_streams; s++) {
325  for (int32 r = 0; r < frame_num_utt[s]; r++) {
326  target_host[r*n_streams + s] = labels_utt[s][cursor_utt[s] + r];
327  }
328  }
329 
330  // padded frames will keep initial zero-weight,
331  for (int32 s = 0; s < n_streams; s++) {
332  if (frame_num_utt[s] > 0) {
333  auto weight_tmp = weights_utt[s].Range(cursor_utt[s], frame_num_utt[s]);
334  for (int32 r = 0; r < frame_num_utt[s]; r++) {
335  weight_host(r*n_streams + s) = weight_tmp(r);
336  }
337  }
338  }
339  }
340 
341  // advance the cursors,
342  for (int32 s = 0; s < n_streams; s++) {
343  cursor_utt[s] += frame_num_utt[s];
344  }
345  }
346 
347  // pass the info about padding,
348  nnet.SetSeqLengths(frame_num_utt);
349 
350  // Show debug info,
351  if (GetVerboseLevel() >= 4) {
352  // cursors in the feature_matrices,
353  {
354  std::ostringstream os;
355  os << "[ ";
356  for (size_t i = 0; i < cursor_utt.size(); i++) {
357  os << cursor_utt[i] << " ";
358  }
359  os << "]";
360  KALDI_LOG << "cursor_utt[" << cursor_utt.size() << "]" << os.str();
361  }
362  // frames in the mini-batch,
363  {
364  std::ostringstream os;
365  os << "[ ";
366  for (size_t i = 0; i < frame_num_utt.size(); i++) {
367  os << frame_num_utt[i] << " ";
368  }
369  os << "]";
370  KALDI_LOG << "frame_num_utt[" << frame_num_utt.size() << "]" << os.str();
371  }
372  }
373 
374  Timer t;
375  // with new utterance we reset the history,
376  nnet.ResetStreams(new_utt_flags);
377 
378  // forward pass,
379  nnet.Propagate(CuMatrix<BaseFloat>(feat_mat_host), &nnet_out);
380 
381  // evaluate objective function we've chosen,
382  if (objective_function == "xent") {
383  xent.Eval(weight_host, nnet_out, target_host, &obj_diff);
384  } else if (objective_function == "mse") {
385  mse.Eval(weight_host, nnet_out, target_host, &obj_diff);
386  } else {
387  KALDI_ERR << "Unknown objective function code : "
388  << objective_function;
389  }
390 
391  if (!crossvalidate) {
392  // back-propagate, and do the update,
393  nnet.Backpropagate(obj_diff, NULL);
394  }
395  time_gpu += t.Elapsed();
396 
397  // 1st minibatch : show what happens in network,
398  if (total_frames == 0) {
399  KALDI_LOG << "### After " << total_frames << " frames,";
400  KALDI_LOG << nnet.Info();
401  KALDI_LOG << nnet.InfoPropagate();
402  if (!crossvalidate) {
403  KALDI_LOG << nnet.InfoBackPropagate();
404  KALDI_LOG << nnet.InfoGradient();
405  }
406  }
407 
408  kaldi::int64 tmp_frames = total_frames;
409 
410  num_done += std::accumulate(new_utt_flags.begin(), new_utt_flags.end(), 0);
411  total_frames += std::accumulate(frame_num_utt.begin(), frame_num_utt.end(), 0);
412 
413  // monitor the NN training (--verbose=2),
414  int32 F = 25000;
415  if (GetVerboseLevel() >= 2) {
416  // print every 25k frames,
417  if (tmp_frames / F != total_frames / F) {
418  KALDI_VLOG(2) << "### After " << total_frames << " frames,";
419  KALDI_VLOG(2) << nnet.Info();
420  KALDI_VLOG(2) << nnet.InfoPropagate();
421  if (!crossvalidate) {
422  KALDI_VLOG(2) << nnet.InfoBackPropagate();
423  KALDI_VLOG(2) << nnet.InfoGradient();
424  }
425  }
426  }
427  }
428 
429  // after last minibatch : show what happens in network,
430  KALDI_LOG << "### After " << total_frames << " frames,";
431  KALDI_LOG << nnet.Info();
432  KALDI_LOG << nnet.InfoPropagate();
433  if (!crossvalidate) {
434  KALDI_LOG << nnet.InfoBackPropagate();
435  KALDI_LOG << nnet.InfoGradient();
436  }
437 
438  if (!crossvalidate) {
439  nnet.Write(target_model_filename, binary);
440  }
441 
442  if (objective_function == "xent") {
443  KALDI_LOG << xent.ReportPerClass();
444  }
445 
446  KALDI_LOG << "Done " << num_done << " files, "
447  << num_no_tgt_mat << " with no tgt_mats, "
448  << num_other_error << " with other errors. "
449  << "[" << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
450  << ", " << time.Elapsed() / 60 << " min, processing "
451  << total_frames / time.Elapsed() << " frames per sec, "
452  << "GPU_time " << 100.*time_gpu/time.Elapsed() << "% ]";
453 
454  if (objective_function == "xent") {
455  KALDI_LOG << xent.Report();
456  } else if (objective_function == "mse") {
457  KALDI_LOG << mse.Report();
458  } else {
459  KALDI_ERR << "Unknown objective function code : " << objective_function;
460  }
461 
462 #if HAVE_CUDA == 1
463  CuDevice::Instantiate().PrintProfile();
464 #endif
465 
466  return 0;
467  } catch(const std::exception &e) {
468  std::cerr << e.what();
469  return -1;
470  }
471 }
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ResetStreams(const std::vector< int32 > &stream_reset_flag)
Reset streams in multi-stream training,.
Definition: nnet-nnet.cc:281
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
void SetSeqLengths(const std::vector< int32 > &sequence_lengths)
Set sequence length in LSTM multi-stream training,.
Definition: nnet-nnet.cc:291
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
int32 InputDim() const
Dimensionality on network input (input feature dim.),.
Definition: nnet-nnet.cc:148
bool Open(const std::string &rspecifier)
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:46
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
A class representing a vector.
Definition: kaldi-vector.h:406
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420
std::string Info() const
Create string with human readable description of the nnet,.
Definition: nnet-nnet.cc:386
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
void Register(OptionsItf *opts)
Definition: nnet-loss.h:45
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
bool ReadData(SequentialBaseFloatMatrixReader &feature_reader, RandomAccessPosteriorReader &target_reader, RandomAccessBaseFloatVectorReader &weights_reader, int32 length_tolerance, Matrix< BaseFloat > *feats, Posterior *targets, Vector< BaseFloat > *weights, int32 *num_no_tgt_mat, int32 *num_other_error)