28 int main(
int argc,
char *argv[]) {
29 using namespace kaldi;
34 "Estimate SGMM model parameters discriminatively using Extended\n" 35 "Baum-Welch style of update\n" 36 "Usage: sgmm2-est-ebw [options] <model-in> <num-stats-in> <den-stats-in> <model-out>\n";
39 string update_flags_str =
"vMNwcSt";
40 bool binary_write =
true;
41 string write_flags_str =
"gsnu";
46 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
47 po.
Register(
"update-flags", &update_flags_str,
"Which SGMM parameters to " 48 "update: subset of vMNwcSt.");
49 po.
Register(
"write-flags", &write_flags_str,
"Which SGMM parameters to " 50 "write: subset of gsnu");
52 "weight update and normalizer computation");
60 string model_in_filename = po.
GetArg(1),
61 num_stats_filename = po.
GetArg(2),
62 den_stats_filename = po.
GetArg(3),
63 model_out_filename = po.
GetArg(4);
72 Input ki(model_in_filename, &binary);
81 Input ki(num_stats_filename, &binary);
83 sgmm_num_accs.
Read(ki.
Stream(), binary,
false);
89 Input ki(den_stats_filename, &binary);
91 sgmm_den_accs.
Read(ki.
Stream(), binary,
false);
94 sgmm_num_accs.
Check(am_sgmm,
true);
95 sgmm_den_accs.
Check(am_sgmm,
true);
100 sgmm_updater.
Update(sgmm_num_accs, sgmm_den_accs, &am_sgmm,
101 update_flags, &auxf_impr, &count);
102 KALDI_LOG <<
"Overall auxf impr/frame from SGMM update is " << (auxf_impr/
count)
103 <<
" over " << count <<
" frames.";
107 Output ko(model_out_filename, binary_write);
109 am_sgmm.
Write(ko.
Stream(), binary_write, write_flags);
112 KALDI_LOG <<
"Wrote model to " << model_out_filename;
114 }
catch(
const std::exception &e) {
115 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
uint16 SgmmWriteFlagsType
Bitwise OR of the above flags.
void Write(std::ostream &os, bool binary, SgmmWriteFlagsType write_params) const
Class for definition of the subspace Gmm acoustic model.
This header implements a form of Extended Baum-Welch training for SGMMs.
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Read(std::istream &in_stream, bool binary, bool add)
void Read(std::istream &is, bool binary)
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
void Register(const std::string &name, bool *ptr, const std::string &doc)
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
void Register(OptionsItf *opts)
void Update(const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, AmSgmm2 *model, SgmmUpdateFlagsType flags, BaseFloat *auxf_change_out, BaseFloat *count_out)
SgmmUpdateFlagsType StringToSgmmWriteFlags(std::string str)
void Check(const AmSgmm2 &model, bool show_properties=true) const
Checks the various accumulators for correct sizes given a model.
Class for the accumulators associated with the phonetic-subspace model parameters.
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.