gmm-est.cc
Go to the documentation of this file.
1 // gmmbin/gmm-est.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "gmm/am-diag-gmm.h"
23 #include "tree/context-dep.h"
24 #include "hmm/transition-model.h"
25 #include "gmm/mle-am-diag-gmm.h"
26 
27 int main(int argc, char *argv[]) {
28  try {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31 
32  const char *usage =
33  "Do Maximum Likelihood re-estimation of GMM-based acoustic model\n"
34  "Usage: gmm-est [options] <model-in> <stats-in> <model-out>\n"
35  "e.g.: gmm-est 1.mdl 1.acc 2.mdl\n";
36 
37  bool binary_write = true;
39  MleDiagGmmOptions gmm_opts;
40  int32 mixup = 0;
41  int32 mixdown = 0;
42  BaseFloat perturb_factor = 0.01;
43  BaseFloat power = 0.2;
44  BaseFloat min_count = 20.0;
45  std::string update_flags_str = "mvwt";
46  std::string occs_out_filename;
47 
48  ParseOptions po(usage);
49  po.Register("binary", &binary_write, "Write output in binary mode");
50  po.Register("mix-up", &mixup, "Increase number of mixture components to "
51  "this overall target.");
52  po.Register("min-count", &min_count,
53  "Minimum per-Gaussian count enforced while mixing up and down.");
54  po.Register("mix-down", &mixdown, "If nonzero, merge mixture components to this "
55  "target.");
56  po.Register("power", &power, "If mixing up, power to allocate Gaussians to"
57  " states.");
58  po.Register("update-flags", &update_flags_str, "Which GMM parameters to "
59  "update: subset of mvwt.");
60  po.Register("perturb-factor", &perturb_factor, "While mixing up, perturb "
61  "means by standard deviation times this factor.");
62  po.Register("write-occs", &occs_out_filename, "File to write pdf "
63  "occupation counts to.");
64  tcfg.Register(&po);
65  gmm_opts.Register(&po);
66 
67  po.Read(argc, argv);
68 
69  if (po.NumArgs() != 3) {
70  po.PrintUsage();
71  exit(1);
72  }
73 
74  kaldi::GmmFlagsType update_flags =
75  StringToGmmFlags(update_flags_str);
76 
77  std::string model_in_filename = po.GetArg(1),
78  stats_filename = po.GetArg(2),
79  model_out_filename = po.GetArg(3);
80 
81  AmDiagGmm am_gmm;
82  TransitionModel trans_model;
83  {
84  bool binary_read;
85  Input ki(model_in_filename, &binary_read);
86  trans_model.Read(ki.Stream(), binary_read);
87  am_gmm.Read(ki.Stream(), binary_read);
88  }
89 
90  Vector<double> transition_accs;
91  AccumAmDiagGmm gmm_accs;
92  {
93  bool binary;
94  Input ki(stats_filename, &binary);
95  transition_accs.Read(ki.Stream(), binary);
96  gmm_accs.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
97  }
98 
99  if (update_flags & kGmmTransitions) { // Update transition model.
100  BaseFloat objf_impr, count;
101  trans_model.MleUpdate(transition_accs, tcfg, &objf_impr, &count);
102  KALDI_LOG << "Transition model update: Overall " << (objf_impr/count)
103  << " log-like improvement per frame over " << (count)
104  << " frames.";
105  }
106 
107  { // Update GMMs.
108  BaseFloat objf_impr, count;
109  BaseFloat tot_like = gmm_accs.TotLogLike(),
110  tot_t = gmm_accs.TotCount();
111  MleAmDiagGmmUpdate(gmm_opts, gmm_accs, update_flags, &am_gmm,
112  &objf_impr, &count);
113  KALDI_LOG << "GMM update: Overall " << (objf_impr/count)
114  << " objective function improvement per frame over "
115  << count << " frames";
116  KALDI_LOG << "GMM update: Overall avg like per frame = "
117  << (tot_like/tot_t) << " over " << tot_t << " frames.";
118  }
119 
120  if (mixup != 0 || mixdown != 0 || !occs_out_filename.empty()) {
121  // get pdf occupation counts
122  Vector<BaseFloat> pdf_occs;
123  pdf_occs.Resize(gmm_accs.NumAccs());
124  for (int i = 0; i < gmm_accs.NumAccs(); i++)
125  pdf_occs(i) = gmm_accs.GetAcc(i).occupancy().Sum();
126 
127  if (mixdown != 0)
128  am_gmm.MergeByCount(pdf_occs, mixdown, power, min_count);
129 
130  if (mixup != 0)
131  am_gmm.SplitByCount(pdf_occs, mixup, perturb_factor,
132  power, min_count);
133 
134  if (!occs_out_filename.empty()) {
135  bool binary = false;
136  WriteKaldiObject(pdf_occs, occs_out_filename, binary);
137  }
138  }
139 
140  {
141  Output ko(model_out_filename, binary_write);
142  trans_model.Write(ko.Stream(), binary_write);
143  am_gmm.Write(ko.Stream(), binary_write);
144  }
145 
146  KALDI_LOG << "Written model to " << model_out_filename;
147  return 0;
148  } catch(const std::exception &e) {
149  std::cerr << e.what() << '\n';
150  return -1;
151  }
152 }
153 
154 
void MleAmDiagGmmUpdate(const MleDiagGmmOptions &config, const AccumAmDiagGmm &am_diag_gmm_acc, GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
for computing the maximum-likelihood estimates of the parameters of an acoustic model that uses diago...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int main(int argc, char *argv[])
Definition: gmm-est.cc:27
void MleUpdate(const Vector< double > &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum Likelihood estimation.
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
Definition: model-common.cc:26
void MergeByCount(const Vector< BaseFloat > &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count)
Definition: am-diag-gmm.cc:125
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
const VectorBase< double > & occupancy() const
Definition: mle-diag-gmm.h:183
kaldi::int32 int32
BaseFloat TotCount() const
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Register(const std::string &name, bool *ptr, const std::string &doc)
const size_t count
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
BaseFloat TotLogLike() const
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Read(std::istream &is, bool binary)
void Register(OptionsItf *opts)
Definition: mle-diag-gmm.h:59
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
Configuration variables like variance floor, minimum occupancy, etc.
Definition: mle-diag-gmm.h:38
void Read(std::istream &in_stream, bool binary, bool add=false)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
const AccumDiagGmm & GetAcc(int32 index) const
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
void Register(OptionsItf *opts)
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.
void SplitByCount(const Vector< BaseFloat > &state_occs, int32 target_components, float perturb_factor, BaseFloat power, BaseFloat min_count)
Definition: am-diag-gmm.cc:102