All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
rbm-train-cd1-frmshuff.cc File Reference
#include "nnet/nnet-trnopts.h"
#include "nnet/nnet-rbm.h"
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-loss.h"
#include "nnet/nnet-randomizer.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "base/timer.h"
#include "cudamatrix/cu-device.h"
#include "cudamatrix/cu-rand.h"
Include dependency graph for rbm-train-cd1-frmshuff.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 32 of file rbm-train-cd1-frmshuff.cc.

References MatrixRandomizer::AddData(), CuRand< Real >::AddGaussNoise(), RbmBase::Bernoulli, CuRand< Real >::BinarizeProbs(), SequentialTableReader< Holder >::Close(), MatrixRandomizer::Done(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), Mse::Eval(), Nnet::Feedforward(), RbmBase::Gaussian, RandomizerMask::Generate(), ParseOptions::GetArg(), Nnet::GetComponent(), Component::GetType(), MatrixRandomizer::IsFull(), KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), Component::kRbm, RbmTrainOptions::learn_rate, NnetDataRandomizerOptions::minibatch_size, RbmTrainOptions::momentum, RbmTrainOptions::momentum_max, RbmTrainOptions::momentum_step_period, RbmTrainOptions::momentum_steps, rnnlm::n, MatrixRandomizer::Next(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), Nnet::NumComponents(), MatrixRandomizer::NumFrames(), MatrixBase< Real >::NumRows(), CuMatrixBase< Real >::NumRows(), SequentialTableReader< Holder >::Open(), ParseOptions::PrintUsage(), MatrixRandomizer::Randomize(), ParseOptions::Read(), Nnet::Read(), LossOptions::Register(), NnetDataRandomizerOptions::Register(), ParseOptions::Register(), RbmTrainOptions::Register(), Mse::Report(), CuMatrix< Real >::Resize(), VectorBase< Real >::Set(), MatrixRandomizer::Value(), SequentialTableReader< Holder >::Value(), and Nnet::Write().

