nnet-train-multistream.cc
Go to the documentation of this file.
1 // nnetbin/nnet-train-multistream.cc
2 
3 // Copyright 2015-2016 Brno University of Technology (Author: Karel Vesely)
4 // 2014 Jiayu DU (Jerry), Wei Li
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <numeric>
22 
23 #include "nnet/nnet-trnopts.h"
24 #include "nnet/nnet-nnet.h"
25 #include "nnet/nnet-loss.h"
26 #include "nnet/nnet-randomizer.h"
27 #include "base/kaldi-common.h"
28 #include "util/common-utils.h"
29 #include "base/timer.h"
30 #include "cudamatrix/cu-device.h"
31 
32 
33 namespace kaldi {
34 
36  RandomAccessPosteriorReader& target_reader,
37  RandomAccessBaseFloatVectorReader& weights_reader,
38  int32 length_tolerance,
39  Matrix<BaseFloat>* feats,
40  Posterior* targets,
41  Vector<BaseFloat>* weights,
42  int32* num_no_tgt_mat,
43  int32* num_other_error) {
44 
45  // We're looking for the 1st valid utterance...
46  for ( ; !feature_reader.Done(); feature_reader.Next()) {
47  // Do we have targets?
48  const std::string& utt = feature_reader.Key();
49  KALDI_VLOG(3) << "Reading: " << utt;
50  if (!target_reader.HasKey(utt)) {
51  KALDI_WARN << utt << ", missing targets";
52  (*num_no_tgt_mat)++;
53  continue;
54  }
55  // Do we have frame-weights?
56  if (weights_reader.IsOpen() && !weights_reader.HasKey(utt)) {
57  KALDI_WARN << utt << ", missing frame-weights";
58  (*num_other_error)++;
59  continue;
60  }
61 
62  // get the (feature,target) pair,
63  (*feats) = feature_reader.Value();
64  (*targets) = target_reader.Value(utt);
65 
66  // getting per-frame weights,
67  if (weights_reader.IsOpen()) {
68  (*weights) = weights_reader.Value(utt);
69  } else { // all per-frame weights are 1.0
70  weights->Resize(feats->NumRows());
71  weights->Set(1.0);
72  }
73 
74  // correct small length mismatch ... or drop sentence
75  {
76  // add lengths to vector
77  std::vector<int32> length;
78  length.push_back(feats->NumRows());
79  length.push_back(targets->size());
80  length.push_back(weights->Dim());
81  // find min, max
82  int32 min = *std::min_element(length.begin(), length.end());
83  int32 max = *std::max_element(length.begin(), length.end());
84  // fix or drop ?
85  if (max - min < length_tolerance) {
86  if (feats->NumRows() != min) feats->Resize(min, feats->NumCols(), kCopyData);
87  if (targets->size() != min) targets->resize(min);
88  if (weights->Dim() != min) weights->Resize(min, kCopyData);
89  } else {
90  KALDI_WARN << "Length mismatch! Targets " << targets->size()
91  << ", features " << feats->NumRows() << ", " << utt;
92  num_other_error++;
93  continue;
94  }
95  }
96 
97  // By getting here we got a valid utterance,
98  feature_reader.Next();
99  return true;
100  }
101 
102  // No more data,
103  return false;
104 }
105 
106 } // namespace kaldi
107 
108 
109 int main(int argc, char *argv[]) {
110  using namespace kaldi;
111  using namespace kaldi::nnet1;
112  typedef kaldi::int32 int32;
113 
114  try {
115  const char *usage =
116  "Perform one iteration of Multi-stream training, truncated BPTT for LSTMs.\n"
117  "The training targets are pdf-posteriors, usually prepared by ali-to-post.\n"
118  "The updates are per-utterance.\n"
119  "\n"
120  "Usage: nnet-train-multistream [options] "
121  "<feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n"
122  "e.g.: nnet-train-lstm-streams scp:feature.scp ark:posterior.ark nnet.init nnet.iter1\n";
123 
124  ParseOptions po(usage);
125 
126  NnetTrainOptions trn_opts;
127  trn_opts.Register(&po);
128  LossOptions loss_opts;
129  loss_opts.Register(&po);
130 
131  bool binary = true;
132  po.Register("binary", &binary, "Write output in binary mode");
133 
134  bool crossvalidate = false;
135  po.Register("cross-validate", &crossvalidate,
136  "Perform cross-validation (don't back-propagate)");
137 
138  std::string feature_transform;
139  po.Register("feature-transform", &feature_transform,
140  "Feature transform in Nnet format");
141 
142  std::string objective_function = "xent";
143  po.Register("objective-function", &objective_function,
144  "Objective function : xent|mse");
145 
146  int32 length_tolerance = 5;
147  po.Register("length-tolerance", &length_tolerance,
148  "Allowed length difference of features/targets (frames)");
149 
150  std::string frame_weights;
151  po.Register("frame-weights", &frame_weights,
152  "Per-frame weights to scale gradients (frame selection/weighting).");
153 
154  int32 batch_size = 20;
155  po.Register("batch-size", &batch_size,
156  "Length of 'one stream' in the Multi-stream training");
157 
158  int32 num_streams = 4;
159  po.Register("num-streams", &num_streams,
160  "Number of streams in the Multi-stream training");
161 
162  bool dummy = false;
163  po.Register("randomize", &dummy, "Dummy option.");
164 
165  std::string use_gpu="yes";
166  po.Register("use-gpu", &use_gpu,
167  "yes|no|optional, only has effect if compiled with CUDA");
168 
169  po.Read(argc, argv);
170 
171  if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
172  po.PrintUsage();
173  exit(1);
174  }
175 
176  std::string feature_rspecifier = po.GetArg(1),
177  targets_rspecifier = po.GetArg(2),
178  model_filename = po.GetArg(3);
179 
180  std::string target_model_filename;
181  if (!crossvalidate) {
182  target_model_filename = po.GetArg(4);
183  }
184 
185  using namespace kaldi;
186  using namespace kaldi::nnet1;
187  typedef kaldi::int32 int32;
188 
189 #if HAVE_CUDA == 1
190  CuDevice::Instantiate().SelectGpuId(use_gpu);
191 #endif
192 
193  Nnet nnet_transf;
194  if (feature_transform != "") {
195  nnet_transf.Read(feature_transform);
196  }
197 
198  Nnet nnet;
199  nnet.Read(model_filename);
200  nnet.SetTrainOptions(trn_opts);
201 
202  if (crossvalidate) {
203  nnet_transf.SetDropoutRate(0.0);
204  nnet.SetDropoutRate(0.0);
205  }
206 
207  kaldi::int64 total_frames = 0;
208 
209  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
210  RandomAccessPosteriorReader target_reader(targets_rspecifier);
211  RandomAccessBaseFloatVectorReader weights_reader;
212  if (frame_weights != "") {
213  weights_reader.Open(frame_weights);
214  }
215 
216  Xent xent(loss_opts);
217  Mse mse(loss_opts);
218 
219  Timer time;
220  double time_gpu = 0;
221  KALDI_LOG << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
222  << " STARTED";
223 
224  int32 num_done = 0,
225  num_no_tgt_mat = 0,
226  num_other_error = 0;
227 
228  // book-keeping for multi-stream training,
229  std::vector<Matrix<BaseFloat> > feats_utt(num_streams);
230  std::vector<Posterior> labels_utt(num_streams);
231  std::vector<Vector<BaseFloat> > weights_utt(num_streams);
232  std::vector<int32> cursor_utt(num_streams); // 0 initialized,
233  std::vector<int32> new_utt_flags(num_streams);
234 
235  CuMatrix<BaseFloat> feats_transf, nnet_out, obj_diff;
236 
237  // MAIN LOOP,
238  while (1) {
239 
240  // Re-fill the streams, if needed,
241  new_utt_flags.assign(num_streams, 0); // set new-utterance flags to zero,
242  for (int s = 0; s < num_streams; s++) {
243  // Need a new utterance for stream 's'?
244  if (cursor_utt[s] >= feats_utt[s].NumRows()) {
245  Matrix<BaseFloat> feats;
246  Posterior targets;
247  Vector<BaseFloat> weights;
248  // get the data from readers,
249  if (ReadData(feature_reader, target_reader, weights_reader,
250  length_tolerance,
251  &feats, &targets, &weights,
252  &num_no_tgt_mat, &num_other_error)) {
253 
254  // input transform may contain splicing,
255  Timer t;
256  nnet_transf.Feedforward(CuMatrix<BaseFloat>(feats), &feats_transf);
257  time_gpu += t.Elapsed();
258 
259  /* Here we could do the 'targets_delay', BUT...
260  * It is better to do it by a <Splice> component!
261  *
262  * The prototype would look like this (6th frame becomes 1st frame, etc.):
263  * '<Splice> <InputDim> dim1 <OutputDim> dim1 <BuildVector> 5 </BuildVector>'
264  */
265 
266  // store,
267  feats_utt[s] = Matrix<BaseFloat>(feats_transf);
268  labels_utt[s] = targets;
269  weights_utt[s] = weights;
270  cursor_utt[s] = 0;
271  new_utt_flags[s] = 1;
272  }
273  }
274  }
275 
276  // End the training when 1st stream is empty
277  // (this avoids over-adaptation to last utterances),
278  size_t inactive_streams = 0;
279  for (int32 s = 0; s < num_streams; s++) {
280  if (feats_utt[s].NumRows() - cursor_utt[s] <= 0) {
281  inactive_streams += 1;
282  }
283  }
284  if (inactive_streams >= 1) {
285  KALDI_LOG << "No more data to re-fill one of the streams, end of the training!";
286  KALDI_LOG << "(remaining stubs of data are discarded, don't overtrain on them)";
287  break;
288  }
289 
290  // number of frames we'll pack as the streams,
291  std::vector<int32> frame_num_utt;
292 
293  // pack the parallel data,
294  Matrix<BaseFloat> feat_mat_host;
295  Posterior target_host;
296  Vector<BaseFloat> weight_host;
297  {
298  // Number of sequences (can have zero length),
299  int32 n_streams = num_streams;
300 
301  // Create the final feature matrix with 'interleaved feature-lines',
302  feat_mat_host.Resize(n_streams * batch_size, nnet.InputDim(), kSetZero);
303  target_host.resize(n_streams * batch_size);
304  weight_host.Resize(n_streams * batch_size, kSetZero);
305  frame_num_utt.resize(n_streams, 0);
306 
307  // we slice at the 'cursor' at most 'batch_size' frames,
308  for (int32 s = 0; s < n_streams; s++) {
309  int32 num_rows = std::max(0, feats_utt[s].NumRows() - cursor_utt[s]);
310  frame_num_utt[s] = std::min(batch_size, num_rows);
311  }
312 
313  // pack the data,
314  {
315  for (int32 s = 0; s < n_streams; s++) {
316  if (frame_num_utt[s] > 0) {
317  auto mat_tmp = feats_utt[s].RowRange(cursor_utt[s], frame_num_utt[s]);
318  for (int32 r = 0; r < frame_num_utt[s]; r++) {
319  feat_mat_host.Row(r*n_streams + s).CopyFromVec(mat_tmp.Row(r));
320  }
321  }
322  }
323 
324  for (int32 s = 0; s < n_streams; s++) {
325  for (int32 r = 0; r < frame_num_utt[s]; r++) {
326  target_host[r*n_streams + s] = labels_utt[s][cursor_utt[s] + r];
327  }
328  }
329 
330  // padded frames will keep initial zero-weight,
331  for (int32 s = 0; s < n_streams; s++) {
332  if (frame_num_utt[s] > 0) {
333  auto weight_tmp = weights_utt[s].Range(cursor_utt[s], frame_num_utt[s]);
334  for (int32 r = 0; r < frame_num_utt[s]; r++) {
335  weight_host(r*n_streams + s) = weight_tmp(r);
336  }
337  }
338  }
339  }
340 
341  // advance the cursors,
342  for (int32 s = 0; s < n_streams; s++) {
343  cursor_utt[s] += frame_num_utt[s];
344  }
345  }
346 
347  // pass the info about padding,
348  nnet.SetSeqLengths(frame_num_utt);
349 
350  // Show debug info,
351  if (GetVerboseLevel() >= 4) {
352  // cursors in the feature_matrices,
353  {
354  std::ostringstream os;
355  os << "[ ";
356  for (size_t i = 0; i < cursor_utt.size(); i++) {
357  os << cursor_utt[i] << " ";
358  }
359  os << "]";
360  KALDI_LOG << "cursor_utt[" << cursor_utt.size() << "]" << os.str();
361  }
362  // frames in the mini-batch,
363  {
364  std::ostringstream os;
365  os << "[ ";
366  for (size_t i = 0; i < frame_num_utt.size(); i++) {
367  os << frame_num_utt[i] << " ";
368  }
369  os << "]";
370  KALDI_LOG << "frame_num_utt[" << frame_num_utt.size() << "]" << os.str();
371  }
372  }
373 
374  Timer t;
375  // with new utterance we reset the history,
376  nnet.ResetStreams(new_utt_flags);
377 
378  // forward pass,
379  nnet.Propagate(CuMatrix<BaseFloat>(feat_mat_host), &nnet_out);
380 
381  // evaluate objective function we've chosen,
382  if (objective_function == "xent") {
383  xent.Eval(weight_host, nnet_out, target_host, &obj_diff);
384  } else if (objective_function == "mse") {
385  mse.Eval(weight_host, nnet_out, target_host, &obj_diff);
386  } else {
387  KALDI_ERR << "Unknown objective function code : "
388  << objective_function;
389  }
390 
391  if (!crossvalidate) {
392  // back-propagate, and do the update,
393  nnet.Backpropagate(obj_diff, NULL);
394  }
395  time_gpu += t.Elapsed();
396 
397  // 1st minibatch : show what happens in network,
398  if (total_frames == 0) {
399  KALDI_LOG << "### After " << total_frames << " frames,";
400  KALDI_LOG << nnet.Info();
401  KALDI_LOG << nnet.InfoPropagate();
402  if (!crossvalidate) {
403  KALDI_LOG << nnet.InfoBackPropagate();
404  KALDI_LOG << nnet.InfoGradient();
405  }
406  }
407 
408  kaldi::int64 tmp_frames = total_frames;
409 
410  num_done += std::accumulate(new_utt_flags.begin(), new_utt_flags.end(), 0);
411  total_frames += std::accumulate(frame_num_utt.begin(), frame_num_utt.end(), 0);
412 
413  // monitor the NN training (--verbose=2),
414  int32 F = 25000;
415  if (GetVerboseLevel() >= 2) {
416  // print every 25k frames,
417  if (tmp_frames / F != total_frames / F) {
418  KALDI_VLOG(2) << "### After " << total_frames << " frames,";
419  KALDI_VLOG(2) << nnet.Info();
420  KALDI_VLOG(2) << nnet.InfoPropagate();
421  if (!crossvalidate) {
422  KALDI_VLOG(2) << nnet.InfoBackPropagate();
423  KALDI_VLOG(2) << nnet.InfoGradient();
424  }
425  }
426  }
427  }
428 
429  // after last minibatch : show what happens in network,
430  KALDI_LOG << "### After " << total_frames << " frames,";
431  KALDI_LOG << nnet.Info();
432  KALDI_LOG << nnet.InfoPropagate();
433  if (!crossvalidate) {
434  KALDI_LOG << nnet.InfoBackPropagate();
435  KALDI_LOG << nnet.InfoGradient();
436  }
437 
438  if (!crossvalidate) {
439  nnet.Write(target_model_filename, binary);
440  }
441 
442  if (objective_function == "xent") {
443  KALDI_LOG << xent.ReportPerClass();
444  }
445 
446  KALDI_LOG << "Done " << num_done << " files, "
447  << num_no_tgt_mat << " with no tgt_mats, "
448  << num_other_error << " with other errors. "
449  << "[" << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
450  << ", " << time.Elapsed() / 60 << " min, processing "
451  << total_frames / time.Elapsed() << " frames per sec, "
452  << "GPU_time " << 100.*time_gpu/time.Elapsed() << "% ]";
453 
454  if (objective_function == "xent") {
455  KALDI_LOG << xent.Report();
456  } else if (objective_function == "mse") {
457  KALDI_LOG << mse.Report();
458  } else {
459  KALDI_ERR << "Unknown objective function code : " << objective_function;
460  }
461 
462 #if HAVE_CUDA == 1
463  CuDevice::Instantiate().PrintProfile();
464 #endif
465 
466  return 0;
467  } catch(const std::exception &e) {
468  std::cerr << e.what();
469  return -1;
470  }
471 }
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ResetStreams(const std::vector< int32 > &stream_reset_flag)
Reset streams in multi-stream training,.
Definition: nnet-nnet.cc:281
int main(int argc, char *argv[])
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
void SetSeqLengths(const std::vector< int32 > &sequence_lengths)
Set sequence length in LSTM multi-stream training,.
Definition: nnet-nnet.cc:291
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
int32 InputDim() const
Dimensionality on network input (input feature dim.),.
Definition: nnet-nnet.cc:148
bool Open(const std::string &rspecifier)
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Register(const std::string &name, bool *ptr, const std::string &doc)
std::string Report()
Generate string with error report.
Definition: nnet-loss.cc:299
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate cross entropy using target-matrix (supports soft labels),.
Definition: nnet-loss.cc:62
std::string Report()
Generate string with error report,.
Definition: nnet-loss.cc:182
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate mean square error using target-matrix,.
Definition: nnet-loss.cc:228
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:46
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
bool HasKey(const std::string &key)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void Set(Real f)
Set all members of a vector to a specified value.
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420
std::string Info() const
Create string with human readable description of the nnet,.
Definition: nnet-nnet.cc:386
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
std::string ReportPerClass()
Generate string with per-class error report,.
Definition: nnet-loss.cc:203
void Register(OptionsItf *opts)
Definition: nnet-loss.h:45
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
bool ReadData(SequentialBaseFloatMatrixReader &feature_reader, RandomAccessPosteriorReader &target_reader, RandomAccessBaseFloatVectorReader &weights_reader, int32 length_tolerance, Matrix< BaseFloat > *feats, Posterior *targets, Vector< BaseFloat > *weights, int32 *num_no_tgt_mat, int32 *num_other_error)