36 for (
int32 tid = 1; tid < transition_accs.
Dim(); tid++) {
38 pdf_counts(pdf) += transition_accs(tid);
43 pdf_counts.Scale(1.0 / sum);
44 pdf_counts.ApplyFloor(prior_floor);
45 pdf_counts.Scale(1.0 / pdf_counts.Sum());
53 int main(
int argc,
char *argv[]) {
55 using namespace kaldi;
60 "Train the transition probabilities of a neural network acoustic model\n" 62 "Usage: nnet-train-transitions [options] <nnet-in> <alignments-rspecifier> <nnet-out>\n" 64 " nnet-train-transitions 1.nnet \"ark:gunzip -c ali.*.gz|\" 2.nnet\n";
66 bool binary_write =
true;
67 bool set_priors =
true;
77 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
78 po.
Register(
"set-priors", &set_priors,
"If true, also set priors in neural " 79 "net (we divide by these in test time)");
80 po.
Register(
"prior-floor", &prior_floor,
"When setting priors, floor for " 82 transition_update_config.
Register(&po);
91 std::string nnet_rxfilename = po.
GetArg(1),
92 ali_rspecifier = po.
GetArg(2),
93 nnet_wxfilename = po.
GetArg(3);
99 Input ki(nnet_rxfilename, &binary_read);
109 for (; ! ali_reader.
Done(); ali_reader.
Next()) {
110 const std::vector<int32> alignment(ali_reader.
Value());
111 for (
size_t i = 0;
i < alignment.size();
i++) {
112 int32 tid = alignment[
i];
114 trans_model.
Accumulate(weight, tid, &transition_accs);
118 KALDI_LOG <<
"Accumulated transition stats from " << num_done
123 trans_model.
MleUpdate(transition_accs, transition_update_config,
125 KALDI_LOG <<
"Transition model update: average " << (objf_impr/
count)
126 <<
" log-like improvement per frame over " << count
131 KALDI_LOG <<
"Setting priors of pdfs in the model.";
132 SetPriors(trans_model, transition_accs, prior_floor, &am_nnet);
136 Output ko(nnet_wxfilename, binary_write);
140 KALDI_LOG <<
"Trained transitions of neural network model and wrote it to " 143 }
catch(
const std::exception &e) {
144 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void MleUpdate(const Vector< double > &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum Likelihood estimation.
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Read(std::istream &is, bool binary)
int32 TransitionIdToPdf(int32 trans_id) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
void InitStats(Vector< double > *stats) const
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
void Write(std::ostream &os, bool binary) const
void SetPriors(const TransitionModel &tmodel, const Vector< double > &transition_accs, double prior_floor, AmNnet *am_nnet)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
#define KALDI_ASSERT(cond)
void Register(OptionsItf *opts)
void SetPriors(const VectorBase< BaseFloat > &priors)