All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
gmm-est-map.cc File Reference
Include dependency graph for gmm-est-map.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 28 of file gmm-est-map.cc.

References count, AccumAmDiagGmm::GetAcc(), ParseOptions::GetArg(), rnnlm::i, KALDI_LOG, kaldi::kGmmTransitions, kaldi::MapAmDiagGmmUpdate(), TransitionModel::MapUpdate(), AccumAmDiagGmm::NumAccs(), ParseOptions::NumArgs(), AccumDiagGmm::occupancy(), ParseOptions::PrintUsage(), AccumAmDiagGmm::Read(), AmDiagGmm::Read(), ParseOptions::Read(), TransitionModel::Read(), Vector< Real >::Read(), ParseOptions::Register(), MapDiagGmmOptions::Register(), MapTransitionUpdateConfig::Register(), Vector< Real >::Resize(), Output::Stream(), Input::Stream(), kaldi::StringToGmmFlags(), AccumAmDiagGmm::TotCount(), AccumAmDiagGmm::TotLogLike(), AmDiagGmm::Write(), TransitionModel::Write(), and kaldi::WriteKaldiObject().

28  {
29  try {
30  using namespace kaldi;
31  typedef kaldi::int32 int32;
32 
33  const char *usage =
34  "Do Maximum A Posteriori re-estimation of GMM-based acoustic model\n"
35  "Usage: gmm-est-map [options] <model-in> <stats-in> <model-out>\n"
36  "e.g.: gmm-est-map 1.mdl 1.acc 2.mdl\n";
37 
38  bool binary_write = true;
40  MapDiagGmmOptions gmm_opts;
41  std::string update_flags_str = "mvwt";
42  std::string occs_out_filename;
43 
44  ParseOptions po(usage);
45  po.Register("binary", &binary_write, "Write output in binary mode");
46  po.Register("update-flags", &update_flags_str, "Which GMM parameters to "
47  "update: subset of mvwt.");
48  po.Register("write-occs", &occs_out_filename, "File to write state "
49  "occupancies to.");
50  tcfg.Register(&po);
51  gmm_opts.Register(&po);
52 
53  po.Read(argc, argv);
54 
55  if (po.NumArgs() != 3) {
56  po.PrintUsage();
57  exit(1);
58  }
59 
60  kaldi::GmmFlagsType update_flags =
61  StringToGmmFlags(update_flags_str);
62 
63  std::string model_in_filename = po.GetArg(1),
64  stats_filename = po.GetArg(2),
65  model_out_filename = po.GetArg(3);
66 
67  AmDiagGmm am_gmm;
68  TransitionModel trans_model;
69  {
70  bool binary_read;
71  Input ki(model_in_filename, &binary_read);
72  trans_model.Read(ki.Stream(), binary_read);
73  am_gmm.Read(ki.Stream(), binary_read);
74  }
75 
76  Vector<double> transition_accs;
77  AccumAmDiagGmm gmm_accs;
78  {
79  bool binary;
80  Input ki(stats_filename, &binary);
81  transition_accs.Read(ki.Stream(), binary);
82  gmm_accs.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
83  }
84 
85  if (update_flags & kGmmTransitions) { // Update transition model.
86  BaseFloat objf_impr, count;
87  trans_model.MapUpdate(transition_accs, tcfg, &objf_impr, &count);
88  KALDI_LOG << "Transition model update: Overall " << (objf_impr/count)
89  << " log-like improvement per frame over " << (count)
90  << " frames.";
91  }
92 
93  { // Update GMMs.
94  BaseFloat objf_impr, count;
95  BaseFloat tot_like = gmm_accs.TotLogLike(),
96  tot_t = gmm_accs.TotCount();
97  MapAmDiagGmmUpdate(gmm_opts, gmm_accs, update_flags, &am_gmm,
98  &objf_impr, &count);
99  KALDI_LOG << "GMM update: Overall " << (objf_impr/count)
100  << " objective function improvement per frame over "
101  << count << " frames";
102  KALDI_LOG << "GMM update: Overall avg like per frame = "
103  << (tot_like/tot_t) << " over " << tot_t << " frames.";
104  }
105 
106  if (!occs_out_filename.empty()) { // get state occs
107  Vector<BaseFloat> state_occs;
108  state_occs.Resize(gmm_accs.NumAccs());
109  for (int i = 0; i < gmm_accs.NumAccs(); i++)
110  state_occs(i) = gmm_accs.GetAcc(i).occupancy().Sum();
111  bool binary = false;
112  WriteKaldiObject(state_occs, occs_out_filename, binary);
113  }
114 
115  {
116  Output ko(model_out_filename, binary_write);
117  trans_model.Write(ko.Stream(), binary_write);
118  am_gmm.Write(ko.Stream(), binary_write);
119  }
120  KALDI_LOG << "Written model to " << model_out_filename;
121  return 0;
122  } catch(const std::exception &e) {
123  std::cerr << e.what() << '\n';
124  return -1;
125  }
126 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void MapAmDiagGmmUpdate(const MapDiagGmmOptions &config, const AccumAmDiagGmm &am_diag_gmm_acc, GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
Maximum A Posteriori update.
const AccumDiagGmm & GetAcc(int32 index) const
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
Definition: model-common.cc:26
BaseFloat TotCount() const
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Register(OptionsItf *opts)
Definition: mle-diag-gmm.h:93
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Register(OptionsItf *opts)
void Read(std::istream &is, bool binary)
BaseFloat TotLogLike() const
void Read(std::istream &in_stream, bool binary, bool add=false)
void Write(std::ostream &os, bool binary) const
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
const VectorBase< double > & occupancy() const
Definition: mle-diag-gmm.h:183
#define KALDI_LOG
Definition: kaldi-error.h:133
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.
void MapUpdate(const Vector< double > &stats, const MapTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum A Posteriori (MAP) estimation.
Configuration variables for Maximum A Posteriori (MAP) update.
Definition: mle-diag-gmm.h:76