46 const TransitionModel &trans_model,
47 const std::vector<int32> &state_times,
49 kaldi::uint64 props = lat->Properties(fst::kFstProperties,
false);
50 if (!(props & fst::kTopSorted))
51 KALDI_ERR <<
"Input lattice must be topologically sorted.";
54 std::vector<std::vector<int32> > time_to_state(log_like.NumRows());
55 for (
size_t i = 0;
i < state_times.size();
i++) {
57 if (state_times[
i] < log_like.NumRows())
58 time_to_state[state_times[
i]].push_back(
i);
61 &&
"There appears to be lattice/feature mismatch.");
64 for (
int32 t = 0; t < log_like.NumRows(); t++) {
65 for (
size_t i = 0;
i < time_to_state[t].size();
i++) {
66 int32 state = time_to_state[t][
i];
67 for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
70 int32 trans_id = arc.ilabel;
72 int32 pdf_id = trans_model.TransitionIdToPdf(trans_id);
73 arc.weight.SetValue2(-log_like(t, pdf_id) + arc.weight.Value2());
85 int main(
int argc,
char *argv[]) {
86 using namespace kaldi;
91 "Perform one iteration of MPE/sMBR training using SGD with per-utterance" 94 "Usage: nnet-train-mpe-sequential [options] " 95 "<model-in> <transition-model-in> <feature-rspecifier> " 96 "<den-lat-rspecifier> <ali-rspecifier> [<model-out>]\n" 98 "e.g.: nnet-train-mpe-sequential nnet.init trans.mdl scp:feats.scp " 99 "scp:denlats.scp ark:ali.ark nnet.iter1\n";
108 po.
Register(
"binary", &binary,
"Write output in binary mode");
110 std::string feature_transform;
111 po.
Register(
"feature-transform", &feature_transform,
112 "Feature transform in 'nnet1' format");
114 std::string silence_phones_str;
115 po.
Register(
"silence-phones", &silence_phones_str,
116 "Colon-separated list of integer id's of silence phones, e.g. 46:47");
123 old_acoustic_scale = 0.0;
125 po.
Register(
"acoustic-scale", &acoustic_scale,
126 "Scaling factor for acoustic likelihoods");
129 "Scaling factor for \"graph costs\" (including LM costs)");
131 po.
Register(
"old-acoustic-scale", &old_acoustic_scale,
132 "Add in the scores in the input lattices with this scale, rather " 133 "than discarding them.");
135 bool one_silence_class =
false;
136 po.
Register(
"one-silence-class", &one_silence_class,
137 "If true, the newer behavior reduces insertions.");
140 po.
Register(
"max-frames", &max_frames,
141 "Maximum number of frames an utterance can have (skipped if longer)");
143 bool do_smbr =
false;
145 "Use state-level accuracies instead of phone accuracies.");
147 std::string use_gpu=
"yes";
149 "yes|no|optional, only has effect if compiled with CUDA");
158 std::string model_filename = po.
GetArg(1),
159 transition_model_filename = po.
GetArg(2),
160 feature_rspecifier = po.
GetArg(3),
161 den_lat_rspecifier = po.
GetArg(4),
162 ref_ali_rspecifier = po.
GetArg(5),
163 target_model_filename = po.
GetArg(6);
165 std::vector<int32> silence_phones;
168 KALDI_ERR <<
"Invalid silence-phones string " << silence_phones_str;
171 if (silence_phones.empty()) {
172 KALDI_LOG <<
"No silence phones specified.";
176 CuDevice::Instantiate().SelectGpuId(use_gpu);
180 if (feature_transform !=
"") {
181 nnet_transf.
Read(feature_transform);
185 nnet.
Read(model_filename);
192 KALDI_LOG <<
"Removing softmax from the nnet " << model_filename;
195 KALDI_LOG <<
"The nnet was without softmax. " 196 <<
"The last component in " << model_filename <<
" was " 205 TransitionModel trans_model;
213 Matrix<BaseFloat> nnet_out_h;
224 kaldi::int64 total_frames = 0;
225 double total_frame_acc = 0.0, utt_frame_acc;
228 for (; !feature_reader.
Done(); feature_reader.
Next()) {
229 std::string utt = feature_reader.
Key();
230 if (!den_lat_reader.
HasKey(utt)) {
235 if (!ref_ali_reader.
HasKey(utt)) {
236 KALDI_WARN <<
"Missing alignment for " << utt;
242 const Matrix<BaseFloat> &mat = feature_reader.
Value();
243 const std::vector<int32> &ref_ali = ref_ali_reader.
Value(utt);
245 if (static_cast<MatrixIndexT>(ref_ali.size()) != mat.NumRows()) {
247 <<
" alignment " << ref_ali.size()
248 <<
" features " << mat.NumRows();
252 if (mat.NumRows() > max_frames) {
254 <<
" that has " << mat.NumRows() <<
" frames," 255 <<
" it is longer than '--max-frames'" << max_frames;
262 if (den_lat.Start() == -1) {
263 KALDI_WARN <<
"Empty lattice of " << utt <<
", skipping.";
267 if (old_acoustic_scale != 1.0) {
272 kaldi::uint64 props = den_lat.Properties(fst::kFstProperties,
false);
273 if (!(props & fst::kTopSorted)) {
274 if (fst::TopSort(&den_lat) ==
false) {
275 KALDI_ERR <<
"Cycles detected in lattice.";
279 std::vector<int32> state_times;
282 if (max_time != mat.NumRows()) {
284 <<
" denominator lattice " << max_time
285 <<
" features " << mat.NumRows() <<
"," 286 <<
" skipping " << utt;
292 int32 num_frames = mat.NumRows();
304 nnet_out_h = Matrix<BaseFloat>(nnet_out);
306 feats_transf.
Resize(0, 0);
311 if (acoustic_scale != 1.0 || lm_scale != 1.0)
318 trans_model, silence_phones, den_lat, ref_ali,
"smbr",
319 one_silence_class, &post);
323 trans_model, silence_phones, den_lat, ref_ali,
"mpfe",
324 one_silence_class, &post);
329 nnet_diff.
Scale(-1.0);
331 KALDI_VLOG(1) <<
"Lattice #" << num_done + 1 <<
" processed" 332 <<
" (" << utt <<
"): found " << den_lat.NumStates()
335 KALDI_VLOG(1) <<
"Utterance " << utt <<
": Average frame accuracy = " 336 << (utt_frame_acc/num_frames) <<
" over " << num_frames
338 <<
" diff-range(" << nnet_diff.
Min() <<
"," 339 << nnet_diff.
Max() <<
")";
346 total_frame_acc += utt_frame_acc;
347 total_frames += num_frames;
350 if (num_done % 100 == 0) {
352 KALDI_VLOG(1) <<
"After " << num_done <<
" utterances: " 353 <<
"time elapsed = " << time_now / 60 <<
" min; " 354 <<
"processed " << total_frames / time_now <<
" frames per sec.";
357 CuDevice::Instantiate().CheckGpuHealth();
370 if (num_done % 1000 == 0) {
384 KALDI_LOG <<
"Appending the softmax " << target_model_filename;
387 nnet.
Write(target_model_filename, binary);
391 <<
"Time taken = " << time_now / 60 <<
" min; processed " 392 << total_frames / time_now <<
" frames per second.";
394 KALDI_LOG <<
"Done " << num_done <<
" files, " 395 << num_no_ref_ali <<
" with no reference alignments, " 396 << num_no_den_lat <<
" with no lattices, " 397 << num_other_error <<
" with other errors.";
399 KALDI_LOG <<
"Overall average frame-accuracy is " 400 << total_frame_acc / total_frames <<
" over " 401 << total_frames <<
" frames.";
404 CuDevice::Instantiate().PrintProfile();
408 }
catch(
const std::exception &e) {
409 std::cerr << e.what();
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
void RemoveLastComponent()
Remove the last of the Components,.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void PosteriorToPdfMatrix(const Posterior &post, const TransitionModel &model, CuMatrix< Real > *mat)
Wrapper of PosteriorToMatrixMapped with CuMatrix argument.
void AppendComponentPointer(Component *dynamically_allocated_comp)
Append Component* to 'this' instance of Nnet by a shallow copy ('this' instance of Nnet over-takes th...
fst::ArcTpl< LatticeWeight > LatticeArc
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
std::string class_frame_counts
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void Min(const CuMatrixBase< Real > &A)
Do, elementwise, *this = min(*this, A).
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
void SubtractOnLogpost(CuMatrixBase< BaseFloat > *llk)
Subtract pdf priors from log-posteriors to get pseudo log-likelihoods.
const Component & GetLastComponent() const
LastComponent accessor,.
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
int main(int argc, char *argv[])
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Max(const CuMatrixBase< Real > &A)
Do, elementwise, *this = max(*this, A).
int32 OutputDim() const
Dimensionality of network outputs (posteriors | bn-features | etc.),.
const T & Value(const std::string &key)
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
std::vector< std::vector< double > > LatticeScale(double lmwt, double acwt)
fst::VectorFst< LatticeArc > Lattice
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
BaseFloat LatticeForwardBackwardMpeVariants(const TransitionModel &trans, const std::vector< int32 > &silence_phones, const Lattice &lat, const std::vector< int32 > &num_ali, std::string criterion, bool one_silence_class, Posterior *post)
This function implements either the MPFE (minimum phone frame error) or SMBR (state-level minimum bay...
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
void Register(OptionsItf *opts)
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Arc::StateId NumArcs(const ExpandedFst< Arc > &fst)
Returns the total number of arcs in an FST.
double Elapsed() const
Returns time in seconds.
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
void Register(OptionsItf *opts)
void LatticeAcousticRescore(const Matrix< BaseFloat > &log_like, const TransitionModel &trans_model, const std::vector< int32 > &state_times, Lattice *lat)