All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
nnet-train-frmshuff.cc File Reference
Include dependency graph for nnet-train-frmshuff.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 29 of file nnet-train-frmshuff.cc.

References MatrixRandomizer::AddData(), VectorRandomizer::AddData(), StdVectorRandomizer< T >::AddData(), Nnet::Backpropagate(), CuMatrixBase< Real >::CopyRows(), VectorBase< Real >::Dim(), MatrixRandomizer::Done(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), Xent::Eval(), Mse::Eval(), MultiTaskLoss::Eval(), Nnet::Feedforward(), RandomizerMask::Generate(), ParseOptions::GetArg(), kaldi::GetVerboseLevel(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), MultiTaskLoss::InitFromString(), MatrixRandomizer::IsFull(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, kaldi::kCopyData, SequentialTableReader< Holder >::Key(), VectorBase< Real >::Min(), MatrixRandomizer::Next(), VectorRandomizer::Next(), StdVectorRandomizer< T >::Next(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumCols(), MatrixRandomizer::NumFrames(), MatrixBase< Real >::NumRows(), CuMatrixBase< Real >::NumRows(), RandomAccessTableReader< Holder >::Open(), ParseOptions::PrintUsage(), Nnet::Propagate(), MatrixRandomizer::Randomize(), VectorRandomizer::Randomize(), StdVectorRandomizer< T >::Randomize(), ParseOptions::Read(), Nnet::Read(), LossOptions::Register(), NnetTrainOptions::Register(), NnetDataRandomizerOptions::Register(), ParseOptions::Register(), Xent::Report(), Mse::Report(), MultiTaskLoss::Report(), Xent::ReportPerClass(), Timer::Reset(), Vector< Real >::Resize(), Matrix< Real >::Resize(), VectorBase< Real >::Scale(), VectorBase< Real >::Set(), Nnet::SetDropoutRate(), Nnet::SetTrainOptions(), MatrixRandomizer::Value(), VectorRandomizer::Value(), StdVectorRandomizer< T >::Value(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and Nnet::Write().

29  {
30  using namespace kaldi;
31  using namespace kaldi::nnet1;
32  typedef kaldi::int32 int32;
33 
34  try {
35  const char *usage =
36  "Perform one iteration (epoch) of Neural Network training with\n"
37  "mini-batch Stochastic Gradient Descent. The training targets\n"
38  "are usually pdf-posteriors, prepared by ali-to-post.\n"
39  "Usage: nnet-train-frmshuff [options] <feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n"
40  "e.g.: nnet-train-frmshuff scp:feats.scp ark:posterior.ark nnet.init nnet.iter1\n";
41 
42  ParseOptions po(usage);
43 
44  NnetTrainOptions trn_opts;
45  trn_opts.Register(&po);
47  rnd_opts.Register(&po);
48  LossOptions loss_opts;
49  loss_opts.Register(&po);
50 
51  bool binary = true;
52  po.Register("binary", &binary, "Write output in binary mode");
53 
54  bool crossvalidate = false;
55  po.Register("cross-validate", &crossvalidate,
56  "Perform cross-validation (don't back-propagate)");
57 
58  bool randomize = true;
59  po.Register("randomize", &randomize,
60  "Perform the frame-level shuffling within the Cache::");
61 
62  std::string feature_transform;
63  po.Register("feature-transform", &feature_transform,
64  "Feature transform in Nnet format");
65 
66  std::string objective_function = "xent";
67  po.Register("objective-function", &objective_function,
68  "Objective function : xent|mse|multitask");
69 
70  int32 max_frames = 360000;
71  po.Register("max-frames", &max_frames,
72  "Maximum number of frames an utterance can have (skipped if longer)");
73 
74  int32 length_tolerance = 5;
75  po.Register("length-tolerance", &length_tolerance,
76  "Allowed length mismatch of features/targets/weights "
77  "(in frames, we truncate to the shortest)");
78 
79  std::string frame_weights;
80  po.Register("frame-weights", &frame_weights,
81  "Per-frame weights, used to re-scale gradients.");
82 
83  std::string utt_weights;
84  po.Register("utt-weights", &utt_weights,
85  "Per-utterance weights, used to re-scale frame-weights.");
86 
87  std::string use_gpu="yes";
88  po.Register("use-gpu", &use_gpu,
89  "yes|no|optional, only has effect if compiled with CUDA");
90 
91  po.Read(argc, argv);
92 
93  if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
94  po.PrintUsage();
95  exit(1);
96  }
97 
98  std::string feature_rspecifier = po.GetArg(1),
99  targets_rspecifier = po.GetArg(2),
100  model_filename = po.GetArg(3);
101 
102  std::string target_model_filename;
103  if (!crossvalidate) {
104  target_model_filename = po.GetArg(4);
105  }
106 
107  using namespace kaldi;
108  using namespace kaldi::nnet1;
109  typedef kaldi::int32 int32;
110 
111 #if HAVE_CUDA == 1
112  CuDevice::Instantiate().SelectGpuId(use_gpu);
113 #endif
114 
115  Nnet nnet_transf;
116  if (feature_transform != "") {
117  nnet_transf.Read(feature_transform);
118  }
119 
120  Nnet nnet;
121  nnet.Read(model_filename);
122  nnet.SetTrainOptions(trn_opts);
123 
124  if (crossvalidate) {
125  nnet_transf.SetDropoutRate(0.0);
126  nnet.SetDropoutRate(0.0);
127  }
128 
129  kaldi::int64 total_frames = 0;
130 
131  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
132  RandomAccessPosteriorReader targets_reader(targets_rspecifier);
133  RandomAccessBaseFloatVectorReader weights_reader;
134  if (frame_weights != "") {
135  weights_reader.Open(frame_weights);
136  }
137  RandomAccessBaseFloatReader utt_weights_reader;
138  if (utt_weights != "") {
139  utt_weights_reader.Open(utt_weights);
140  }
141 
142  RandomizerMask randomizer_mask(rnd_opts);
143  MatrixRandomizer feature_randomizer(rnd_opts);
144  PosteriorRandomizer targets_randomizer(rnd_opts);
145  VectorRandomizer weights_randomizer(rnd_opts);
146 
147  Xent xent(loss_opts);
148  Mse mse(loss_opts);
149 
150  MultiTaskLoss multitask(loss_opts);
151  if (0 == objective_function.compare(0, 9, "multitask")) {
152  // objective_function contains something like :
153  // 'multitask,xent,2456,1.0,mse,440,0.001'
154  //
155  // the meaning is following:
156  // 'multitask,<type1>,<dim1>,<weight1>,...,<typeN>,<dimN>,<weightN>'
157  multitask.InitFromString(objective_function);
158  }
159 
160  CuMatrix<BaseFloat> feats_transf, nnet_out, obj_diff;
161 
162  Timer time, time_io;
163  KALDI_LOG << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
164  << " STARTED";
165 
166  int32 num_done = 0,
167  num_no_tgt_mat = 0,
168  num_other_error = 0;
169 
170  double time_io_accu = 0.0;
171 
172  // main loop,
173  while (!feature_reader.Done()) {
174 #if HAVE_CUDA == 1
175  // check that GPU computes accurately,
176  CuDevice::Instantiate().CheckGpuHealth();
177 #endif
178  // fill the randomizer,
179  time_io.Reset();
180  for ( ; !feature_reader.Done(); feature_reader.Next()) {
181  if (feature_randomizer.IsFull()) {
182  // break the loop without calling Next(),
183  // we keep the 'utt' for next round,
184  break;
185  }
186  std::string utt = feature_reader.Key();
187  KALDI_VLOG(3) << "Reading " << utt;
188  // check that we have targets,
189  if (!targets_reader.HasKey(utt)) {
190  KALDI_WARN << utt << ", missing targets";
191  num_no_tgt_mat++;
192  continue;
193  }
194  // check we have per-frame weights,
195  if (frame_weights != "" && !weights_reader.HasKey(utt)) {
196  KALDI_WARN << utt << ", missing per-frame weights";
197  num_other_error++;
198  continue;
199  }
200  // check we have per-utterance weights,
201  if (utt_weights != "" && !utt_weights_reader.HasKey(utt)) {
202  KALDI_WARN << utt << ", missing per-utterance weight";
203  num_other_error++;
204  continue;
205  }
206  // get feature / target pair,
207  Matrix<BaseFloat> mat = feature_reader.Value();
208  Posterior targets = targets_reader.Value(utt);
209  // get per-frame weights,
210  Vector<BaseFloat> weights;
211  if (frame_weights != "") {
212  weights = weights_reader.Value(utt);
213  } else { // all per-frame weights are 1.0,
214  weights.Resize(mat.NumRows());
215  weights.Set(1.0);
216  }
217  // multiply with per-utterance weight,
218  if (utt_weights != "") {
219  BaseFloat w = utt_weights_reader.Value(utt);
220  KALDI_ASSERT(w >= 0.0);
221  if (w == 0.0) continue; // remove sentence from training,
222  weights.Scale(w);
223  }
224 
225  // accumulate the I/O time,
226  time_io_accu += time_io.Elapsed();
227  time_io.Reset(); // to be sure we don't count 2x,
228 
229  // skip too long utterances (or we run out of memory),
230  if (mat.NumRows() > max_frames) {
231  KALDI_WARN << "Utterance too long, skipping! " << utt
232  << " (length " << mat.NumRows() << ", max_frames "
233  << max_frames << ")";
234  num_other_error++;
235  continue;
236  }
237 
238  // correct small length mismatch or drop sentence,
239  {
240  // add lengths to vector,
241  std::vector<int32> length;
242  length.push_back(mat.NumRows());
243  length.push_back(targets.size());
244  length.push_back(weights.Dim());
245  // find min, max,
246  int32 min = *std::min_element(length.begin(), length.end());
247  int32 max = *std::max_element(length.begin(), length.end());
248  // fix or drop ?
249  if (max - min < length_tolerance) {
250  // we truncate to shortest,
251  if (mat.NumRows() != min) mat.Resize(min, mat.NumCols(), kCopyData);
252  if (targets.size() != min) targets.resize(min);
253  if (weights.Dim() != min) weights.Resize(min, kCopyData);
254  } else {
255  KALDI_WARN << "Length mismatch! Targets " << targets.size()
256  << ", features " << mat.NumRows() << ", " << utt;
257  num_other_error++;
258  continue;
259  }
260  }
261  // apply feature transform (if empty, input is copied),
262  nnet_transf.Feedforward(CuMatrix<BaseFloat>(mat), &feats_transf);
263 
264  // remove frames with '0' weight from training,
265  {
266  // are there any frames to be removed? (frames with zero weight),
267  BaseFloat weight_min = weights.Min();
268  KALDI_ASSERT(weight_min >= 0.0);
269  if (weight_min == 0.0) {
270  // create vector with frame-indices to keep,
271  std::vector<MatrixIndexT> keep_frames;
272  for (int32 i = 0; i < weights.Dim(); i++) {
273  if (weights(i) > 0.0) {
274  keep_frames.push_back(i);
275  }
276  }
277 
278  // when all frames are removed, we skip the sentence,
279  if (keep_frames.size() == 0) continue;
280 
281  // filter feature-frames,
282  CuMatrix<BaseFloat> tmp_feats(keep_frames.size(), feats_transf.NumCols());
283  tmp_feats.CopyRows(feats_transf, CuArray<MatrixIndexT>(keep_frames));
284  tmp_feats.Swap(&feats_transf);
285 
286  // filter targets,
287  Posterior tmp_targets;
288  for (int32 i = 0; i < keep_frames.size(); i++) {
289  tmp_targets.push_back(targets[keep_frames[i]]);
290  }
291  tmp_targets.swap(targets);
292 
293  // filter weights,
294  Vector<BaseFloat> tmp_weights(keep_frames.size());
295  for (int32 i = 0; i < keep_frames.size(); i++) {
296  tmp_weights(i) = weights(keep_frames[i]);
297  }
298  tmp_weights.Swap(&weights);
299  }
300  }
301 
302  // pass data to randomizers,
303  KALDI_ASSERT(feats_transf.NumRows() == targets.size());
304  feature_randomizer.AddData(feats_transf);
305  targets_randomizer.AddData(targets);
306  weights_randomizer.AddData(weights);
307  num_done++;
308 
309  time_io.Reset(); // reset before reading next feature matrix,
310  }
311 
312  // randomize,
313  if (!crossvalidate && randomize) {
314  const std::vector<int32>& mask =
315  randomizer_mask.Generate(feature_randomizer.NumFrames());
316  feature_randomizer.Randomize(mask);
317  targets_randomizer.Randomize(mask);
318  weights_randomizer.Randomize(mask);
319  }
320 
321  // train with data from randomizers (using mini-batches),
322  for ( ; !feature_randomizer.Done(); feature_randomizer.Next(),
323  targets_randomizer.Next(),
324  weights_randomizer.Next()) {
325  // get block of feature/target pairs,
326  const CuMatrixBase<BaseFloat>& nnet_in = feature_randomizer.Value();
327  const Posterior& nnet_tgt = targets_randomizer.Value();
328  const Vector<BaseFloat>& frm_weights = weights_randomizer.Value();
329 
330  // forward pass,
331  nnet.Propagate(nnet_in, &nnet_out);
332 
333  // evaluate objective function we've chosen,
334  if (objective_function == "xent") {
335  // gradients re-scaled by weights in Eval,
336  xent.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
337  } else if (objective_function == "mse") {
338  // gradients re-scaled by weights in Eval,
339  mse.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
340  } else if (0 == objective_function.compare(0, 9, "multitask")) {
341  // gradients re-scaled by weights in Eval,
342  multitask.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
343  } else {
344  KALDI_ERR << "Unknown objective function code : " << objective_function;
345  }
346 
347  if (!crossvalidate) {
348  // back-propagate, and do the update,
349  nnet.Backpropagate(obj_diff, NULL);
350  }
351 
352  // 1st mini-batch : show what happens in network,
353  if (total_frames == 0) {
354  KALDI_LOG << "### After " << total_frames << " frames,";
355  KALDI_LOG << nnet.InfoPropagate();
356  if (!crossvalidate) {
357  KALDI_LOG << nnet.InfoBackPropagate();
358  KALDI_LOG << nnet.InfoGradient();
359  }
360  }
361 
362  // VERBOSE LOG
363  // monitor the NN training (--verbose=2),
364  if (GetVerboseLevel() >= 2) {
365  static int32 counter = 0;
366  counter += nnet_in.NumRows();
367  // print every 25k frames,
368  if (counter >= 25000) {
369  KALDI_VLOG(2) << "### After " << total_frames << " frames,";
370  KALDI_VLOG(2) << nnet.InfoPropagate();
371  if (!crossvalidate) {
372  KALDI_VLOG(2) << nnet.InfoBackPropagate();
373  KALDI_VLOG(2) << nnet.InfoGradient();
374  }
375  counter = 0;
376  }
377  }
378 
379  total_frames += nnet_in.NumRows();
380  }
381  } // main loop,
382 
383  // after last mini-batch : show what happens in network,
384  KALDI_LOG << "### After " << total_frames << " frames,";
385  KALDI_LOG << nnet.InfoPropagate();
386  if (!crossvalidate) {
387  KALDI_LOG << nnet.InfoBackPropagate();
388  KALDI_LOG << nnet.InfoGradient();
389  }
390 
391  if (!crossvalidate) {
392  nnet.Write(target_model_filename, binary);
393  }
394 
395  KALDI_LOG << "Done " << num_done << " files, "
396  << num_no_tgt_mat << " with no tgt_mats, "
397  << num_other_error << " with other errors. "
398  << "[" << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
399  << ", " << (randomize ? "RANDOMIZED" : "NOT-RANDOMIZED")
400  << ", " << time.Elapsed() / 60 << " min, processing "
401  << total_frames / time.Elapsed() << " frames per sec;"
402  << " i/o time " << 100.*time_io_accu/time.Elapsed() << "%]";
403 
404  if (objective_function == "xent") {
405  KALDI_LOG << xent.ReportPerClass();
406  KALDI_LOG << xent.Report();
407  } else if (objective_function == "mse") {
408  KALDI_LOG << mse.Report();
409  } else if (0 == objective_function.compare(0, 9, "multitask")) {
410  KALDI_LOG << multitask.Report();
411  } else {
412  KALDI_ERR << "Unknown objective function code : " << objective_function;
413  }
414 
415 #if HAVE_CUDA == 1
416  CuDevice::Instantiate().PrintProfile();
417 #endif
418 
419  return 0;
420  } catch(const std::exception &e) {
421  std::cerr << e.what();
422  return -1;
423  }
424 }
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420
void Reset()
Definition: timer.h:71
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
Generates randomly ordered vector of indices,.
bool Open(const std::string &rspecifier)
Randomizes elements of a vector according to a mask.
MatrixIndexT NumCols() const
Definition: cu-matrix.h:215
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
Real Min() const
Returns the minimum value of any element, or +infinity for the empty vector.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
Class CuArray represents a vector of an integer or struct of type T.
Definition: cu-array.h:32
Configuration variables that affect how frame-level shuffling is done.
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:214
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:127
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
Definition: nnet-nnet.cc:333
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:46
#define KALDI_WARN
Definition: kaldi-error.h:130
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
void Scale(Real alpha)
Multiplies all elements by this constant.
bool HasKey(const std::string &key)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
Shuffles rows of a matrix according to the indices in the mask,.
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:64
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void Set(Real f)
Set all members of a vector to a specified value.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
Definition: nnet-nnet.cc:367
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
void CopyRows(const CuMatrixBase< Real > &src, const CuArrayBase< MatrixIndexT > &indexes)
Copies row r from row indexes[r] of src.
Definition: cu-matrix.cc:2645
Randomizes elements of a vector according to a mask.
void Register(OptionsItf *opts)
Definition: nnet-loss.h:45
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
#define KALDI_LOG
Definition: kaldi-error.h:133
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:63