32  {
33  using namespace kaldi;
34  using namespace kaldi::nnet1;
35  typedef kaldi::int32 int32;
36  try {
37  const char *usage =
38  "Train RBM by Contrastive Divergence alg. with 1 step of "
39  "Markov Chain Monte-Carlo.\n"
40  "The tool can perform several iterations (--num-iters) "
41  "or it can subsample the training dataset (--drop-data)\n"
42 
43  "Usage: rbm-train-cd1-frmshuff [options] <model-in> "
44  "<feature-rspecifier> <model-out>\n"
45  "e.g.: rbm-train-cd1-frmshuff 1.rbm.init scp:train.scp 1.rbm\n";
46 
47  ParseOptions po(usage);
48 
49  RbmTrainOptions trn_opts, trn_opts_rbm;
50  trn_opts.Register(&po);
51  LossOptions loss_opts;
52  loss_opts.Register(&po);
53 
54  bool binary = false;
55  po.Register("binary", &binary, "Write output in binary mode");
56 
57  bool with_bug = true;
58  po.Register("with-bug", &with_bug,
59  "Apply bug which led to better results (set-initial-momentum-to-max)");
60 
61  int32 num_iters = 1;
62  po.Register("num-iters", &num_iters,
63  "Number of iterations (smaller datasets should have more iterations, "
64  "iterating within tool because of linear momentum scheduling)");
65 
66  std::string feature_transform;
67  po.Register("feature-transform", &feature_transform,
68  "Feature transform in 'nnet1' format");
69 
71  rnd_opts.minibatch_size = 100;
72  rnd_opts.Register(&po);
73 
74  kaldi::int32 max_frames = 6000;
75  po.Register("max-frames", &max_frames,
76  "Maximum number of frames an utterance can have (skipped if longer)");
77 
78  std::string use_gpu="yes";
79  po.Register("use-gpu", &use_gpu,
80  "yes|no|optional, only has effect if compiled with CUDA");
81 
82  po.Read(argc, argv);
83 
84  if (po.NumArgs() != 3) {
85  po.PrintUsage();
86  exit(1);
87  }
88 
89  std::string model_filename = po.GetArg(1),
90  feature_rspecifier = po.GetArg(2);
91 
92  std::string target_model_filename;
93  target_model_filename = po.GetArg(3);
94 
95 
96  using namespace kaldi;
97  using namespace kaldi::nnet1;
98  typedef kaldi::int32 int32;
99 
100 #if HAVE_CUDA == 1
101  CuDevice::Instantiate().SelectGpuId(use_gpu);
102 #endif
103 
104  Nnet rbm_transf;
105  if (feature_transform != "") {
106  rbm_transf.Read(feature_transform);
107  }
108 
109  // Read nnet, extract the RBM,
110  Nnet nnet;
111  nnet.Read(model_filename);
112  KALDI_ASSERT(nnet.NumComponents() == 1);
114  RbmBase &rbm = dynamic_cast<RbmBase&>(nnet.GetComponent(0));
115 
116  // Configure the RBM,
117  // make some constants accessible, will use them later,
118  const BaseFloat& learn_rate = trn_opts.learn_rate;
119  const BaseFloat& momentum = trn_opts.momentum;
120  const BaseFloat& momentum_max = trn_opts.momentum_max;
121  const int32& momentum_steps = trn_opts.momentum_steps;
122  const int32& momentum_step_period = trn_opts.momentum_step_period;
123 
124  // 'trn_opts_rbm' is a local copy of 'trn_opts' which is passed to RBM,
125  trn_opts_rbm = trn_opts;
126  // keep `effective' learning rate constant
127  trn_opts_rbm.learn_rate = learn_rate * (1 - momentum);
128  // pass options to RBM,
129  rbm.SetRbmTrainOptions(trn_opts_rbm);
130 
131  kaldi::int64 total_frames = 0;
132 
133  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
134  RandomizerMask randomizer_mask(rnd_opts);
135  MatrixRandomizer feature_randomizer(rnd_opts);
136 
137  CuRand<BaseFloat> cu_rand; // parallel random number generator,
138  Mse mse(loss_opts);
139 
140  CuMatrix<BaseFloat> feats_transf,
141  pos_hid, pos_hid_aux,
142  neg_vis, neg_hid;
143  CuMatrix<BaseFloat> dummy_mse_mat;
144 
145  Timer time;
146  KALDI_LOG << "RBM TRAINING STARTED";
147 
148  int32 iter = 1;
149  KALDI_LOG << "Iteration " << iter << "/" << num_iters;
150 
151  int32 num_done = 0, num_other_error = 0;
152  while (!feature_reader.Done()) {
153 #if HAVE_CUDA == 1
154  // check that GPU is computing accurately,
155  CuDevice::Instantiate().CheckGpuHealth();
156 #endif
157  // fill the randomizer,
158  for ( ; !feature_reader.Done(); feature_reader.Next()) {
159  if (feature_randomizer.IsFull()) {
160  // break the loop without calling Next(),
161  // we keep the 'utt' for next round,
162  break;
163  }
164  std::string utt = feature_reader.Key();
165  KALDI_VLOG(3) << "Reading " << utt;
166  // get feature matrix,
167  const Matrix<BaseFloat> &mat = feature_reader.Value();
168  // skip too long segments (avoid runinning out of memory)
169  if (mat.NumRows() > max_frames) {
170  KALDI_WARN << "Skipping " << utt
171  << " that has " << mat.NumRows() << " frames,"
172  << " it is longer than '--max-frames'" << max_frames;
173  num_other_error++;
174  continue;
175  }
176  // apply feature transform,
177  rbm_transf.Feedforward(CuMatrix<BaseFloat>(mat), &feats_transf);
178  // add to randomizer,
179  feature_randomizer.AddData(feats_transf);
180  num_done++;
181 
182  // report the speed
183  if (num_done % 5000 == 0) {
184  double time_now = time.Elapsed();
185  KALDI_VLOG(1) << "After " << num_done << " utterances: "
186  << "time elapsed = " << time_now / 60 << " min; "
187  << "processed " << total_frames / time_now << " frames per sec.";
188  }
189  }
190 
191  // randomize,
192  feature_randomizer.Randomize(
193  randomizer_mask.Generate(feature_randomizer.NumFrames())
194  );
195 
196  // train with data from randomizer (using mini-batches)
197  for ( ; !feature_randomizer.Done(); feature_randomizer.Next()) {
198  // get the mini-batch,
199  const CuMatrixBase<BaseFloat>& pos_vis = feature_randomizer.Value();
200  // get the dims,
201  int32 num_frames = pos_vis.NumRows(),
202  dim_hid = rbm.OutputDim();
203  // Create dummy frame-weights for Mse::Eval,
204  Vector<BaseFloat> dummy_weights(num_frames);
205  dummy_weights.Set(1.0);
206 
207  // TRAIN with CD1,
208  // forward pass,
209  rbm.Propagate(pos_vis, &pos_hid);
210 
211  // alter the hidden values, so we can generate negative example,
212  if (rbm.HidType() == Rbm::Bernoulli) {
213  pos_hid_aux.Resize(num_frames, dim_hid);
214  cu_rand.BinarizeProbs(pos_hid, &pos_hid_aux); // => 0 / 1,
215  } else {
216  KALDI_ASSERT(rbm.HidType() == Rbm::Gaussian);
217  pos_hid_aux = pos_hid;
218  cu_rand.AddGaussNoise(&pos_hid_aux);
219  }
220 
221  // reconstruct pass,
222  rbm.Reconstruct(pos_hid_aux, &neg_vis);
223  // propagate negative examples
224  rbm.Propagate(neg_vis, &neg_hid);
225  // update step
226  rbm.RbmUpdate(pos_vis, pos_hid, neg_vis, neg_hid);
227  // evaluate mean square error
228  mse.Eval(dummy_weights, neg_vis, pos_vis, &dummy_mse_mat);
229 
230  total_frames += num_frames;
231 
232  // change the momentum progressively per 0.5million samples of the data
233  {
234  static int32 n_prev = -1;
235  BaseFloat step = (momentum_max - momentum) / momentum_steps;
236  // change every momentum_step_period data,
237  int32 n = total_frames / momentum_step_period;
238  BaseFloat momentum_actual;
239  if (n > momentum_steps) {
240  momentum_actual = momentum_max;
241  } else {
242  momentum_actual = momentum + n*step;
243  }
244  if (n - n_prev > 0) {
245  n_prev = n;
246  BaseFloat learning_rate_actual = learn_rate*(1-momentum_actual);
247  KALDI_VLOG(1) << "Setting momentum "
248  << (with_bug ? momentum_max : momentum_actual)
249  << " and learning rate " << learning_rate_actual
250  << " after processing "
251  << static_cast<double>(total_frames) / 360000 << " h";
252  // pass values to rbm,
253  trn_opts_rbm.momentum = (with_bug ? momentum_max : momentum_actual);
254  trn_opts_rbm.learn_rate = learning_rate_actual;
255  rbm.SetRbmTrainOptions(trn_opts_rbm);
256  }
257  }
258  }
259 
260  // reopen the feature stream if we will run another iteration
261  if (feature_reader.Done() && (iter < num_iters)) {
262  iter++;
263  KALDI_LOG << "Iteration " << iter << "/" << num_iters;
264  feature_reader.Close();
265  feature_reader.Open(feature_rspecifier);
266  }
267  }
268 
269  nnet.Write(target_model_filename, binary);
270 
271  KALDI_LOG << "Done " << iter << " iterations, " << num_done << " files, "
272  << "skipped " << num_other_error << " files. "
273  << "[" << time.Elapsed() / 60 << " min, "
274  << "processing" << total_frames / time.Elapsed() << " "
275  << "frames per sec.]";
276 
277  KALDI_LOG << mse.Report();
278 
279 #if HAVE_CUDA == 1
280  CuDevice::Instantiate().PrintProfile();
281 #endif
282  return 0;
283  } catch(const std::exception &e) {
284  std::cerr << e.what();
285  return -1;
286  }
287 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
Generates randomly ordered vector of indices,.
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:86
void AddGaussNoise(CuMatrix< Real > *tgt, Real gscale=1.0)
add gaussian noise to each element,
Definition: cu-rand.cc:202
Configuration variables that affect how frame-level shuffling is done.
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:49
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:205
struct rnnlm::@11::@12 n
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
Definition: nnet-nnet.cc:333
#define KALDI_WARN
Definition: kaldi-error.h:130
void BinarizeProbs(const CuMatrix< Real > &probs, CuMatrix< Real > *states)
align probabilities to discrete 0/1 states (use uniform sampling),
Definition: cu-rand.cc:192
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
const Component & GetComponent(int32 c) const
Component accessor,.
Definition: nnet-nnet.cc:153
Shuffles rows of a matrix according to the indices in the mask,.
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
Definition: nnet-nnet.cc:367
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
int32 NumComponents() const
Returns the number of 'Components' which form the NN.
Definition: nnet-nnet.h:66
void Register(OptionsItf *opts)
Definition: nnet-loss.h:45
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
#define KALDI_LOG
Definition: kaldi-error.h:133