nnet-train-perutt.cc File Reference
Include dependency graph for nnet-train-perutt.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 29 of file nnet-train-perutt.cc.

References Nnet::Backpropagate(), VectorBase< Real >::Dim(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), Xent::Eval(), Mse::Eval(), MultiTaskLoss::Eval(), Nnet::Feedforward(), ParseOptions::GetArg(), kaldi::GetVerboseLevel(), RandomAccessTableReader< Holder >::HasKey(), Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), MultiTaskLoss::InitFromString(), KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, kaldi::kCopyData, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), RandomAccessTableReader< Holder >::Open(), ParseOptions::PrintUsage(), Nnet::Propagate(), ParseOptions::Read(), Nnet::Read(), LossOptions::Register(), NnetTrainOptions::Register(), ParseOptions::Register(), Xent::Report(), Mse::Report(), MultiTaskLoss::Report(), Xent::ReportPerClass(), Vector< Real >::Resize(), Matrix< Real >::Resize(), VectorBase< Real >::Set(), Nnet::SetDropoutRate(), Nnet::SetTrainOptions(), VectorBase< Real >::Sum(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and Nnet::Write().

29  {
30  using namespace kaldi;
31  using namespace kaldi::nnet1;
32  typedef kaldi::int32 int32;
33 
34  try {
35  const char *usage =
36  "Perform one iteration of NN training by SGD with per-utterance updates.\n"
37  "The training targets are represented as pdf-posteriors, usually prepared "
38  "by ali-to-post.\n"
39  "Usage: nnet-train-perutt [options] "
40  "<feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n"
41  "e.g.: nnet-train-perutt scp:feature.scp ark:posterior.ark nnet.init nnet.iter1\n";
42 
43  ParseOptions po(usage);
44 
45  NnetTrainOptions trn_opts;
46  trn_opts.Register(&po);
47  LossOptions loss_opts;
48  loss_opts.Register(&po);
49 
50  bool binary = true;
51  po.Register("binary", &binary, "Write output in binary mode");
52 
53  bool crossvalidate = false;
54  po.Register("cross-validate", &crossvalidate,
55  "Perform cross-validation (don't backpropagate)");
56 
57  std::string feature_transform;
58  po.Register("feature-transform", &feature_transform,
59  "Feature transform in Nnet format");
60 
61  std::string objective_function = "xent";
62  po.Register("objective-function", &objective_function,
63  "Objective function : xent|mse");
64 
65  int32 length_tolerance = 5;
66  po.Register("length-tolerance", &length_tolerance,
67  "Allowed length difference of features/targets (frames)");
68 
69  std::string frame_weights;
70  po.Register("frame-weights", &frame_weights,
71  "Per-frame weights to scale gradients (frame selection/weighting).");
72 
73  kaldi::int32 max_frames = 6000; // Allow segments maximum of one minute by default
74  po.Register("max-frames",&max_frames, "Maximum number of frames a segment can have to be processed");
75 
76  std::string use_gpu="yes";
77  po.Register("use-gpu", &use_gpu,
78  "yes|no|optional, only has effect if compiled with CUDA");
79 
81  bool randomize = false;
82  po.Register("randomize", &randomize,
83  "Dummy, for compatibility with 'steps/nnet/train_scheduler.sh'");
85 
86  po.Read(argc, argv);
87 
88  if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
89  po.PrintUsage();
90  exit(1);
91  }
92 
93  std::string feature_rspecifier = po.GetArg(1),
94  targets_rspecifier = po.GetArg(2),
95  model_filename = po.GetArg(3);
96 
97  std::string target_model_filename;
98  if (!crossvalidate) {
99  target_model_filename = po.GetArg(4);
100  }
101 
102  using namespace kaldi;
103  using namespace kaldi::nnet1;
104  typedef kaldi::int32 int32;
105 
106 #if HAVE_CUDA == 1
107  CuDevice::Instantiate().SelectGpuId(use_gpu);
108 #endif
109 
110  Nnet nnet_transf;
111  if (feature_transform != "") {
112  nnet_transf.Read(feature_transform);
113  }
114 
115  Nnet nnet;
116  nnet.Read(model_filename);
117  nnet.SetTrainOptions(trn_opts);
118 
119  if (crossvalidate) {
120  nnet_transf.SetDropoutRate(0.0);
121  nnet.SetDropoutRate(0.0);
122  }
123 
124  kaldi::int64 total_frames = 0;
125 
126  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
127  RandomAccessPosteriorReader targets_reader(targets_rspecifier);
128  RandomAccessBaseFloatVectorReader weights_reader;
129  if (frame_weights != "") {
130  weights_reader.Open(frame_weights);
131  }
132 
133  Xent xent(loss_opts);
134  Mse mse(loss_opts);
135 
136  MultiTaskLoss multitask(loss_opts);
137  if (0 == objective_function.compare(0, 9, "multitask")) {
138  // objective_function contains something like :
139  // 'multitask,xent,2456,1.0,mse,440,0.001'
140  //
141  // the meaning is following:
142  // 'multitask,<type1>,<dim1>,<weight1>,...,<typeN>,<dimN>,<weightN>'
143  multitask.InitFromString(objective_function);
144  }
145 
146  CuMatrix<BaseFloat> feats, feats_transf, nnet_out, obj_diff;
147 
148  Timer time;
149  KALDI_LOG << (crossvalidate?"CROSS-VALIDATION":"TRAINING") << " STARTED";
150 
151  int32 num_done = 0,
152  num_no_tgt_mat = 0,
153  num_other_error = 0;
154 
155  // main loop,
156  for ( ; !feature_reader.Done(); feature_reader.Next()) {
157  std::string utt = feature_reader.Key();
158  KALDI_VLOG(3) << "Reading " << utt;
159  // check that we have targets
160  if (!targets_reader.HasKey(utt)) {
161  KALDI_WARN << utt << ", missing targets";
162  num_no_tgt_mat++;
163  continue;
164  }
165  // check we have per-frame weights
166  if (frame_weights != "" && !weights_reader.HasKey(utt)) {
167  KALDI_WARN << utt << ", missing per-frame weights";
168  num_other_error++;
169  feature_reader.Next();
170  continue;
171  }
172  // get feature / target pair
173  Matrix<BaseFloat> mat = feature_reader.Value();
174  Posterior nnet_tgt = targets_reader.Value(utt);
175  // skip the sentence if it is too long,
176  if (mat.NumRows() > max_frames) {
177  KALDI_WARN << "Skipping " << utt
178  << " that has " << mat.NumRows() << " frames,"
179  << " it is longer than '--max-frames'" << max_frames;
180  num_other_error++;
181  continue;
182  }
183  // get per-frame weights
184  Vector<BaseFloat> frm_weights;
185  if (frame_weights != "") {
186  frm_weights = weights_reader.Value(utt);
187  } else { // all per-frame weights are 1.0
188  frm_weights.Resize(mat.NumRows());
189  frm_weights.Set(1.0);
190  }
191  // correct small length mismatch ... or drop sentence
192  {
193  // add lengths to vector
194  std::vector<int32> length;
195  length.push_back(mat.NumRows());
196  length.push_back(nnet_tgt.size());
197  length.push_back(frm_weights.Dim());
198  // find min, max
199  int32 min = *std::min_element(length.begin(), length.end());
200  int32 max = *std::max_element(length.begin(), length.end());
201  // fix or drop ?
202  if (max - min < length_tolerance) {
203  if (mat.NumRows() != min) mat.Resize(min, mat.NumCols(), kCopyData);
204  if (nnet_tgt.size() != min) nnet_tgt.resize(min);
205  if (frm_weights.Dim() != min) frm_weights.Resize(min, kCopyData);
206  } else {
207  KALDI_WARN << utt << ", length mismatch of targets " << nnet_tgt.size()
208  << " and features " << mat.NumRows();
209  num_other_error++;
210  continue;
211  }
212  }
213  // apply optional feature transform
214  nnet_transf.Feedforward(CuMatrix<BaseFloat>(mat), &feats_transf);
215 
216  // forward pass
217  nnet.Propagate(feats_transf, &nnet_out);
218 
219  // evaluate objective function we've chosen,
220  if (objective_function == "xent") {
221  // gradients are re-scaled by weights inside Eval,
222  xent.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
223  } else if (objective_function == "mse") {
224  // gradients are re-scaled by weights inside Eval,
225  mse.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
226  } else if (0 == objective_function.compare(0, 9, "multitask")) {
227  // gradients re-scaled by weights in Eval,
228  multitask.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
229  } else {
230  KALDI_ERR << "Unknown objective function code : "
231  << objective_function;
232  }
233 
234  if (!crossvalidate) {
235  // backpropagate and update,
236  nnet.Backpropagate(obj_diff, NULL);
237  }
238 
239  // 1st minibatch : show what happens in network,
240  if (total_frames == 0) {
241  KALDI_LOG << "### After " << total_frames << " frames,";
242  KALDI_LOG << nnet.InfoPropagate();
243  if (!crossvalidate) {
244  KALDI_LOG << nnet.InfoBackPropagate();
245  KALDI_LOG << nnet.InfoGradient();
246  }
247  }
248 
249  // VERBOSE LOG
250  // monitor the NN training (--verbose=2),
251  if (GetVerboseLevel() >= 2) {
252  static int32 counter = 0;
253  counter += mat.NumRows();
254  // print every 25k frames,
255  if (counter >= 25000) {
256  KALDI_VLOG(2) << "### After " << total_frames << " frames,";
257  KALDI_VLOG(2) << nnet.InfoPropagate();
258  if (!crossvalidate) {
259  KALDI_VLOG(2) << nnet.InfoBackPropagate();
260  KALDI_VLOG(2) << nnet.InfoGradient();
261  }
262  counter = 0;
263  }
264  }
265 
266  num_done++;
267  total_frames += frm_weights.Sum();
268  } // main loop,
269 
270  // after last minibatch : show what happens in network,
271  KALDI_LOG << "### After " << total_frames << " frames,";
272  KALDI_LOG << nnet.InfoPropagate();
273  if (!crossvalidate) {
274  KALDI_LOG << nnet.InfoBackPropagate();
275  KALDI_LOG << nnet.InfoGradient();
276  }
277 
278  if (!crossvalidate) {
279  nnet.Write(target_model_filename, binary);
280  }
281 
282  KALDI_LOG << "Done " << num_done << " files, "
283  << num_no_tgt_mat << " with no tgt_mats, "
284  << num_other_error << " with other errors. "
285  << "[" << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
286  << ", " << (randomize ? "RANDOMIZED" : "NOT-RANDOMIZED")
287  << ", " << time.Elapsed() / 60 << " min, processing "
288  << total_frames / time.Elapsed() << " frames per sec.]";
289 
290  if (objective_function == "xent") {
291  KALDI_LOG << xent.ReportPerClass();
292  KALDI_LOG << xent.Report();
293  } else if (objective_function == "mse") {
294  KALDI_LOG << mse.Report();
295  } else if (0 == objective_function.compare(0, 9, "multitask")) {
296  KALDI_LOG << multitask.Report();
297  } else {
298  KALDI_ERR << "Unknown objective function code : " << objective_function;
299  }
300 
301 #if HAVE_CUDA == 1
302  CuDevice::Instantiate().PrintProfile();
303 #endif
304 
305  return 0;
306  } catch(const std::exception &e) {
307  std::cerr << e.what();
308  return -1;
309  }
310 }
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
bool Open(const std::string &rspecifier)
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:46
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
bool HasKey(const std::string &key)
Real Sum() const
Returns sum of the elements.
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void Set(Real f)
Set all members of a vector to a specified value.
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
void Register(OptionsItf *opts)
Definition: nnet-loss.h:45
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74