All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
nnet-train-frmshuff.cc File Reference
Include dependency graph for nnet-train-frmshuff.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 29 of file nnet-train-frmshuff.cc.

References MatrixRandomizer::AddData(), VectorRandomizer::AddData(), StdVectorRandomizer< T >::AddData(), Nnet::Backpropagate(), CuMatrixBase< Real >::CopyRows(), VectorBase< Real >::Dim(), MatrixRandomizer::Done(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), Xent::Eval(), Mse::Eval(), MultiTaskLoss::Eval(), Nnet::Feedforward(), RandomizerMask::Generate(), ParseOptions::GetArg(), kaldi::GetVerboseLevel(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), MultiTaskLoss::InitFromString(), MatrixRandomizer::IsFull(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, kaldi::kCopyData, SequentialTableReader< Holder >::Key(), VectorBase< Real >::Min(), MatrixRandomizer::Next(), VectorRandomizer::Next(), StdVectorRandomizer< T >::Next(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumCols(), MatrixRandomizer::NumFrames(), MatrixBase< Real >::NumRows(), CuMatrixBase< Real >::NumRows(), RandomAccessTableReader< Holder >::Open(), ParseOptions::PrintUsage(), Nnet::Propagate(), MatrixRandomizer::Randomize(), VectorRandomizer::Randomize(), StdVectorRandomizer< T >::Randomize(), ParseOptions::Read(), Nnet::Read(), LossOptions::Register(), NnetTrainOptions::Register(), NnetDataRandomizerOptions::Register(), ParseOptions::Register(), Xent::Report(), Mse::Report(), MultiTaskLoss::Report(), Xent::ReportPerClass(), Vector< Real >::Resize(), Matrix< Real >::Resize(), VectorBase< Real >::Scale(), VectorBase< Real >::Set(), Nnet::SetDropoutRate(), Nnet::SetTrainOptions(), MatrixRandomizer::Value(), VectorRandomizer::Value(), StdVectorRandomizer< T >::Value(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and Nnet::Write().

29  {
30  using namespace kaldi;
31  using namespace kaldi::nnet1;
32  typedef kaldi::int32 int32;
33 
34  try {
35  const char *usage =
36  "Perform one iteration (epoch) of Neural Network training with\n"
37  "mini-batch Stochastic Gradient Descent. The training targets\n"
38  "are usually pdf-posteriors, prepared by ali-to-post.\n"
39  "Usage: nnet-train-frmshuff [options] <feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n"
40  "e.g.: nnet-train-frmshuff scp:feats.scp ark:posterior.ark nnet.init nnet.iter1\n";
41 
42  ParseOptions po(usage);
43 
44  NnetTrainOptions trn_opts;
45  trn_opts.Register(&po);
47  rnd_opts.Register(&po);
48  LossOptions loss_opts;
49  loss_opts.Register(&po);
50 
51  bool binary = true;
52  po.Register("binary", &binary, "Write output in binary mode");
53 
54  bool crossvalidate = false;
55  po.Register("cross-validate", &crossvalidate,
56  "Perform cross-validation (don't back-propagate)");
57 
58  bool randomize = true;
59  po.Register("randomize", &randomize,
60  "Perform the frame-level shuffling within the Cache::");
61 
62  std::string feature_transform;
63  po.Register("feature-transform", &feature_transform,
64  "Feature transform in Nnet format");
65 
66  std::string objective_function = "xent";
67  po.Register("objective-function", &objective_function,
68  "Objective function : xent|mse|multitask");
69 
70  int32 max_frames = 360000;
71  po.Register("max-frames", &max_frames,
72  "Maximum number of frames an utterance can have (skipped if longer)");
73 
74  int32 length_tolerance = 5;
75  po.Register("length-tolerance", &length_tolerance,
76  "Allowed length mismatch of features/targets/weights "
77  "(in frames, we truncate to the shortest)");
78 
79  std::string frame_weights;
80  po.Register("frame-weights", &frame_weights,
81  "Per-frame weights, used to re-scale gradients.");
82 
83  std::string utt_weights;
84  po.Register("utt-weights", &utt_weights,
85  "Per-utterance weights, used to re-scale frame-weights.");
86 
87  std::string use_gpu="yes";
88  po.Register("use-gpu", &use_gpu,
89  "yes|no|optional, only has effect if compiled with CUDA");
90 
91  po.Read(argc, argv);
92 
93  if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
94  po.PrintUsage();
95  exit(1);
96  }
97 
98  std::string feature_rspecifier = po.GetArg(1),
99  targets_rspecifier = po.GetArg(2),
100  model_filename = po.GetArg(3);
101 
102  std::string target_model_filename;
103  if (!crossvalidate) {
104  target_model_filename = po.GetArg(4);
105  }
106 
107  using namespace kaldi;
108  using namespace kaldi::nnet1;
109  typedef kaldi::int32 int32;
110 
111 #if HAVE_CUDA == 1
112  CuDevice::Instantiate().SelectGpuId(use_gpu);
113 #endif
114 
115  Nnet nnet_transf;
116  if (feature_transform != "") {
117  nnet_transf.Read(feature_transform);
118  }
119 
120  Nnet nnet;
121  nnet.Read(model_filename);
122  nnet.SetTrainOptions(trn_opts);
123 
124  if (crossvalidate) {
125  nnet_transf.SetDropoutRate(0.0);
126  nnet.SetDropoutRate(0.0);
127  }
128 
129  kaldi::int64 total_frames = 0;
130 
131  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
132  RandomAccessPosteriorReader targets_reader(targets_rspecifier);
133  RandomAccessBaseFloatVectorReader weights_reader;
134  if (frame_weights != "") {
135  weights_reader.Open(frame_weights);
136  }
137  RandomAccessBaseFloatReader utt_weights_reader;
138  if (utt_weights != "") {
139  utt_weights_reader.Open(utt_weights);
140  }
141 
142  RandomizerMask randomizer_mask(rnd_opts);
143  MatrixRandomizer feature_randomizer(rnd_opts);
144  PosteriorRandomizer targets_randomizer(rnd_opts);
145  VectorRandomizer weights_randomizer(rnd_opts);
146 
147  Xent xent(loss_opts);
148  Mse mse(loss_opts);
149 
150  MultiTaskLoss multitask(loss_opts);
151  if (0 == objective_function.compare(0, 9, "multitask")) {
152  // objective_function contains something like :
153  // 'multitask,xent,2456,1.0,mse,440,0.001'
154  //
155  // the meaning is following:
156  // 'multitask,<type1>,<dim1>,<weight1>,...,<typeN>,<dimN>,<weightN>'
157  multitask.InitFromString(objective_function);
158  }
159 
160  CuMatrix<BaseFloat> feats_transf, nnet_out, obj_diff;
161 
162  Timer time;
163  KALDI_LOG << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
164  << " STARTED";
165 
166  int32 num_done = 0,
167  num_no_tgt_mat = 0,
168  num_other_error = 0;
169 
170  // main loop,
171  while (!feature_reader.Done()) {
172 #if HAVE_CUDA == 1
173  // check that GPU computes accurately,
174  CuDevice::Instantiate().CheckGpuHealth();
175 #endif
176  // fill the randomizer,
177  for ( ; !feature_reader.Done(); feature_reader.Next()) {
178  if (feature_randomizer.IsFull()) {
179  // break the loop without calling Next(),
180  // we keep the 'utt' for next round,
181  break;
182  }
183  std::string utt = feature_reader.Key();
184  KALDI_VLOG(3) << "Reading " << utt;
185  // check that we have targets,
186  if (!targets_reader.HasKey(utt)) {
187  KALDI_WARN << utt << ", missing targets";
188  num_no_tgt_mat++;
189  continue;
190  }
191  // check we have per-frame weights,
192  if (frame_weights != "" && !weights_reader.HasKey(utt)) {
193  KALDI_WARN << utt << ", missing per-frame weights";
194  num_other_error++;
195  continue;
196  }
197  // check we have per-utterance weights,
198  if (utt_weights != "" && !utt_weights_reader.HasKey(utt)) {
199  KALDI_WARN << utt << ", missing per-utterance weight";
200  num_other_error++;
201  continue;
202  }
203  // get feature / target pair,
204  Matrix<BaseFloat> mat = feature_reader.Value();
205  Posterior targets = targets_reader.Value(utt);
206  // get per-frame weights,
207  Vector<BaseFloat> weights;
208  if (frame_weights != "") {
209  weights = weights_reader.Value(utt);
210  } else { // all per-frame weights are 1.0,
211  weights.Resize(mat.NumRows());
212  weights.Set(1.0);
213  }
214  // multiply with per-utterance weight,
215  if (utt_weights != "") {
216  BaseFloat w = utt_weights_reader.Value(utt);
217  KALDI_ASSERT(w >= 0.0);
218  if (w == 0.0) continue; // remove sentence from training,
219  weights.Scale(w);
220  }
221 
222  // skip too long utterances (or we run out of memory),
223  if (mat.NumRows() > max_frames) {
224  KALDI_WARN << "Utterance too long, skipping! " << utt
225  << " (length " << mat.NumRows() << ", max_frames "
226  << max_frames << ")";
227  num_other_error++;
228  continue;
229  }
230 
231  // correct small length mismatch or drop sentence,
232  {
233  // add lengths to vector,
234  std::vector<int32> length;
235  length.push_back(mat.NumRows());
236  length.push_back(targets.size());
237  length.push_back(weights.Dim());
238  // find min, max,
239  int32 min = *std::min_element(length.begin(), length.end());
240  int32 max = *std::max_element(length.begin(), length.end());
241  // fix or drop ?
242  if (max - min < length_tolerance) {
243  // we truncate to shortest,
244  if (mat.NumRows() != min) mat.Resize(min, mat.NumCols(), kCopyData);
245  if (targets.size() != min) targets.resize(min);
246  if (weights.Dim() != min) weights.Resize(min, kCopyData);
247  } else {
248  KALDI_WARN << "Length mismatch! Targets " << targets.size()
249  << ", features " << mat.NumRows() << ", " << utt;
250  num_other_error++;
251  continue;
252  }
253  }
254  // apply feature transform (if empty, input is copied),
255  nnet_transf.Feedforward(CuMatrix<BaseFloat>(mat), &feats_transf);
256 
257  // remove frames with '0' weight from training,
258  {
259  // are there any frames to be removed? (frames with zero weight),
260  BaseFloat weight_min = weights.Min();
261  KALDI_ASSERT(weight_min >= 0.0);
262  if (weight_min == 0.0) {
263  // create vector with frame-indices to keep,
264  std::vector<MatrixIndexT> keep_frames;
265  for (int32 i = 0; i < weights.Dim(); i++) {
266  if (weights(i) > 0.0) {
267  keep_frames.push_back(i);
268  }
269  }
270 
271  // when all frames are removed, we skip the sentence,
272  if (keep_frames.size() == 0) continue;
273 
274  // filter feature-frames,
275  CuMatrix<BaseFloat> tmp_feats(keep_frames.size(), feats_transf.NumCols());
276  tmp_feats.CopyRows(feats_transf, CuArray<MatrixIndexT>(keep_frames));
277  tmp_feats.Swap(&feats_transf);
278 
279  // filter targets,
280  Posterior tmp_targets;
281  for (int32 i = 0; i < keep_frames.size(); i++) {
282  tmp_targets.push_back(targets[keep_frames[i]]);
283  }
284  tmp_targets.swap(targets);
285 
286  // filter weights,
287  Vector<BaseFloat> tmp_weights(keep_frames.size());
288  for (int32 i = 0; i < keep_frames.size(); i++) {
289  tmp_weights(i) = weights(keep_frames[i]);
290  }
291  tmp_weights.Swap(&weights);
292  }
293  }
294 
295  // pass data to randomizers,
296  KALDI_ASSERT(feats_transf.NumRows() == targets.size());
297  feature_randomizer.AddData(feats_transf);
298  targets_randomizer.AddData(targets);
299  weights_randomizer.AddData(weights);
300  num_done++;
301 
302  // report the speed,
303  if (num_done % 5000 == 0) {
304  double time_now = time.Elapsed();
305  KALDI_VLOG(1) << "After " << num_done << " utterances: "
306  << "time elapsed = " << time_now / 60 << " min; "
307  << "processed " << total_frames / time_now << " frames per sec.";
308  }
309  }
310 
311  // randomize,
312  if (!crossvalidate && randomize) {
313  const std::vector<int32>& mask =
314  randomizer_mask.Generate(feature_randomizer.NumFrames());
315  feature_randomizer.Randomize(mask);
316  targets_randomizer.Randomize(mask);
317  weights_randomizer.Randomize(mask);
318  }
319 
320  // train with data from randomizers (using mini-batches),
321  for ( ; !feature_randomizer.Done(); feature_randomizer.Next(),
322  targets_randomizer.Next(),
323  weights_randomizer.Next()) {
324  // get block of feature/target pairs,
325  const CuMatrixBase<BaseFloat>& nnet_in = feature_randomizer.Value();
326  const Posterior& nnet_tgt = targets_randomizer.Value();
327  const Vector<BaseFloat>& frm_weights = weights_randomizer.Value();
328 
329  // forward pass,
330  nnet.Propagate(nnet_in, &nnet_out);
331 
332  // evaluate objective function we've chosen,
333  if (objective_function == "xent") {
334  // gradients re-scaled by weights in Eval,
335  xent.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
336  } else if (objective_function == "mse") {
337  // gradients re-scaled by weights in Eval,
338  mse.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
339  } else if (0 == objective_function.compare(0, 9, "multitask")) {
340  // gradients re-scaled by weights in Eval,
341  multitask.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
342  } else {
343  KALDI_ERR << "Unknown objective function code : " << objective_function;
344  }
345 
346  if (!crossvalidate) {
347  // back-propagate, and do the update,
348  nnet.Backpropagate(obj_diff, NULL);
349  }
350 
351  // 1st mini-batch : show what happens in network,
352  if (total_frames == 0) {
353  KALDI_VLOG(1) << "### After " << total_frames << " frames,";
354  KALDI_VLOG(1) << nnet.InfoPropagate();
355  if (!crossvalidate) {
356  KALDI_VLOG(1) << nnet.InfoBackPropagate();
357  KALDI_VLOG(1) << nnet.InfoGradient();
358  }
359  }
360 
361  // VERBOSE LOG
362  // monitor the NN training (--verbose=2),
363  if (GetVerboseLevel() >= 2) {
364  static int32 counter = 0;
365  counter += nnet_in.NumRows();
366  // print every 25k frames,
367  if (counter >= 25000) {
368  KALDI_VLOG(2) << "### After " << total_frames << " frames,";
369  KALDI_VLOG(2) << nnet.InfoPropagate();
370  if (!crossvalidate) {
371  KALDI_VLOG(2) << nnet.InfoBackPropagate();
372  KALDI_VLOG(2) << nnet.InfoGradient();
373  }
374  counter = 0;
375  }
376  }
377 
378  total_frames += nnet_in.NumRows();
379  }
380  } // main loop,
381 
382  // after last mini-batch : show what happens in network,
383  KALDI_VLOG(1) << "### After " << total_frames << " frames,";
384  KALDI_VLOG(1) << nnet.InfoPropagate();
385  if (!crossvalidate) {
386  KALDI_VLOG(1) << nnet.InfoBackPropagate();
387  KALDI_VLOG(1) << nnet.InfoGradient();
388  }
389 
390  if (!crossvalidate) {
391  nnet.Write(target_model_filename, binary);
392  }
393 
394  KALDI_LOG << "Done " << num_done << " files, "
395  << num_no_tgt_mat << " with no tgt_mats, "
396  << num_other_error << " with other errors. "
397  << "[" << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
398  << ", " << (randomize ? "RANDOMIZED" : "NOT-RANDOMIZED")
399  << ", " << time.Elapsed() / 60 << " min, processing "
400  << total_frames / time.Elapsed() << " frames per sec.]";
401 
402  if (objective_function == "xent") {
403  KALDI_LOG << xent.ReportPerClass();
404  KALDI_LOG << xent.Report();
405  } else if (objective_function == "mse") {
406  KALDI_LOG << mse.Report();
407  } else if (0 == objective_function.compare(0, 9, "multitask")) {
408  KALDI_LOG << multitask.Report();
409  } else {
410  KALDI_ERR << "Unknown objective function code : " << objective_function;
411  }
412 
413 #if HAVE_CUDA == 1
414  CuDevice::Instantiate().PrintProfile();
415 #endif
416 
417  return 0;
418  } catch(const std::exception &e) {
419  std::cerr << e.what();
420  return -1;
421  }
422 }
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
Generates randomly ordered vector of indices,.
bool Open(const std::string &rspecifier)
Randomizes elements of a vector according to a mask.
MatrixIndexT NumCols() const
Definition: cu-matrix.h:206
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
Real Min() const
Returns the minimum value of any element, or +infinity for the empty vector.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
Class CuArray represents a vector of an integer or struct of type T.
Definition: cu-array.h:32
Configuration variables that affect how frame-level shuffling is done.
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:205
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:127
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
Definition: nnet-nnet.cc:333
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:46
#define KALDI_WARN
Definition: kaldi-error.h:130
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
void Scale(Real alpha)
Multiplies all elements by this constant.
bool HasKey(const std::string &key)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
Shuffles rows of a matrix according to the indices in the mask,.
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:64
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void Set(Real f)
Set all members of a vector to a specified value.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
Definition: nnet-nnet.cc:367
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
void CopyRows(const CuMatrixBase< Real > &src, const CuArrayBase< MatrixIndexT > &indexes)
Copies row r from row indexes[r] of src.
Definition: cu-matrix.cc:2614
Randomizes elements of a vector according to a mask.
void Register(OptionsItf *opts)
Definition: nnet-loss.h:45
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
#define KALDI_LOG
Definition: kaldi-error.h:133
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:63