33 using namespace kaldi;
38 "Train RBM by Contrastive Divergence alg. with 1 step of " 39 "Markov Chain Monte-Carlo.\n" 40 "The tool can perform several iterations (--num-iters) " 41 "or it can subsample the training dataset (--drop-data)\n" 43 "Usage: rbm-train-cd1-frmshuff [options] <model-in> " 44 "<feature-rspecifier> <model-out>\n" 45 "e.g.: rbm-train-cd1-frmshuff 1.rbm.init scp:train.scp 1.rbm\n";
55 po.Register(
"binary", &binary,
"Write output in binary mode");
58 po.Register(
"with-bug", &with_bug,
59 "Apply bug which led to better results (set-initial-momentum-to-max)");
62 po.Register(
"num-iters", &num_iters,
63 "Number of iterations (smaller datasets should have more iterations, " 64 "iterating within tool because of linear momentum scheduling)");
66 std::string feature_transform;
67 po.Register(
"feature-transform", &feature_transform,
68 "Feature transform in 'nnet1' format");
75 po.Register(
"max-frames", &max_frames,
76 "Maximum number of frames an utterance can have (skipped if longer)");
78 std::string use_gpu=
"yes";
79 po.Register(
"use-gpu", &use_gpu,
80 "yes|no|optional, only has effect if compiled with CUDA");
84 if (po.NumArgs() != 3) {
89 std::string model_filename = po.GetArg(1),
90 feature_rspecifier = po.GetArg(2);
92 std::string target_model_filename;
93 target_model_filename = po.GetArg(3);
96 using namespace kaldi;
101 CuDevice::Instantiate().SelectGpuId(use_gpu);
105 if (feature_transform !=
"") {
106 rbm_transf.
Read(feature_transform);
111 nnet.
Read(model_filename);
125 trn_opts_rbm = trn_opts;
127 trn_opts_rbm.
learn_rate = learn_rate * (1 - momentum);
129 rbm.SetRbmTrainOptions(trn_opts_rbm);
131 kaldi::int64 total_frames = 0;
141 pos_hid, pos_hid_aux,
149 KALDI_LOG <<
"Iteration " << iter <<
"/" << num_iters;
151 int32 num_done = 0, num_other_error = 0;
152 while (!feature_reader.Done()) {
155 CuDevice::Instantiate().CheckGpuHealth();
158 for ( ; !feature_reader.Done(); feature_reader.Next()) {
159 if (feature_randomizer.IsFull()) {
164 std::string utt = feature_reader.Key();
169 if (mat.
NumRows() > max_frames) {
171 <<
" that has " << mat.
NumRows() <<
" frames," 172 <<
" it is longer than '--max-frames'" << max_frames;
179 feature_randomizer.AddData(feats_transf);
183 if (num_done % 5000 == 0) {
184 double time_now = time.
Elapsed();
185 KALDI_VLOG(1) <<
"After " << num_done <<
" utterances: " 186 <<
"time elapsed = " << time_now / 60 <<
" min; " 187 <<
"processed " << total_frames / time_now <<
" frames per sec.";
192 feature_randomizer.Randomize(
193 randomizer_mask.Generate(feature_randomizer.NumFrames())
197 for ( ; !feature_randomizer.Done(); feature_randomizer.Next()) {
201 int32 num_frames = pos_vis.
NumRows(),
202 dim_hid = rbm.OutputDim();
205 dummy_weights.Set(1.0);
209 rbm.Propagate(pos_vis, &pos_hid);
213 pos_hid_aux.
Resize(num_frames, dim_hid);
217 pos_hid_aux = pos_hid;
222 rbm.Reconstruct(pos_hid_aux, &neg_vis);
224 rbm.Propagate(neg_vis, &neg_hid);
226 rbm.RbmUpdate(pos_vis, pos_hid, neg_vis, neg_hid);
228 mse.Eval(dummy_weights, neg_vis, pos_vis, &dummy_mse_mat);
230 total_frames += num_frames;
234 static int32 n_prev = -1;
235 BaseFloat step = (momentum_max - momentum) / momentum_steps;
237 int32
n = total_frames / momentum_step_period;
239 if (n > momentum_steps) {
240 momentum_actual = momentum_max;
242 momentum_actual = momentum + n*step;
244 if (n - n_prev > 0) {
246 BaseFloat learning_rate_actual = learn_rate*(1-momentum_actual);
248 << (with_bug ? momentum_max : momentum_actual)
249 <<
" and learning rate " << learning_rate_actual
250 <<
" after processing " 251 << static_cast<double>(total_frames) / 360000 <<
" h";
253 trn_opts_rbm.
momentum = (with_bug ? momentum_max : momentum_actual);
254 trn_opts_rbm.
learn_rate = learning_rate_actual;
255 rbm.SetRbmTrainOptions(trn_opts_rbm);
261 if (feature_reader.Done() && (iter < num_iters)) {
263 KALDI_LOG <<
"Iteration " << iter <<
"/" << num_iters;
264 feature_reader.Close();
265 feature_reader.Open(feature_rspecifier);
269 nnet.
Write(target_model_filename, binary);
271 KALDI_LOG <<
"Done " << iter <<
" iterations, " << num_done <<
" files, " 272 <<
"skipped " << num_other_error <<
" files. " 273 <<
"[" << time.
Elapsed() / 60 <<
" min, " 274 <<
"processing" << total_frames / time.
Elapsed() <<
" " 275 <<
"frames per sec.]";
280 CuDevice::Instantiate().PrintProfile();
283 }
catch(
const std::exception &e) {
284 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
int32 momentum_step_period
int32 NumComponents() const
Returns the number of 'Components' which form the NN.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
Generates randomly ordered vector of indices,.
void BinarizeProbs(const CuMatrix< Real > &probs, CuMatrix< Real > *states)
align probabilities to discrete 0/1 states (use uniform sampling),
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void AddGaussNoise(CuMatrix< Real > *tgt, Real gscale=1.0)
add gaussian noise to each element,
void Register(OptionsItf *opts)
Configuration variables that affect how frame-level shuffling is done.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Shuffles rows of a matrix according to the indices in the mask,.
Matrix for CUDA computing.
A class representing a vector.
#define KALDI_ASSERT(cond)
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
void Register(OptionsItf *opts)
const Component & GetComponent(int32 c) const
Component accessor,.
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
void Register(OptionsItf *opts)
MatrixIndexT NumRows() const
Dimensions.
double Elapsed() const
Returns time in seconds.
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.