gmm-est-map.cc
Go to the documentation of this file.
1 // gmmbin/gmm-est-map.cc
2 
3 // Copyright 2009-2012 Microsoft Corporation
4 // Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "gmm/am-diag-gmm.h"
24 #include "tree/context-dep.h"
25 #include "hmm/transition-model.h"
26 #include "gmm/mle-am-diag-gmm.h"
27 
28 int main(int argc, char *argv[]) {
29  try {
30  using namespace kaldi;
31  typedef kaldi::int32 int32;
32 
33  const char *usage =
34  "Do Maximum A Posteriori re-estimation of GMM-based acoustic model\n"
35  "Usage: gmm-est-map [options] <model-in> <stats-in> <model-out>\n"
36  "e.g.: gmm-est-map 1.mdl 1.acc 2.mdl\n";
37 
38  bool binary_write = true;
40  MapDiagGmmOptions gmm_opts;
41  std::string update_flags_str = "mvwt";
42  std::string occs_out_filename;
43 
44  ParseOptions po(usage);
45  po.Register("binary", &binary_write, "Write output in binary mode");
46  po.Register("update-flags", &update_flags_str, "Which GMM parameters to "
47  "update: subset of mvwt.");
48  po.Register("write-occs", &occs_out_filename, "File to write state "
49  "occupancies to.");
50  tcfg.Register(&po);
51  gmm_opts.Register(&po);
52 
53  po.Read(argc, argv);
54 
55  if (po.NumArgs() != 3) {
56  po.PrintUsage();
57  exit(1);
58  }
59 
60  kaldi::GmmFlagsType update_flags =
61  StringToGmmFlags(update_flags_str);
62 
63  std::string model_in_filename = po.GetArg(1),
64  stats_filename = po.GetArg(2),
65  model_out_filename = po.GetArg(3);
66 
67  AmDiagGmm am_gmm;
68  TransitionModel trans_model;
69  {
70  bool binary_read;
71  Input ki(model_in_filename, &binary_read);
72  trans_model.Read(ki.Stream(), binary_read);
73  am_gmm.Read(ki.Stream(), binary_read);
74  }
75 
76  Vector<double> transition_accs;
77  AccumAmDiagGmm gmm_accs;
78  {
79  bool binary;
80  Input ki(stats_filename, &binary);
81  transition_accs.Read(ki.Stream(), binary);
82  gmm_accs.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
83  }
84 
85  if (update_flags & kGmmTransitions) { // Update transition model.
86  BaseFloat objf_impr, count;
87  trans_model.MapUpdate(transition_accs, tcfg, &objf_impr, &count);
88  KALDI_LOG << "Transition model update: Overall " << (objf_impr/count)
89  << " log-like improvement per frame over " << (count)
90  << " frames.";
91  }
92 
93  { // Update GMMs.
94  BaseFloat objf_impr, count;
95  BaseFloat tot_like = gmm_accs.TotLogLike(),
96  tot_t = gmm_accs.TotCount();
97  MapAmDiagGmmUpdate(gmm_opts, gmm_accs, update_flags, &am_gmm,
98  &objf_impr, &count);
99  KALDI_LOG << "GMM update: Overall " << (objf_impr/count)
100  << " objective function improvement per frame over "
101  << count << " frames";
102  KALDI_LOG << "GMM update: Overall avg like per frame = "
103  << (tot_like/tot_t) << " over " << tot_t << " frames.";
104  }
105 
106  if (!occs_out_filename.empty()) { // get state occs
107  Vector<BaseFloat> state_occs;
108  state_occs.Resize(gmm_accs.NumAccs());
109  for (int i = 0; i < gmm_accs.NumAccs(); i++)
110  state_occs(i) = gmm_accs.GetAcc(i).occupancy().Sum();
111  bool binary = false;
112  WriteKaldiObject(state_occs, occs_out_filename, binary);
113  }
114 
115  {
116  Output ko(model_out_filename, binary_write);
117  trans_model.Write(ko.Stream(), binary_write);
118  am_gmm.Write(ko.Stream(), binary_write);
119  }
120  KALDI_LOG << "Written model to " << model_out_filename;
121  return 0;
122  } catch(const std::exception &e) {
123  std::cerr << e.what() << '\n';
124  return -1;
125  }
126 }
127 
128 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void MapAmDiagGmmUpdate(const MapDiagGmmOptions &config, const AccumAmDiagGmm &am_diag_gmm_acc, GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
Maximum A Posteriori update.
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
Definition: model-common.cc:26
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
const VectorBase< double > & occupancy() const
Definition: mle-diag-gmm.h:183
kaldi::int32 int32
BaseFloat TotCount() const
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Register(OptionsItf *opts)
Definition: mle-diag-gmm.h:93
void Register(const std::string &name, bool *ptr, const std::string &doc)
const size_t count
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
BaseFloat TotLogLike() const
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Register(OptionsItf *opts)
void Read(std::istream &is, bool binary)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void Read(std::istream &in_stream, bool binary, bool add=false)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
const AccumDiagGmm & GetAcc(int32 index) const
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
int main(int argc, char *argv[])
Definition: gmm-est-map.cc:28
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.
void MapUpdate(const Vector< double > &stats, const MapTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum A Posteriori (MAP) estimation.
Configuration variables for Maximum A Posteriori (MAP) update.
Definition: mle-diag-gmm.h:76