All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
nnet-train-frmshuff.cc File Reference
Include dependency graph for nnet-train-frmshuff.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 29 of file nnet-train-frmshuff.cc.

References MatrixRandomizer::AddData(), VectorRandomizer::AddData(), StdVectorRandomizer< T >::AddData(), Nnet::Backpropagate(), CuMatrixBase< Real >::CopyRows(), VectorBase< Real >::Dim(), MatrixRandomizer::Done(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), Xent::Eval(), Mse::Eval(), MultiTaskLoss::Eval(), Nnet::Feedforward(), RandomizerMask::Generate(), ParseOptions::GetArg(), kaldi::GetVerboseLevel(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), MultiTaskLoss::InitFromString(), MatrixRandomizer::IsFull(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, kaldi::kCopyData, SequentialTableReader< Holder >::Key(), VectorBase< Real >::Min(), MatrixRandomizer::Next(), VectorRandomizer::Next(), StdVectorRandomizer< T >::Next(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumCols(), MatrixRandomizer::NumFrames(), MatrixBase< Real >::NumRows(), CuMatrixBase< Real >::NumRows(), RandomAccessTableReader< Holder >::Open(), ParseOptions::PrintUsage(), Nnet::Propagate(), MatrixRandomizer::Randomize(), VectorRandomizer::Randomize(), StdVectorRandomizer< T >::Randomize(), ParseOptions::Read(), Nnet::Read(), NnetTrainOptions::Register(), NnetDataRandomizerOptions::Register(), ParseOptions::Register(), Xent::Report(), Mse::Report(), MultiTaskLoss::Report(), Xent::ReportPerClass(), Vector< Real >::Resize(), Matrix< Real >::Resize(), VectorBase< Real >::Scale(), VectorBase< Real >::Set(), Nnet::SetDropoutRate(), Nnet::SetTrainOptions(), MatrixRandomizer::Value(), VectorRandomizer::Value(), StdVectorRandomizer< T >::Value(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and Nnet::Write().

29  {
30  using namespace kaldi;
31  using namespace kaldi::nnet1;
32  typedef kaldi::int32 int32;
33 
34  try {
35  const char *usage =
36  "Perform one iteration (epoch) of Neural Network training with\n"
37  "mini-batch Stochastic Gradient Descent. The training targets\n"
38  "are usually pdf-posteriors, prepared by ali-to-post.\n"
39  "Usage: nnet-train-frmshuff [options] <feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n"
40  "e.g.: nnet-train-frmshuff scp:feats.scp ark:posterior.ark nnet.init nnet.iter1\n";
41 
42  ParseOptions po(usage);
43 
44  NnetTrainOptions trn_opts;
45  trn_opts.Register(&po);
47  rnd_opts.Register(&po);
48 
49  bool binary = true;
50  po.Register("binary", &binary, "Write output in binary mode");
51 
52  bool crossvalidate = false;
53  po.Register("cross-validate", &crossvalidate,
54  "Perform cross-validation (don't back-propagate)");
55 
56  bool randomize = true;
57  po.Register("randomize", &randomize,
58  "Perform the frame-level shuffling within the Cache::");
59 
60  std::string feature_transform;
61  po.Register("feature-transform", &feature_transform,
62  "Feature transform in Nnet format");
63 
64  std::string objective_function = "xent";
65  po.Register("objective-function", &objective_function,
66  "Objective function : xent|mse|multitask");
67 
68  int32 length_tolerance = 5;
69  po.Register("length-tolerance", &length_tolerance,
70  "Allowed length mismatch of features/targets/weights "
71  "(in frames, we truncate to the shortest)");
72 
73  std::string frame_weights;
74  po.Register("frame-weights", &frame_weights,
75  "Per-frame weights, used to re-scale gradients.");
76 
77  std::string utt_weights;
78  po.Register("utt-weights", &utt_weights,
79  "Per-utterance weights, used to re-scale frame-weights.");
80 
81  std::string use_gpu="yes";
82  po.Register("use-gpu", &use_gpu,
83  "yes|no|optional, only has effect if compiled with CUDA");
84 
85  po.Read(argc, argv);
86 
87  if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
88  po.PrintUsage();
89  exit(1);
90  }
91 
92  std::string feature_rspecifier = po.GetArg(1),
93  targets_rspecifier = po.GetArg(2),
94  model_filename = po.GetArg(3);
95 
96  std::string target_model_filename;
97  if (!crossvalidate) {
98  target_model_filename = po.GetArg(4);
99  }
100 
101  using namespace kaldi;
102  using namespace kaldi::nnet1;
103  typedef kaldi::int32 int32;
104 
105 #if HAVE_CUDA == 1
106  CuDevice::Instantiate().SelectGpuId(use_gpu);
107 #endif
108 
109  Nnet nnet_transf;
110  if (feature_transform != "") {
111  nnet_transf.Read(feature_transform);
112  }
113 
114  Nnet nnet;
115  nnet.Read(model_filename);
116  nnet.SetTrainOptions(trn_opts);
117 
118  if (crossvalidate) {
119  nnet_transf.SetDropoutRate(0.0);
120  nnet.SetDropoutRate(0.0);
121  }
122 
123  kaldi::int64 total_frames = 0;
124 
125  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
126  RandomAccessPosteriorReader targets_reader(targets_rspecifier);
127  RandomAccessBaseFloatVectorReader weights_reader;
128  if (frame_weights != "") {
129  weights_reader.Open(frame_weights);
130  }
131  RandomAccessBaseFloatReader utt_weights_reader;
132  if (utt_weights != "") {
133  utt_weights_reader.Open(utt_weights);
134  }
135 
136  RandomizerMask randomizer_mask(rnd_opts);
137  MatrixRandomizer feature_randomizer(rnd_opts);
138  PosteriorRandomizer targets_randomizer(rnd_opts);
139  VectorRandomizer weights_randomizer(rnd_opts);
140 
141  Xent xent;
142  Mse mse;
143 
144  MultiTaskLoss multitask;
145  if (0 == objective_function.compare(0, 9, "multitask")) {
146  // objective_function contains something like :
147  // 'multitask,xent,2456,1.0,mse,440,0.001'
148  //
149  // the meaning is following:
150  // 'multitask,<type1>,<dim1>,<weight1>,...,<typeN>,<dimN>,<weightN>'
151  multitask.InitFromString(objective_function);
152  }
153 
154  CuMatrix<BaseFloat> feats_transf, nnet_out, obj_diff;
155 
156  Timer time;
157  KALDI_LOG << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
158  << " STARTED";
159 
160  int32 num_done = 0,
161  num_no_tgt_mat = 0,
162  num_other_error = 0;
163 
164  // main loop,
165  while (!feature_reader.Done()) {
166 #if HAVE_CUDA == 1
167  // check that GPU computes accurately,
168  CuDevice::Instantiate().CheckGpuHealth();
169 #endif
170  // fill the randomizer,
171  for ( ; !feature_reader.Done(); feature_reader.Next()) {
172  if (feature_randomizer.IsFull()) {
173  // break the loop without calling Next(),
174  // we keep the 'utt' for next round,
175  break;
176  }
177  std::string utt = feature_reader.Key();
178  KALDI_VLOG(3) << "Reading " << utt;
179  // check that we have targets,
180  if (!targets_reader.HasKey(utt)) {
181  KALDI_WARN << utt << ", missing targets";
182  num_no_tgt_mat++;
183  continue;
184  }
185  // check we have per-frame weights,
186  if (frame_weights != "" && !weights_reader.HasKey(utt)) {
187  KALDI_WARN << utt << ", missing per-frame weights";
188  num_other_error++;
189  continue;
190  }
191  // check we have per-utterance weights,
192  if (utt_weights != "" && !utt_weights_reader.HasKey(utt)) {
193  KALDI_WARN << utt << ", missing per-utterance weight";
194  num_other_error++;
195  continue;
196  }
197  // get feature / target pair,
198  Matrix<BaseFloat> mat = feature_reader.Value();
199  Posterior targets = targets_reader.Value(utt);
200  // get per-frame weights,
201  Vector<BaseFloat> weights;
202  if (frame_weights != "") {
203  weights = weights_reader.Value(utt);
204  } else { // all per-frame weights are 1.0,
205  weights.Resize(mat.NumRows());
206  weights.Set(1.0);
207  }
208  // multiply with per-utterance weight,
209  if (utt_weights != "") {
210  BaseFloat w = utt_weights_reader.Value(utt);
211  KALDI_ASSERT(w >= 0.0);
212  if (w == 0.0) continue; // remove sentence from training,
213  weights.Scale(w);
214  }
215 
216  // correct small length mismatch or drop sentence,
217  {
218  // add lengths to vector,
219  std::vector<int32> length;
220  length.push_back(mat.NumRows());
221  length.push_back(targets.size());
222  length.push_back(weights.Dim());
223  // find min, max,
224  int32 min = *std::min_element(length.begin(), length.end());
225  int32 max = *std::max_element(length.begin(), length.end());
226  // fix or drop ?
227  if (max - min < length_tolerance) {
228  // we truncate to shortest,
229  if (mat.NumRows() != min) mat.Resize(min, mat.NumCols(), kCopyData);
230  if (targets.size() != min) targets.resize(min);
231  if (weights.Dim() != min) weights.Resize(min, kCopyData);
232  } else {
233  KALDI_WARN << "Length mismatch! Targets " << targets.size()
234  << ", features " << mat.NumRows() << ", " << utt;
235  num_other_error++;
236  continue;
237  }
238  }
239  // apply feature transform (if empty, input is copied),
240  nnet_transf.Feedforward(CuMatrix<BaseFloat>(mat), &feats_transf);
241 
242  // remove frames with '0' weight from training,
243  {
244  // are there any frames to be removed? (frames with zero weight),
245  BaseFloat weight_min = weights.Min();
246  KALDI_ASSERT(weight_min >= 0.0);
247  if (weight_min == 0.0) {
248  // create vector with frame-indices to keep,
249  std::vector<MatrixIndexT> keep_frames;
250  for (int32 i = 0; i < weights.Dim(); i++) {
251  if (weights(i) > 0.0) {
252  keep_frames.push_back(i);
253  }
254  }
255 
256  // when all frames are removed, we skip the sentence,
257  if (keep_frames.size() == 0) continue;
258 
259  // filter feature-frames,
260  CuMatrix<BaseFloat> tmp_feats(keep_frames.size(), feats_transf.NumCols());
261  tmp_feats.CopyRows(feats_transf, CuArray<MatrixIndexT>(keep_frames));
262  tmp_feats.Swap(&feats_transf);
263 
264  // filter targets,
265  Posterior tmp_targets;
266  for (int32 i = 0; i < keep_frames.size(); i++) {
267  tmp_targets.push_back(targets[keep_frames[i]]);
268  }
269  tmp_targets.swap(targets);
270 
271  // filter weights,
272  Vector<BaseFloat> tmp_weights(keep_frames.size());
273  for (int32 i = 0; i < keep_frames.size(); i++) {
274  tmp_weights(i) = weights(keep_frames[i]);
275  }
276  tmp_weights.Swap(&weights);
277  }
278  }
279 
280  // pass data to randomizers,
281  KALDI_ASSERT(feats_transf.NumRows() == targets.size());
282  feature_randomizer.AddData(feats_transf);
283  targets_randomizer.AddData(targets);
284  weights_randomizer.AddData(weights);
285  num_done++;
286 
287  // report the speed,
288  if (num_done % 5000 == 0) {
289  double time_now = time.Elapsed();
290  KALDI_VLOG(1) << "After " << num_done << " utterances: "
291  << "time elapsed = " << time_now / 60 << " min; "
292  << "processed " << total_frames / time_now << " frames per sec.";
293  }
294  }
295 
296  // randomize,
297  if (!crossvalidate && randomize) {
298  const std::vector<int32>& mask =
299  randomizer_mask.Generate(feature_randomizer.NumFrames());
300  feature_randomizer.Randomize(mask);
301  targets_randomizer.Randomize(mask);
302  weights_randomizer.Randomize(mask);
303  }
304 
305  // train with data from randomizers (using mini-batches),
306  for ( ; !feature_randomizer.Done(); feature_randomizer.Next(),
307  targets_randomizer.Next(),
308  weights_randomizer.Next()) {
309  // get block of feature/target pairs,
310  const CuMatrixBase<BaseFloat>& nnet_in = feature_randomizer.Value();
311  const Posterior& nnet_tgt = targets_randomizer.Value();
312  const Vector<BaseFloat>& frm_weights = weights_randomizer.Value();
313 
314  // forward pass,
315  nnet.Propagate(nnet_in, &nnet_out);
316 
317  // evaluate objective function we've chosen,
318  if (objective_function == "xent") {
319  // gradients re-scaled by weights in Eval,
320  xent.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
321  } else if (objective_function == "mse") {
322  // gradients re-scaled by weights in Eval,
323  mse.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
324  } else if (0 == objective_function.compare(0, 9, "multitask")) {
325  // gradients re-scaled by weights in Eval,
326  multitask.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
327  } else {
328  KALDI_ERR << "Unknown objective function code : " << objective_function;
329  }
330 
331  if (!crossvalidate) {
332  // back-propagate, and do the update,
333  nnet.Backpropagate(obj_diff, NULL);
334  }
335 
336  // 1st mini-batch : show what happens in network,
337  if (total_frames == 0) {
338  KALDI_VLOG(1) << "### After " << total_frames << " frames,";
339  KALDI_VLOG(1) << nnet.InfoPropagate();
340  if (!crossvalidate) {
341  KALDI_VLOG(1) << nnet.InfoBackPropagate();
342  KALDI_VLOG(1) << nnet.InfoGradient();
343  }
344  }
345 
346  // VERBOSE LOG
347  // monitor the NN training (--verbose=2),
348  if (GetVerboseLevel() >= 2) {
349  static int32 counter = 0;
350  counter += nnet_in.NumRows();
351  // print every 25k frames,
352  if (counter >= 25000) {
353  KALDI_VLOG(2) << "### After " << total_frames << " frames,";
354  KALDI_VLOG(2) << nnet.InfoPropagate();
355  if (!crossvalidate) {
356  KALDI_VLOG(2) << nnet.InfoBackPropagate();
357  KALDI_VLOG(2) << nnet.InfoGradient();
358  }
359  counter = 0;
360  }
361  }
362 
363  total_frames += nnet_in.NumRows();
364  }
365  } // main loop,
366 
367  // after last mini-batch : show what happens in network,
368  KALDI_VLOG(1) << "### After " << total_frames << " frames,";
369  KALDI_VLOG(1) << nnet.InfoPropagate();
370  if (!crossvalidate) {
371  KALDI_VLOG(1) << nnet.InfoBackPropagate();
372  KALDI_VLOG(1) << nnet.InfoGradient();
373  }
374 
375  if (!crossvalidate) {
376  nnet.Write(target_model_filename, binary);
377  }
378 
379  KALDI_LOG << "Done " << num_done << " files, "
380  << num_no_tgt_mat << " with no tgt_mats, "
381  << num_other_error << " with other errors. "
382  << "[" << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
383  << ", " << (randomize ? "RANDOMIZED" : "NOT-RANDOMIZED")
384  << ", " << time.Elapsed() / 60 << " min, processing "
385  << total_frames / time.Elapsed() << " frames per sec.]";
386 
387  if (objective_function == "xent") {
388  KALDI_LOG << xent.ReportPerClass();
389  KALDI_LOG << xent.Report();
390  } else if (objective_function == "mse") {
391  KALDI_LOG << mse.Report();
392  } else if (0 == objective_function.compare(0, 9, "multitask")) {
393  KALDI_LOG << multitask.Report();
394  } else {
395  KALDI_ERR << "Unknown objective function code : " << objective_function;
396  }
397 
398 #if HAVE_CUDA == 1
399  CuDevice::Instantiate().PrintProfile();
400 #endif
401 
402  return 0;
403  } catch(const std::exception &e) {
404  std::cerr << e.what();
405  return -1;
406  }
407 }
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
void CopyRows(const CuMatrixBase< Real > &src, const CuArray< MatrixIndexT > &indexes)
Copies row r from row indexes[r] of src.
Definition: cu-matrix.cc:2408
std::string Report()
Generate string with error report.
Definition: nnet-loss.cc:384
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
Generates randomly ordered vector of indices,.
bool Open(const std::string &rspecifier)
Randomizes elements of a vector according to a mask.
MatrixIndexT NumCols() const
Definition: cu-matrix.h:196
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
double Elapsed()
Returns time in seconds.
Definition: timer.h:65
Real Min() const
Returns the minimum value of any element, or +infinity for the empty vector.
std::string Report()
Generate string with error report.
Definition: nnet-loss.cc:289
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
Configuration variables that affect how frame-level shuffling is done.
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate mean square error using target-matrix,.
Definition: nnet-loss.h:191
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate cross entropy using target-matrix (supports soft labels),.
Definition: nnet-loss.cc:61
std::string Report()
Generate string with error report,.
Definition: nnet-loss.cc:171
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate mean square error using target-matrix,.
Definition: nnet-loss.cc:217
#define KALDI_ERR
Definition: kaldi-error.h:127
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
Definition: nnet-nnet.cc:333
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:46
#define KALDI_WARN
Definition: kaldi-error.h:130
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
void Scale(Real alpha)
Multiplies all elements by this constant.
bool HasKey(const std::string &key)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
Shuffles rows of a matrix according to the indices in the mask,.
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
std::vector equivalent for CUDA computing.
Definition: cu-array.h:37
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void Set(Real f)
Set all members of a vector to a specified value.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
Definition: nnet-nnet.cc:367
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
std::string ReportPerClass()
Generate string with per-class error report,.
Definition: nnet-loss.cc:192
Randomizes elements of a vector according to a mask.
#define KALDI_LOG
Definition: kaldi-error.h:133
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:62
void InitFromString(const std::string &s)
Initialize from string, the format for string 's' is : 'multitask,,,,...,,,'.
Definition: nnet-loss.cc:308