86 using namespace kaldi;
91 "Perform one iteration of MPE/sMBR training using SGD with per-utterance" 94 "Usage: nnet-train-mpe-sequential [options] " 95 "<model-in> <transition-model-in> <feature-rspecifier> " 96 "<den-lat-rspecifier> <ali-rspecifier> [<model-out>]\n" 98 "e.g.: nnet-train-mpe-sequential nnet.init trans.mdl scp:feats.scp " 99 "scp:denlats.scp ark:ali.ark nnet.iter1\n";
108 po.Register(
"binary", &binary,
"Write output in binary mode");
110 std::string feature_transform;
111 po.Register(
"feature-transform", &feature_transform,
112 "Feature transform in 'nnet1' format");
114 std::string silence_phones_str;
115 po.Register(
"silence-phones", &silence_phones_str,
116 "Colon-separated list of integer id's of silence phones, e.g. 46:47");
123 old_acoustic_scale = 0.0;
125 po.Register(
"acoustic-scale", &acoustic_scale,
126 "Scaling factor for acoustic likelihoods");
128 po.Register(
"lm-scale", &lm_scale,
129 "Scaling factor for \"graph costs\" (including LM costs)");
131 po.Register(
"old-acoustic-scale", &old_acoustic_scale,
132 "Add in the scores in the input lattices with this scale, rather " 133 "than discarding them.");
135 bool one_silence_class =
false;
136 po.Register(
"one-silence-class", &one_silence_class,
137 "If true, the newer behavior reduces insertions.");
140 po.Register(
"max-frames", &max_frames,
141 "Maximum number of frames an utterance can have (skipped if longer)");
143 bool do_smbr =
false;
144 po.Register(
"do-smbr", &do_smbr,
145 "Use state-level accuracies instead of phone accuracies.");
147 std::string use_gpu=
"yes";
148 po.Register(
"use-gpu", &use_gpu,
149 "yes|no|optional, only has effect if compiled with CUDA");
153 if (po.NumArgs() != 6) {
158 std::string model_filename = po.GetArg(1),
159 transition_model_filename = po.GetArg(2),
160 feature_rspecifier = po.GetArg(3),
161 den_lat_rspecifier = po.GetArg(4),
162 ref_ali_rspecifier = po.GetArg(5),
163 target_model_filename = po.GetArg(6);
165 std::vector<int32> silence_phones;
168 KALDI_ERR <<
"Invalid silence-phones string " << silence_phones_str;
171 if (silence_phones.empty()) {
172 KALDI_LOG <<
"No silence phones specified.";
176 CuDevice::Instantiate().SelectGpuId(use_gpu);
180 if (feature_transform !=
"") {
181 nnet_transf.
Read(feature_transform);
185 nnet.
Read(model_filename);
192 KALDI_LOG <<
"Removing softmax from the nnet " << model_filename;
195 KALDI_LOG <<
"The nnet was without softmax. " 196 <<
"The last component in " << model_filename <<
" was " 205 TransitionModel trans_model;
213 Matrix<BaseFloat> nnet_out_h;
224 kaldi::int64 total_frames = 0;
225 double total_frame_acc = 0.0, utt_frame_acc;
228 for (; !feature_reader.Done(); feature_reader.Next()) {
229 std::string utt = feature_reader.Key();
230 if (!den_lat_reader.HasKey(utt)) {
235 if (!ref_ali_reader.HasKey(utt)) {
236 KALDI_WARN <<
"Missing alignment for " << utt;
242 const Matrix<BaseFloat> &mat = feature_reader.Value();
243 const std::vector<int32> &ref_ali = ref_ali_reader.Value(utt);
245 if (static_cast<MatrixIndexT>(ref_ali.size()) != mat.NumRows()) {
247 <<
" alignment " << ref_ali.size()
248 <<
" features " << mat.NumRows();
252 if (mat.NumRows() > max_frames) {
254 <<
" that has " << mat.NumRows() <<
" frames," 255 <<
" it is longer than '--max-frames'" << max_frames;
261 Lattice den_lat = den_lat_reader.Value(utt);
262 if (den_lat.Start() == -1) {
263 KALDI_WARN <<
"Empty lattice of " << utt <<
", skipping.";
267 if (old_acoustic_scale != 1.0) {
272 kaldi::uint64 props = den_lat.Properties(fst::kFstProperties,
false);
273 if (!(props & fst::kTopSorted)) {
274 if (fst::TopSort(&den_lat) ==
false) {
275 KALDI_ERR <<
"Cycles detected in lattice.";
279 std::vector<int32> state_times;
282 if (max_time != mat.NumRows()) {
284 <<
" denominator lattice " << max_time
285 <<
" features " << mat.NumRows() <<
"," 286 <<
" skipping " << utt;
292 int32 num_frames = mat.NumRows();
301 log_prior.SubtractOnLogpost(&nnet_out);
304 nnet_out_h = Matrix<BaseFloat>(nnet_out);
306 feats_transf.
Resize(0, 0);
311 if (acoustic_scale != 1.0 || lm_scale != 1.0)
318 trans_model, silence_phones, den_lat, ref_ali,
"smbr",
319 one_silence_class, &post);
323 trans_model, silence_phones, den_lat, ref_ali,
"mpfe",
324 one_silence_class, &post);
329 nnet_diff.
Scale(-1.0);
331 KALDI_VLOG(1) <<
"Lattice #" << num_done + 1 <<
" processed" 332 <<
" (" << utt <<
"): found " << den_lat.NumStates()
335 KALDI_VLOG(1) <<
"Utterance " << utt <<
": Average frame accuracy = " 336 << (utt_frame_acc/num_frames) <<
" over " << num_frames
338 <<
" diff-range(" << nnet_diff.
Min() <<
"," 339 << nnet_diff.
Max() <<
")";
346 total_frame_acc += utt_frame_acc;
347 total_frames += num_frames;
350 if (num_done % 100 == 0) {
352 KALDI_VLOG(1) <<
"After " << num_done <<
" utterances: " 353 <<
"time elapsed = " << time_now / 60 <<
" min; " 354 <<
"processed " << total_frames / time_now <<
" frames per sec.";
357 CuDevice::Instantiate().CheckGpuHealth();
370 if (num_done % 1000 == 0) {
384 KALDI_LOG <<
"Appending the softmax " << target_model_filename;
387 nnet.
Write(target_model_filename, binary);
391 <<
"Time taken = " << time_now / 60 <<
" min; processed " 392 << total_frames / time_now <<
" frames per second.";
394 KALDI_LOG <<
"Done " << num_done <<
" files, " 395 << num_no_ref_ali <<
" with no reference alignments, " 396 << num_no_den_lat <<
" with no lattices, " 397 << num_other_error <<
" with other errors.";
399 KALDI_LOG <<
"Overall average frame-accuracy is " 400 << total_frame_acc / total_frames <<
" over " 401 << total_frames <<
" frames.";
404 CuDevice::Instantiate().PrintProfile();
408 }
catch(
const std::exception &e) {
409 std::cerr << e.what();
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
void RemoveLastComponent()
Remove the last of the Components,.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void PosteriorToPdfMatrix(const Posterior &post, const TransitionModel &model, CuMatrix< Real > *mat)
Wrapper of PosteriorToMatrixMapped with CuMatrix argument.
void AppendComponentPointer(Component *dynamically_allocated_comp)
Append Component* to 'this' instance of Nnet by a shallow copy ('this' instance of Nnet over-takes th...
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
std::string class_frame_counts
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void Min(const CuMatrixBase< Real > &A)
Do, elementwise, *this = min(*this, A).
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
const Component & GetLastComponent() const
LastComponent accessor,.
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Max(const CuMatrixBase< Real > &A)
Do, elementwise, *this = max(*this, A).
int32 OutputDim() const
Dimensionality of network outputs (posteriors | bn-features | etc.),.
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
std::vector< std::vector< double > > LatticeScale(double lmwt, double acwt)
fst::VectorFst< LatticeArc > Lattice
BaseFloat LatticeForwardBackwardMpeVariants(const TransitionModel &trans, const std::vector< int32 > &silence_phones, const Lattice &lat, const std::vector< int32 > &num_ali, std::string criterion, bool one_silence_class, Posterior *post)
This function implements either the MPFE (minimum phone frame error) or SMBR (state-level minimum bay...
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
void Register(OptionsItf *opts)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Arc::StateId NumArcs(const ExpandedFst< Arc > &fst)
Returns the total number of arcs in an FST.
double Elapsed() const
Returns time in seconds.
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
void Register(OptionsItf *opts)
void LatticeAcousticRescore(const Matrix< BaseFloat > &log_like, const TransitionModel &trans_model, const std::vector< int32 > &state_times, Lattice *lat)