29 using namespace kaldi;
33 "Do Maximum Likelihood re-estimation of GMM-based acoustic model\n" 34 "Usage: gmm-est [options] <model-in> <stats-in> <model-out>\n" 35 "e.g.: gmm-est 1.mdl 1.acc 2.mdl\n";
37 bool binary_write =
true;
45 std::string update_flags_str =
"mvwt";
46 std::string occs_out_filename;
49 po.Register(
"binary", &binary_write,
"Write output in binary mode");
50 po.Register(
"mix-up", &mixup,
"Increase number of mixture components to " 51 "this overall target.");
52 po.Register(
"min-count", &min_count,
53 "Minimum per-Gaussian count enforced while mixing up and down.");
54 po.Register(
"mix-down", &mixdown,
"If nonzero, merge mixture components to this " 56 po.Register(
"power", &power,
"If mixing up, power to allocate Gaussians to" 58 po.Register(
"update-flags", &update_flags_str,
"Which GMM parameters to " 59 "update: subset of mvwt.");
60 po.Register(
"perturb-factor", &perturb_factor,
"While mixing up, perturb " 61 "means by standard deviation times this factor.");
62 po.Register(
"write-occs", &occs_out_filename,
"File to write pdf " 63 "occupation counts to.");
69 if (po.NumArgs() != 3) {
77 std::string model_in_filename = po.GetArg(1),
78 stats_filename = po.GetArg(2),
79 model_out_filename = po.GetArg(3);
85 Input ki(model_in_filename, &binary_read);
86 trans_model.
Read(ki.Stream(), binary_read);
87 am_gmm.
Read(ki.Stream(), binary_read);
94 Input ki(stats_filename, &binary);
95 transition_accs.
Read(ki.Stream(), binary);
96 gmm_accs.
Read(ki.Stream(), binary,
true);
101 trans_model.
MleUpdate(transition_accs, tcfg, &objf_impr, &count);
102 KALDI_LOG <<
"Transition model update: Overall " << (objf_impr/
count)
103 <<
" log-like improvement per frame over " << (count)
114 <<
" objective function improvement per frame over " 115 << count <<
" frames";
116 KALDI_LOG <<
"GMM update: Overall avg like per frame = " 117 << (tot_like/tot_t) <<
" over " << tot_t <<
" frames.";
120 if (mixup != 0 || mixdown != 0 || !occs_out_filename.empty()) {
128 am_gmm.
MergeByCount(pdf_occs, mixdown, power, min_count);
134 if (!occs_out_filename.empty()) {
141 Output ko(model_out_filename, binary_write);
142 trans_model.
Write(ko.Stream(), binary_write);
143 am_gmm.
Write(ko.Stream(), binary_write);
146 KALDI_LOG <<
"Written model to " << model_out_filename;
148 }
catch(
const std::exception &e) {
149 std::cerr << e.what() <<
'\n';
void MleAmDiagGmmUpdate(const MleDiagGmmOptions &config, const AccumAmDiagGmm &am_diag_gmm_acc, GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
for computing the maximum-likelihood estimates of the parameters of an acoustic model that uses diago...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void MleUpdate(const Vector< double > &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum Likelihood estimation.
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
void MergeByCount(const Vector< BaseFloat > &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count)
const VectorBase< double > & occupancy() const
BaseFloat TotCount() const
uint16 GmmFlagsType
Bitwise OR of the above flags.
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
BaseFloat TotLogLike() const
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
void Register(OptionsItf *opts)
Configuration variables like variance floor, minimum occupancy, etc.
void Read(std::istream &in_stream, bool binary, bool add=false)
void Write(std::ostream &os, bool binary) const
const AccumDiagGmm & GetAcc(int32 index) const
void Write(std::ostream &out_stream, bool binary) const
void Register(OptionsItf *opts)
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
void Read(std::istream &in_stream, bool binary)
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.
void SplitByCount(const Vector< BaseFloat > &state_occs, int32 target_components, float perturb_factor, BaseFloat power, BaseFloat min_count)