86 using namespace kaldi;
91 "Perform one iteration of MMI training using SGD with per-utterance" 94 "Usage: nnet-train-mmi-sequential [options] " 95 "<model-in> <transition-model-in> <feature-rspecifier> " 96 "<den-lat-rspecifier> <ali-rspecifier> [<model-out>]\n" 98 "e.g.: nnet-train-mmi-sequential nnet.init trans.mdl scp:feats.scp " 99 "scp:denlats.scp ark:ali.ark nnet.iter1\n";
108 po.Register(
"binary", &binary,
"Write output in binary mode");
110 std::string feature_transform;
111 po.Register(
"feature-transform", &feature_transform,
112 "Feature transform in 'nnet1' format");
119 old_acoustic_scale = 0.0;
121 po.Register(
"acoustic-scale", &acoustic_scale,
122 "Scaling factor for acoustic likelihoods");
124 po.Register(
"lm-scale", &lm_scale,
125 "Scaling factor for \"graph costs\" (including LM costs)");
127 po.Register(
"old-acoustic-scale", &old_acoustic_scale,
128 "Add in the scores in the input lattices with this scale, " 129 "rather than discarding them.");
132 po.Register(
"max-frames", &max_frames,
133 "Maximum number of frames an utterance can have (skipped if longer)");
135 bool drop_frames =
true;
136 po.Register(
"drop-frames", &drop_frames,
137 "Drop frames, where is zero den-posterior under numerator path " 138 "(ie. path not in lattice)");
140 std::string use_gpu=
"yes";
141 po.Register(
"use-gpu", &use_gpu,
142 "yes|no|optional, only has effect if compiled with CUDA");
146 if (po.NumArgs() != 6) {
151 std::string model_filename = po.GetArg(1),
152 transition_model_filename = po.GetArg(2),
153 feature_rspecifier = po.GetArg(3),
154 den_lat_rspecifier = po.GetArg(4),
155 num_ali_rspecifier = po.GetArg(5),
156 target_model_filename = po.GetArg(6);
158 using namespace kaldi;
163 CuDevice::Instantiate().SelectGpuId(use_gpu);
167 if (feature_transform !=
"") {
168 nnet_transf.
Read(feature_transform);
172 nnet.
Read(model_filename);
179 KALDI_LOG <<
"Removing softmax from the nnet " << model_filename;
182 KALDI_LOG <<
"The nnet was without softmax. " 183 <<
"The last component in " << model_filename <<
" was " 204 " we will zero gradient for frames with total den/num mismatch." 205 " The mismatch is likely to be caused by missing correct path " 206 " from den-lattice due wrong annotation or search error." 207 " Leaving such frames out stabilizes the training.";
214 int32 num_done = 0, num_no_num_ali = 0, num_no_den_lat = 0,
215 num_other_error = 0, num_frm_drop = 0;
217 kaldi::int64 total_frames = 0;
220 double total_mmi_obj = 0.0, mmi_obj = 0.0;
221 double total_post_on_ali = 0.0, post_on_ali = 0.0;
224 for ( ; !feature_reader.Done(); feature_reader.Next()) {
225 std::string utt = feature_reader.Key();
226 if (!den_lat_reader.HasKey(utt)) {
231 if (!num_ali_reader.HasKey(utt)) {
239 const std::vector<int32> &num_ali = num_ali_reader.Value(utt);
241 if (static_cast<int32>(num_ali.size()) != mat.
NumRows()) {
243 <<
" alignment " << num_ali.size()
244 <<
" features " << mat.
NumRows();
248 if (mat.
NumRows() > max_frames) {
250 <<
" that has " << mat.
NumRows() <<
" frames," 251 <<
" it is longer than '--max-frames'" << max_frames;
257 Lattice den_lat = den_lat_reader.Value(utt);
258 if (den_lat.Start() == -1) {
259 KALDI_WARN <<
"Empty lattice of " << utt <<
", skipping.";
263 if (old_acoustic_scale != 1.0) {
268 kaldi::uint64 props = den_lat.Properties(fst::kFstProperties,
false);
269 if (!(props & fst::kTopSorted)) {
270 if (fst::TopSort(&den_lat) ==
false) {
271 KALDI_ERR <<
"Cycles detected in lattice.";
275 std::vector<int32> state_times;
278 if (max_time != mat.
NumRows()) {
280 <<
" denominator lattice " << max_time
281 <<
" features " << mat.
NumRows() <<
"," 282 <<
" skipping " << utt;
288 int32 num_frames = mat.
NumRows(),
298 log_prior.SubtractOnLogpost(&nnet_out);
303 feats_transf.
Resize(0, 0);
308 if (acoustic_scale != 1.0 || lm_scale != 1.0)
321 double path_ac_like = 0.0;
322 for (int32 t = 0; t < num_frames; t++) {
324 path_ac_like += nnet_out_h(t, pdf);
326 path_ac_like *= acoustic_scale;
327 mmi_obj = path_ac_like - lat_like;
337 for (int32 t = 0; t < num_frames; t++) {
339 double posterior = nnet_diff_h(t, pdf);
340 post_on_ali += posterior;
344 KALDI_VLOG(1) <<
"Lattice #" << num_done + 1 <<
" processed" 345 <<
" (" << utt <<
"): found " << den_lat.NumStates()
348 KALDI_VLOG(1) <<
"Utterance " << utt <<
": Average MMI obj. value = " 349 << (mmi_obj/num_frames) <<
" over " << num_frames <<
" frames." 350 <<
" (Avg. den-posterior on ali " << post_on_ali / num_frames <<
")";
355 std::vector<int32> frm_drop_vec;
356 for (int32 t = 0; t < num_frames; t++) {
358 double posterior = nnet_diff_h(t, pdf);
359 if (posterior < 1e-20) {
361 frm_drop_vec.push_back(t);
366 for (int32 t = 0; t < nnet_diff_h.
NumRows(); t++) {
368 nnet_diff_h(t, pdf) -= 1.0;
373 for (int32
i = 0;
i < frm_drop_vec.size();
i++) {
374 nnet_diff_h.
Row(frm_drop_vec[
i]).Set(0.0);
376 num_frm_drop += frm_drop;
380 std::stringstream ss;
381 ss << (drop_frames?
"Dropped":
"[dropping disabled] Would drop")
382 <<
" frames in " << utt <<
" " << frm_drop <<
"/" << num_frames
385 ss <<
" intervals :";
387 int32 beg_streak = frm_drop_vec[0];
388 int32 len_streak = 0;
390 for (i = 0; i < frm_drop_vec.size(); i++, len_streak++) {
391 if (beg_streak + len_streak != frm_drop_vec[i]) {
392 ss <<
" " << beg_streak <<
".." << frm_drop_vec[i-1] <<
"frm";
393 beg_streak = frm_drop_vec[
i];
397 ss <<
" " << beg_streak <<
".." << frm_drop_vec[i-1] <<
"frm";
410 total_mmi_obj += mmi_obj;
411 total_post_on_ali += post_on_ali;
412 total_frames += num_frames;
415 if (num_done % 100 == 0) {
417 KALDI_VLOG(1) <<
"After " << num_done <<
" utterances: " 418 <<
"time elapsed = " << time_now / 60 <<
" min; " 419 <<
"processed " << total_frames / time_now <<
" frames per sec.";
422 CuDevice::Instantiate().CheckGpuHealth();
435 if (num_done % 1000 == 0) {
449 KALDI_LOG <<
"Appending the softmax " << target_model_filename;
452 nnet.
Write(target_model_filename, binary);
456 <<
"Time taken = " << time_now/60 <<
" min; processed " 457 << (total_frames/time_now) <<
" frames per second.";
459 KALDI_LOG <<
"Done " << num_done <<
" files, " 460 << num_no_num_ali <<
" with no numerator alignments, " 461 << num_no_den_lat <<
" with no denominator lattices, " 462 << num_other_error <<
" with other errors.";
464 KALDI_LOG <<
"Overall MMI-objective/frame is " 465 << std::setprecision(8) << total_mmi_obj / total_frames
466 <<
" over " << total_frames <<
" frames," 467 <<
" (average den-posterior on ali " 468 << total_post_on_ali / total_frames <<
"," 469 <<
" dropped " << num_frm_drop
470 <<
" frames with num/den mismatch)";
473 CuDevice::Instantiate().PrintProfile();
477 }
catch(
const std::exception &e) {
478 std::cerr << e.what();
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
void RemoveLastComponent()
Remove the last of the Components,.
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void PosteriorToPdfMatrix(const Posterior &post, const TransitionModel &model, CuMatrix< Real > *mat)
Wrapper of PosteriorToMatrixMapped with CuMatrix argument.
void AppendComponentPointer(Component *dynamically_allocated_comp)
Append Component* to 'this' instance of Nnet by a shallow copy ('this' instance of Nnet over-takes th...
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
std::string class_frame_counts
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
int32 TransitionIdToPdf(int32 trans_id) const
const Component & GetLastComponent() const
LastComponent accessor,.
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post, double *acoustic_like_sum)
This function does the forward-backward over lattices and computes the posterior probabilities of the...
int32 OutputDim() const
Dimensionality of network outputs (posteriors | bn-features | etc.),.
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
std::vector< std::vector< double > > LatticeScale(double lmwt, double acwt)
fst::VectorFst< LatticeArc > Lattice
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
void Register(OptionsItf *opts)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Arc::StateId NumArcs(const ExpandedFst< Arc > &fst)
Returns the total number of arcs in an FST.
double Elapsed() const
Returns time in seconds.
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
void Register(OptionsItf *opts)
void LatticeAcousticRescore(const Matrix< BaseFloat > &log_like, const TransitionModel &trans_model, const std::vector< int32 > &state_times, Lattice *lat)