30 using namespace kaldi;
34 "Do Maximum A Posteriori re-estimation of GMM-based acoustic model\n" 35 "Usage: gmm-est-map [options] <model-in> <stats-in> <model-out>\n" 36 "e.g.: gmm-est-map 1.mdl 1.acc 2.mdl\n";
38 bool binary_write =
true;
41 std::string update_flags_str =
"mvwt";
42 std::string occs_out_filename;
45 po.Register(
"binary", &binary_write,
"Write output in binary mode");
46 po.Register(
"update-flags", &update_flags_str,
"Which GMM parameters to " 47 "update: subset of mvwt.");
48 po.Register(
"write-occs", &occs_out_filename,
"File to write state " 55 if (po.NumArgs() != 3) {
63 std::string model_in_filename = po.GetArg(1),
64 stats_filename = po.GetArg(2),
65 model_out_filename = po.GetArg(3);
71 Input ki(model_in_filename, &binary_read);
72 trans_model.
Read(ki.Stream(), binary_read);
73 am_gmm.
Read(ki.Stream(), binary_read);
80 Input ki(stats_filename, &binary);
81 transition_accs.
Read(ki.Stream(), binary);
82 gmm_accs.
Read(ki.Stream(), binary,
true);
87 trans_model.
MapUpdate(transition_accs, tcfg, &objf_impr, &count);
88 KALDI_LOG <<
"Transition model update: Overall " << (objf_impr/
count)
89 <<
" log-like improvement per frame over " << (count)
100 <<
" objective function improvement per frame over " 101 << count <<
" frames";
102 KALDI_LOG <<
"GMM update: Overall avg like per frame = " 103 << (tot_like/tot_t) <<
" over " << tot_t <<
" frames.";
106 if (!occs_out_filename.empty()) {
116 Output ko(model_out_filename, binary_write);
117 trans_model.
Write(ko.Stream(), binary_write);
118 am_gmm.
Write(ko.Stream(), binary_write);
120 KALDI_LOG <<
"Written model to " << model_out_filename;
122 }
catch(
const std::exception &e) {
123 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void MapAmDiagGmmUpdate(const MapDiagGmmOptions &config, const AccumAmDiagGmm &am_diag_gmm_acc, GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
Maximum A Posteriori update.
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
const VectorBase< double > & occupancy() const
BaseFloat TotCount() const
uint16 GmmFlagsType
Bitwise OR of the above flags.
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Register(OptionsItf *opts)
BaseFloat TotLogLike() const
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Register(OptionsItf *opts)
void Read(std::istream &is, bool binary)
void Read(std::istream &in_stream, bool binary, bool add=false)
void Write(std::ostream &os, bool binary) const
const AccumDiagGmm & GetAcc(int32 index) const
void Write(std::ostream &out_stream, bool binary) const
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
void Read(std::istream &in_stream, bool binary)
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.
void MapUpdate(const Vector< double > &stats, const MapTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum A Posteriori (MAP) estimation.
Configuration variables for Maximum A Posteriori (MAP) update.