nnet-train-multistream-perutt.cc
Go to the documentation of this file.
1 // nnetbin/nnet-train-multistream-perutt.cc
2 
3 // Copyright 2016 Brno University of Technology (author: Karel Vesely)
4 // Copyright 2015 Chongjia Ni
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "nnet/nnet-trnopts.h"
22 #include "nnet/nnet-nnet.h"
23 #include "nnet/nnet-loss.h"
25 
26 #include "base/kaldi-common.h"
27 #include "util/common-utils.h"
28 #include "base/timer.h"
29 #include "cudamatrix/cu-device.h"
30 
31 #include <numeric>
32 #include <algorithm>
33 
34 int main(int argc, char *argv[]) {
35  using namespace kaldi;
36  using namespace kaldi::nnet1;
37  typedef kaldi::int32 int32;
38 
39  try {
40  const char *usage =
41  "Perform one iteration of Multi-stream training, per-utterance BPTT for (B)LSTMs.\n"
42  "The updates are done per-utterance, while several utterances are \n"
43  "processed at the same time.\n"
44  "\n"
45  "Usage: nnet-train-multistream-perutt [options] <feature-rspecifier> <labels-rspecifier> <model-in> [<model-out>]\n"
46  "e.g.: nnet-train-blstm-streams scp:feats.scp ark:targets.ark nnet.init nnet.iter1\n";
47 
48  ParseOptions po(usage);
49 
50  // training options,
51  NnetTrainOptions trn_opts;
52  trn_opts.Register(&po);
53  LossOptions loss_opts;
54  loss_opts.Register(&po);
55 
56  bool binary = true;
57  po.Register("binary", &binary, "Write model in binary mode");
58 
59  bool crossvalidate = false;
60  po.Register("cross-validate", &crossvalidate,
61  "Perform cross-validation (no backpropagation)");
62 
63  std::string feature_transform;
64  po.Register("feature-transform", &feature_transform,
65  "Feature transform in Nnet format");
66 
67  int32 length_tolerance = 5;
68  po.Register("length-tolerance", &length_tolerance,
69  "Allowed length difference of features/targets (frames)");
70 
71  std::string frame_weights;
72  po.Register("frame-weights", &frame_weights,
73  "Per-frame weights to scale gradients (frame selection/weighting).");
74 
75  int32 num_streams = 20;
76  po.Register("num-streams", &num_streams,
77  "Number of sentences processed in parallel (can be lower if sentences are long)");
78 
79  double max_frames = 8000;
80  po.Register("max-frames", &max_frames,
81  "Max number of frames to be processed");
82 
83  bool dummy = false;
84  po.Register("randomize", &dummy, "Dummy option.");
85 
86  std::string use_gpu = "yes";
87  po.Register("use-gpu", &use_gpu,
88  "yes|no|optional, only has effect if compiled with CUDA");
89 
90  po.Read(argc, argv);
91 
92  if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
93  po.PrintUsage();
94  exit(1);
95  }
96 
97  std::string feature_rspecifier = po.GetArg(1),
98  targets_rspecifier = po.GetArg(2),
99  model_filename = po.GetArg(3);
100 
101  std::string target_model_filename;
102  if (!crossvalidate) {
103  target_model_filename = po.GetArg(4);
104  }
105 
106  using namespace kaldi;
107  using namespace kaldi::nnet1;
108  typedef kaldi::int32 int32;
109 
110 #if HAVE_CUDA == 1
111  CuDevice::Instantiate().SelectGpuId(use_gpu);
112 #endif
113 
114  Nnet nnet_transf;
115  if ( feature_transform != "" ) {
116  nnet_transf.Read(feature_transform);
117  }
118 
119  Nnet nnet;
120  nnet.Read(model_filename);
121  nnet.SetTrainOptions(trn_opts);
122 
123  if (crossvalidate) {
124  nnet_transf.SetDropoutRate(0.0);
125  nnet.SetDropoutRate(0.0);
126  }
127 
128  kaldi::int64 total_frames = 0;
129 
130  // Initialize feature and target readers,
131  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
132  RandomAccessPosteriorReader targets_reader(targets_rspecifier);
133  RandomAccessBaseFloatVectorReader weights_reader;
134  if (frame_weights != "") {
135  weights_reader.Open(frame_weights);
136  }
137 
138 
139  Xent xent(loss_opts);
140 
141  CuMatrix<BaseFloat> feats_transf, nnet_out, obj_diff;
142 
143  Timer time;
144  KALDI_LOG << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
145  << " STARTED";
146 
147  // Buffer for input features, used for choosing utt's with similar length,
148  MatrixBuffer matrix_buffer;
149  matrix_buffer.Init(&feature_reader);
150 
151  int32 num_done = 0,
152  num_no_tgt_mat = 0,
153  num_other_error = 0;
154 
155  while (!matrix_buffer.Done()) {
156 
157  // Fill the parallel data into 'std::vector',
158  std::vector<Matrix<BaseFloat> > feats_utt;
159  std::vector<Posterior> labels_utt;
160  std::vector<Vector<BaseFloat> > weights_utt;
161  std::vector<int32> frame_num_utt;
162  {
163  matrix_buffer.ResetLength();
164  for (matrix_buffer.Next(); !matrix_buffer.Done(); matrix_buffer.Next()) {
165  std::string utt = matrix_buffer.Key();
166  // Check that we have targets,
167  if (!targets_reader.HasKey(utt)) {
168  KALDI_WARN << utt << ", missing targets";
169  num_no_tgt_mat++;
170  continue;
171  }
172  // Do we have frame-weights?
173  if (frame_weights != "" && !weights_reader.HasKey(utt)) {
174  KALDI_WARN << utt << ", missing frame-weights";
175  num_other_error++;
176  continue;
177  }
178 
179  // Get feature / target pair,
180  Matrix<BaseFloat> mat = matrix_buffer.Value();
181  Posterior targets = targets_reader.Value(utt);
182 
183  // Skip too long sentences,
184  if (mat.NumRows() > max_frames) continue;
185 
186  Vector<BaseFloat> weights;
187  if (frame_weights != "") {
188  weights = weights_reader.Value(utt);
189  } else { // all per-frame weights are 1.0
190  weights.Resize(mat.NumRows());
191  weights.Set(1.0);
192  }
193 
194  // correct small length mismatch ... or drop sentence
195  {
196  // add lengths to vector
197  std::vector<int32> length;
198  length.push_back(mat.NumRows());
199  length.push_back(targets.size());
200  length.push_back(weights.Dim());
201  // find min, max
202  int32 min = *std::min_element(length.begin(), length.end());
203  int32 max = *std::max_element(length.begin(), length.end());
204  // fix or drop ?
205  if (max - min < length_tolerance) {
206  if (mat.NumRows() != min) mat.Resize(min, mat.NumCols(), kCopyData);
207  if (targets.size() != min) targets.resize(min);
208  if (weights.Dim() != min) weights.Resize(min, kCopyData);
209  } else {
210  KALDI_WARN << "Length mismatch! Targets " << targets.size()
211  << ", features " << mat.NumRows() << ", " << utt;
212  num_other_error++;
213  continue;
214  }
215  }
216 
217  // input transform may contain splicing,
218  nnet_transf.Feedforward(CuMatrix<BaseFloat>(mat), &feats_transf);
219 
220  // store,
221  feats_utt.push_back(Matrix<BaseFloat>(feats_transf));
222  labels_utt.push_back(targets);
223  weights_utt.push_back(weights);
224  frame_num_utt.push_back(feats_transf.NumRows());
225 
226  if (frame_num_utt.size() == num_streams) break;
227 
228  // See how many frames we'd have (after padding), if we add one more utterance,
229  int32 max = (*std::max_element(frame_num_utt.begin(), frame_num_utt.end()));
230  if (max * (frame_num_utt.size() + 1) > max_frames) break;
231  }
232  }
233  // Having no data? Skip the cycle...
234  if (frame_num_utt.size() == 0) continue;
235 
236  // Pack the parallel data,
237  Matrix<BaseFloat> feat_mat_host;
238  Posterior target_host;
239  Vector<BaseFloat> weight_host;
240  {
241  // Number of sequences,
242  int32 n_streams = frame_num_utt.size();
243  int32 frame_num_padded = (*std::max_element(frame_num_utt.begin(), frame_num_utt.end()));
244  int32 feat_dim = feats_utt.front().NumCols();
245 
246  // Create the final feature matrix. Every utterance is padded to the max
247  // length within this group of utterances,
248  feat_mat_host.Resize(n_streams * frame_num_padded, feat_dim, kSetZero);
249  target_host.resize(n_streams * frame_num_padded);
250  weight_host.Resize(n_streams * frame_num_padded, kSetZero);
251 
252  for (int32 s = 0; s < n_streams; s++) {
253  const Matrix<BaseFloat>& mat_tmp = feats_utt[s];
254  for (int32 r = 0; r < frame_num_utt[s]; r++) {
255  feat_mat_host.Row(r*n_streams + s).CopyFromVec(mat_tmp.Row(r));
256  }
257  }
258 
259  for (int32 s = 0; s < n_streams; s++) {
260  const Posterior& target_tmp = labels_utt[s];
261  for (int32 r = 0; r < frame_num_utt[s]; r++) {
262  target_host[r*n_streams + s] = target_tmp[r];
263  }
264  }
265 
266  // padded frames will keep initial zero-weight,
267  for (int32 s = 0; s < n_streams; s++) {
268  const Vector<BaseFloat>& weight_tmp = weights_utt[s];
269  for (int32 r = 0; r < frame_num_utt[s]; r++) {
270  weight_host(r*n_streams + s) = weight_tmp(r);
271  }
272  }
273  }
274 
275  // Set the original lengths of utterances before padding,
276  nnet.SetSeqLengths(frame_num_utt);
277  // Show the 'utt' lengths in the VLOG[2],
278  if (GetVerboseLevel() >= 2) {
279  std::ostringstream os;
280  os << "[ ";
281  for (size_t i = 0; i < frame_num_utt.size(); i++) {
282  os << frame_num_utt[i] << " ";
283  }
284  os << "]";
285  KALDI_LOG << "frame_num_utt[" << frame_num_utt.size() << "]" << os.str();
286  }
287  // Reset all the streams (we have new sentences),
288  nnet.ResetStreams(std::vector<int32>(frame_num_utt.size(), 1));
289 
290  // Propagation,
291  nnet.Propagate(CuMatrix<BaseFloat>(feat_mat_host), &nnet_out);
292 
293  // Per-frame cross-entropy, gradients get re-scaled by weights,
294  xent.Eval(weight_host, nnet_out, target_host, &obj_diff);
295 
296  // Backward pass
297  if (!crossvalidate) {
298  nnet.Backpropagate(obj_diff, NULL);
299  }
300 
301  // 1st model update : show what happens in network,
302  if (total_frames == 0) {
303  KALDI_LOG << "### After " << total_frames << " frames,";
304  KALDI_LOG << nnet.Info();
305  KALDI_LOG << nnet.InfoPropagate();
306  if (!crossvalidate) {
307  KALDI_LOG << nnet.InfoBackPropagate();
308  KALDI_LOG << nnet.InfoGradient();
309  }
310  }
311 
312  kaldi::int64 tmp_frames = total_frames;
313 
314  num_done += frame_num_utt.size();
315  total_frames += std::accumulate(frame_num_utt.begin(), frame_num_utt.end(), 0);
316 
317  // monitor the NN training (--verbose=2),
318  int32 F = 25000;
319  if (GetVerboseLevel() >= 3) {
320  // print every 25k frames,
321  if (tmp_frames / F != total_frames / F) {
322  KALDI_VLOG(2) << "### After " << total_frames << " frames,";
323  KALDI_VLOG(2) << nnet.Info();
324  KALDI_VLOG(2) << nnet.InfoPropagate();
325  if (!crossvalidate) {
326  KALDI_VLOG(2) << nnet.InfoBackPropagate();
327  KALDI_VLOG(2) << nnet.InfoGradient();
328  }
329  }
330  }
331  }
332 
333  // after last model update : show what happens in network,
334  KALDI_LOG << "### After " << total_frames << " frames,";
335  KALDI_LOG << nnet.Info();
336  KALDI_LOG << nnet.InfoPropagate();
337  if (!crossvalidate) {
338  KALDI_LOG << nnet.InfoBackPropagate();
339  KALDI_LOG << nnet.InfoGradient();
340  }
341 
342  if (!crossvalidate) {
343  nnet.Write(target_model_filename, binary);
344  }
345 
346  KALDI_LOG << xent.ReportPerClass();
347  KALDI_LOG << "Done " << num_done << " files, " << num_no_tgt_mat
348  << " with no tgt_mats, " << num_other_error
349  << " with other errors. "
350  << "[" << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
351  << ", " << time.Elapsed() / 60 << " min, "
352  << "fps" << total_frames / time.Elapsed() << "]";
353  KALDI_LOG << xent.Report();
354 
355 #if HAVE_CUDA == 1
356  CuDevice::Instantiate().PrintProfile();
357 #endif
358  return 0;
359  } catch(const std::exception &e) {
360  std::cerr << e.what();
361  return -1;
362  }
363 }
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
void Init(SequentialBaseFloatMatrixReader *reader, MatrixBufferOptions opts=MatrixBufferOptions())
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ResetStreams(const std::vector< int32 > &stream_reset_flag)
Reset streams in multi-stream training,.
Definition: nnet-nnet.cc:281
int main(int argc, char *argv[])
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
void SetSeqLengths(const std::vector< int32 > &sequence_lengths)
Set sequence length in LSTM multi-stream training,.
Definition: nnet-nnet.cc:291
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
bool Open(const std::string &rspecifier)
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate cross entropy using target-matrix (supports soft labels),.
Definition: nnet-loss.cc:62
std::string Report()
Generate string with error report,.
Definition: nnet-loss.cc:182
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:46
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
A buffer for caching (utterance-key, feature-matrix) pairs.
bool HasKey(const std::string &key)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void Set(Real f)
Set all members of a vector to a specified value.
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420
std::string Info() const
Create string with human readable description of the nnet,.
Definition: nnet-nnet.cc:386
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
std::string ReportPerClass()
Generate string with per-class error report,.
Definition: nnet-loss.cc:203
void Register(OptionsItf *opts)
Definition: nnet-loss.h:45
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74