All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
rbm-train-cd1-frmshuff.cc File Reference
#include "nnet/nnet-trnopts.h"
#include "nnet/nnet-rbm.h"
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-loss.h"
#include "nnet/nnet-randomizer.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "base/timer.h"
#include "cudamatrix/cu-device.h"
#include "cudamatrix/cu-rand.h"
Include dependency graph for rbm-train-cd1-frmshuff.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 32 of file rbm-train-cd1-frmshuff.cc.

References MatrixRandomizer::AddData(), CuRand< Real >::AddGaussNoise(), RbmBase::Bernoulli, CuRand< Real >::BinarizeProbs(), SequentialTableReader< Holder >::Close(), MatrixRandomizer::Done(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), Mse::Eval(), Nnet::Feedforward(), RbmBase::Gaussian, RandomizerMask::Generate(), ParseOptions::GetArg(), Nnet::GetComponent(), Component::GetType(), MatrixRandomizer::IsFull(), KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), Component::kRbm, RbmTrainOptions::learn_rate, NnetDataRandomizerOptions::minibatch_size, RbmTrainOptions::momentum, RbmTrainOptions::momentum_max, RbmTrainOptions::momentum_step_period, RbmTrainOptions::momentum_steps, rnnlm::n, MatrixRandomizer::Next(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), Nnet::NumComponents(), MatrixRandomizer::NumFrames(), MatrixBase< Real >::NumRows(), CuMatrixBase< Real >::NumRows(), SequentialTableReader< Holder >::Open(), ParseOptions::PrintUsage(), MatrixRandomizer::Randomize(), ParseOptions::Read(), Nnet::Read(), NnetDataRandomizerOptions::Register(), ParseOptions::Register(), RbmTrainOptions::Register(), Mse::Report(), CuMatrix< Real >::Resize(), VectorBase< Real >::Set(), MatrixRandomizer::Value(), SequentialTableReader< Holder >::Value(), and Nnet::Write().

