26 int main(
int argc,
char *argv[]) {
28 using namespace kaldi;
32 "Train the transition probabilities in transition-model " 33 "(used in nnet1 recipe).\n" 35 "Usage: train-transitions [options] " 36 "<trans-model-in> <alignments-rspecifier> <trans-model-out>\n" 37 "e.g.: train-transitions 1.mdl \"ark:gunzip -c ali.*.gz|\" 2.mdl\n";
39 bool binary_write =
true;
43 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
44 transition_update_config.
Register(&po);
53 std::string trans_model_rxfilename = po.
GetArg(1),
54 ali_rspecifier = po.
GetArg(2),
55 trans_model_wxfilename = po.
GetArg(3);
60 Input ki(trans_model_rxfilename, &binary_read);
69 for (; !ali_reader.
Done(); ali_reader.
Next()) {
70 const std::vector<int32> alignment(ali_reader.
Value());
71 for (
size_t i = 0;
i < alignment.size();
i++) {
72 int32 tid = alignment[
i];
74 trans_model.
Accumulate(weight, tid, &transition_accs);
78 KALDI_LOG <<
"Accumulated transition stats from " << num_done
83 trans_model.
MleUpdate(transition_accs, transition_update_config,
85 KALDI_LOG <<
"Transition model update: average " << (objf_impr/
count)
86 <<
" log-like improvement per frame over " << count
91 Output ko(trans_model_wxfilename, binary_write);
94 KALDI_LOG <<
"Trained transition model and wrote it to " 95 << trans_model_wxfilename;
97 }
catch(
const std::exception &e) {
98 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void MleUpdate(const Vector< double > &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum Likelihood estimation.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int main(int argc, char *argv[])
void Register(const std::string &name, bool *ptr, const std::string &doc)
void InitStats(Vector< double > *stats) const
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
void Register(OptionsItf *opts)