55 using namespace kaldi;
60 "Train the transition probabilities of an nnet3 neural network acoustic model\n" 62 "Usage: nnet3-am-train-transitions [options] <nnet-in> <alignments-rspecifier> <nnet-out>\n" 64 " nnet3-am-train-transitions 1.nnet \"ark:gunzip -c ali.*.gz|\" 2.nnet\n";
66 bool binary_write =
true;
67 bool set_priors =
true;
77 po.Register(
"binary", &binary_write,
"Write output in binary mode");
78 po.Register(
"set-priors", &set_priors,
"If true, also set priors in neural " 79 "net (we divide by these in test time)");
80 po.Register(
"prior-floor", &prior_floor,
"When setting priors, floor for " 82 transition_update_config.
Register(&po);
86 if (po.NumArgs() != 3) {
91 std::string nnet_rxfilename = po.GetArg(1),
92 ali_rspecifier = po.GetArg(2),
93 nnet_wxfilename = po.GetArg(3);
99 Input ki(nnet_rxfilename, &binary_read);
100 trans_model.
Read(ki.Stream(), binary_read);
101 am_nnet.
Read(ki.Stream(), binary_read);
109 for (; ! ali_reader.Done(); ali_reader.Next()) {
110 const std::vector<int32> alignment(ali_reader.Value());
111 for (
size_t i = 0;
i < alignment.size();
i++) {
112 int32 tid = alignment[
i];
114 trans_model.
Accumulate(weight, tid, &transition_accs);
118 KALDI_LOG <<
"Accumulated transition stats from " << num_done
123 trans_model.
MleUpdate(transition_accs, transition_update_config,
125 KALDI_LOG <<
"Transition model update: average " << (objf_impr/
count)
126 <<
" log-like improvement per frame over " << count
131 KALDI_LOG <<
"Setting priors of pdfs in the model.";
132 SetPriors(trans_model, transition_accs, prior_floor, &am_nnet);
136 Output ko(nnet_wxfilename, binary_write);
137 trans_model.
Write(ko.Stream(), binary_write);
138 am_nnet.
Write(ko.Stream(), binary_write);
140 KALDI_LOG <<
"Trained transitions of neural network model and wrote it to " 143 }
catch(
const std::exception &e) {
144 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void MleUpdate(const Vector< double > &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum Likelihood estimation.
void SetPriors(const TransitionModel &tmodel, const Vector< double > &transition_accs, double prior_floor, AmNnetSimple *am_nnet)
void Read(std::istream &is, bool binary)
void InitStats(Vector< double > *stats) const
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Write(std::ostream &os, bool binary) const
void Register(OptionsItf *opts)
void Write(std::ostream &os, bool binary) const