32  {
33  using namespace kaldi;
34  using namespace kaldi::nnet1;
35  typedef kaldi::int32 int32;
36  try {
37  const char *usage =
38  "Train RBM by Contrastive Divergence alg. with 1 step of "
39  "Markov Chain Monte-Carlo.\n"
40  "The tool can perform several iterations (--num-iters) "
41  "or it can subsample the training dataset (--drop-data)\n"
42 
43  "Usage: rbm-train-cd1-frmshuff [options] <model-in> "
44  "<feature-rspecifier> <model-out>\n"
45  "e.g.: rbm-train-cd1-frmshuff 1.rbm.init scp:train.scp 1.rbm\n";
46 
47  ParseOptions po(usage);
48 
49  RbmTrainOptions trn_opts, trn_opts_rbm;
50  trn_opts.Register(&po);
51 
52  bool binary = false;
53  po.Register("binary", &binary, "Write output in binary mode");
54 
55  bool with_bug = true;
56  po.Register("with-bug", &with_bug,
57  "Apply bug which led to better results (set-initial-momentum-to-max)");
58 
59  int32 num_iters = 1;
60  po.Register("num-iters", &num_iters,
61  "Number of iterations (smaller datasets should have more iterations, "
62  "iterating within tool because of linear momentum scheduling)");
63 
64  std::string feature_transform;
65  po.Register("feature-transform", &feature_transform,
66  "Feature transform in 'nnet1' format");
67 
69  rnd_opts.minibatch_size = 100;
70  rnd_opts.Register(&po);
71 
72  kaldi::int32 max_frames = 6000;
73  po.Register("max-frames", &max_frames,
74  "Maximum number of frames an utterance can have (skipped if longer)");
75 
76  std::string use_gpu="yes";
77  po.Register("use-gpu", &use_gpu,
78  "yes|no|optional, only has effect if compiled with CUDA");
79 
80  po.Read(argc, argv);
81 
82  if (po.NumArgs() != 3) {
83  po.PrintUsage();
84  exit(1);
85  }
86 
87  std::string model_filename = po.GetArg(1),
88  feature_rspecifier = po.GetArg(2);
89 
90  std::string target_model_filename;
91  target_model_filename = po.GetArg(3);
92 
93 
94  using namespace kaldi;
95  using namespace kaldi::nnet1;
96  typedef kaldi::int32 int32;
97 
98 #if HAVE_CUDA == 1
99  CuDevice::Instantiate().SelectGpuId(use_gpu);
100 #endif
101 
102  Nnet rbm_transf;
103  if (feature_transform != "") {
104  rbm_transf.Read(feature_transform);
105  }
106 
107  // Read nnet, extract the RBM,
108  Nnet nnet;
109  nnet.Read(model_filename);
110  KALDI_ASSERT(nnet.NumComponents() == 1);
112  RbmBase &rbm = dynamic_cast<RbmBase&>(nnet.GetComponent(0));
113 
114  // Configure the RBM,
115  // make some constants accessible, will use them later,
116  const BaseFloat& learn_rate = trn_opts.learn_rate;
117  const BaseFloat& momentum = trn_opts.momentum;
118  const BaseFloat& momentum_max = trn_opts.momentum_max;
119  const int32& momentum_steps = trn_opts.momentum_steps;
120  const int32& momentum_step_period = trn_opts.momentum_step_period;
121 
122  // 'trn_opts_rbm' is a local copy of 'trn_opts' which is passed to RBM,
123  trn_opts_rbm = trn_opts;
124  // keep `effective' learning rate constant
125  trn_opts_rbm.learn_rate = learn_rate * (1 - momentum);
126  // pass options to RBM,
127  rbm.SetRbmTrainOptions(trn_opts_rbm);
128 
129  kaldi::int64 total_frames = 0;
130 
131  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
132  RandomizerMask randomizer_mask(rnd_opts);
133  MatrixRandomizer feature_randomizer(rnd_opts);
134 
135  CuRand<BaseFloat> cu_rand; // parallel random number generator,
136  Mse mse;
137 
138  CuMatrix<BaseFloat> feats_transf,
139  pos_hid, pos_hid_aux,
140  neg_vis, neg_hid;
141  CuMatrix<BaseFloat> dummy_mse_mat;
142 
143  Timer time;
144  KALDI_LOG << "RBM TRAINING STARTED";
145 
146  int32 iter = 1;
147  KALDI_LOG << "Iteration " << iter << "/" << num_iters;
148 
149  int32 num_done = 0, num_other_error = 0;
150  while (!feature_reader.Done()) {
151 #if HAVE_CUDA == 1
152  // check that GPU is computing accurately,
153  CuDevice::Instantiate().CheckGpuHealth();
154 #endif
155  // fill the randomizer,
156  for ( ; !feature_reader.Done(); feature_reader.Next()) {
157  if (feature_randomizer.IsFull()) {
158  // break the loop without calling Next(),
159  // we keep the 'utt' for next round,
160  break;
161  }
162  std::string utt = feature_reader.Key();
163  KALDI_VLOG(3) << "Reading " << utt;
164  // get feature matrix,
165  const Matrix<BaseFloat> &mat = feature_reader.Value();
166  // skip too long segments (avoid runinning out of memory)
167  if (mat.NumRows() > max_frames) {
168  KALDI_WARN << "Skipping " << utt
169  << " that has " << mat.NumRows() << " frames,"
170  << " it is longer than '--max-frames'" << max_frames;
171  num_other_error++;
172  continue;
173  }
174  // apply feature transform,
175  rbm_transf.Feedforward(CuMatrix<BaseFloat>(mat), &feats_transf);
176  // add to randomizer,
177  feature_randomizer.AddData(feats_transf);
178  num_done++;
179 
180  // report the speed
181  if (num_done % 5000 == 0) {
182  double time_now = time.Elapsed();
183  KALDI_VLOG(1) << "After " << num_done << " utterances: "
184  << "time elapsed = " << time_now / 60 << " min; "
185  << "processed " << total_frames / time_now << " frames per sec.";
186  }
187  }
188 
189  // randomize,
190  feature_randomizer.Randomize(
191  randomizer_mask.Generate(feature_randomizer.NumFrames())
192  );
193 
194  // train with data from randomizer (using mini-batches)
195  for ( ; !feature_randomizer.Done(); feature_randomizer.Next()) {
196  // get the mini-batch,
197  const CuMatrixBase<BaseFloat>& pos_vis = feature_randomizer.Value();
198  // get the dims,
199  int32 num_frames = pos_vis.NumRows(),
200  dim_hid = rbm.OutputDim();
201  // Create dummy frame-weights for Mse::Eval,
202  Vector<BaseFloat> dummy_weights(num_frames);
203  dummy_weights.Set(1.0);
204 
205  // TRAIN with CD1,
206  // forward pass,
207  rbm.Propagate(pos_vis, &pos_hid);
208 
209  // alter the hidden values, so we can generate negative example,
210  if (rbm.HidType() == Rbm::Bernoulli) {
211  pos_hid_aux.Resize(num_frames, dim_hid);
212  cu_rand.BinarizeProbs(pos_hid, &pos_hid_aux); // => 0 / 1,
213  } else {
214  KALDI_ASSERT(rbm.HidType() == Rbm::Gaussian);
215  pos_hid_aux = pos_hid;
216  cu_rand.AddGaussNoise(&pos_hid_aux);
217  }
218 
219  // reconstruct pass,
220  rbm.Reconstruct(pos_hid_aux, &neg_vis);
221  // propagate negative examples
222  rbm.Propagate(neg_vis, &neg_hid);
223  // update step
224  rbm.RbmUpdate(pos_vis, pos_hid, neg_vis, neg_hid);
225  // evaluate mean square error
226  mse.Eval(dummy_weights, neg_vis, pos_vis, &dummy_mse_mat);
227 
228  total_frames += num_frames;
229 
230  // change the momentum progressively per 0.5million samples of the data
231  {
232  static int32 n_prev = -1;
233  BaseFloat step = (momentum_max - momentum) / momentum_steps;
234  // change every momentum_step_period data,
235  int32 n = total_frames / momentum_step_period;
236  BaseFloat momentum_actual;
237  if (n > momentum_steps) {
238  momentum_actual = momentum_max;
239  } else {
240  momentum_actual = momentum + n*step;
241  }
242  if (n - n_prev > 0) {
243  n_prev = n;
244  BaseFloat learning_rate_actual = learn_rate*(1-momentum_actual);
245  KALDI_VLOG(1) << "Setting momentum "
246  << (with_bug ? momentum_max : momentum_actual)
247  << " and learning rate " << learning_rate_actual
248  << " after processing "
249  << static_cast<double>(total_frames) / 360000 << " h";
250  // pass values to rbm,
251  trn_opts_rbm.momentum = (with_bug ? momentum_max : momentum_actual);
252  trn_opts_rbm.learn_rate = learning_rate_actual;
253  rbm.SetRbmTrainOptions(trn_opts_rbm);
254  }
255  }
256  }
257 
258  // reopen the feature stream if we will run another iteration
259  if (feature_reader.Done() && (iter < num_iters)) {
260  iter++;
261  KALDI_LOG << "Iteration " << iter << "/" << num_iters;
262  feature_reader.Close();
263  feature_reader.Open(feature_rspecifier);
264  }
265  }
266 
267  nnet.Write(target_model_filename, binary);
268 
269  KALDI_LOG << "Done " << iter << " iterations, " << num_done << " files, "
270  << "skipped " << num_other_error << " files. "
271  << "[" << time.Elapsed() / 60 << " min, "
272  << "processing" << total_frames / time.Elapsed() << " "
273  << "frames per sec.]";
274 
275  KALDI_LOG << mse.Report();
276 
277 #if HAVE_CUDA == 1
278  CuDevice::Instantiate().PrintProfile();
279 #endif
280  return 0;
281  } catch(const std::exception &e) {
282  std::cerr << e.what();
283  return -1;
284  }
285 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
Generates randomly ordered vector of indices,.
std::string Report()
Generate string with error report.
Definition: nnet-loss.cc:289
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:86
void AddGaussNoise(CuMatrix< Real > *tgt, Real gscale=1.0)
add gaussian noise to each element,
Definition: cu-rand.cc:189
Configuration variables that affect how frame-level shuffling is done.
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:47
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
struct rnnlm::@11::@12 n
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate mean square error using target-matrix,.
Definition: nnet-loss.cc:217
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
Definition: nnet-nnet.cc:333
#define KALDI_WARN
Definition: kaldi-error.h:130
void BinarizeProbs(const CuMatrix< Real > &probs, CuMatrix< Real > *states)
align probabilities to discrete 0/1 states (use uniform sampling),
Definition: cu-rand.cc:179
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
const Component & GetComponent(int32 c) const
Component accessor,.
Definition: nnet-nnet.cc:153
Shuffles rows of a matrix according to the indices in the mask,.
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
Definition: nnet-nnet.cc:367
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
int32 NumComponents() const
Returns the number of 'Components' which form the NN.
Definition: nnet-nnet.h:66
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
#define KALDI_LOG
Definition: kaldi-error.h:133