gmm-est-gaussians-ebw.cc
Go to the documentation of this file.
1 // gmmbin/gmm-est-gaussians-ebw.cc
2 
3 // Copyright 2009-2011 Petr Motlicek Chao Weng
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "gmm/am-diag-gmm.h"
23 #include "tree/context-dep.h"
24 #include "hmm/transition-model.h"
25 #include "gmm/ebw-diag-gmm.h"
26 
27 int main(int argc, char *argv[]) {
28  try {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31 
32  const char *usage =
33  "Do EBW update for MMI, MPE or MCE discriminative training.\n"
34  "Numerator stats should already be I-smoothed (e.g. use gmm-ismooth-stats)\n"
35  "Usage: gmm-est-gaussians-ebw [options] <model-in> <stats-num-in> <stats-den-in> <model-out>\n"
36  "e.g.: gmm-est-gaussians-ebw 1.mdl num.acc den.acc 2.mdl\n";
37 
38  bool binary_write = false;
39  std::string update_flags_str = "mv";
40 
41  EbwOptions ebw_opts;
42  ParseOptions po(usage);
43  po.Register("binary", &binary_write, "Write output in binary mode");
44  po.Register("update-flags", &update_flags_str, "Which GMM parameters to "
45  "update: e.g. m or mv (w, t ignored).");
46 
47  ebw_opts.Register(&po);
48 
49  po.Read(argc, argv);
50 
51  if (po.NumArgs() != 4) {
52  po.PrintUsage();
53  exit(1);
54  }
55 
56  kaldi::GmmFlagsType update_flags =
57  StringToGmmFlags(update_flags_str);
58 
59  std::string model_in_filename = po.GetArg(1),
60  num_stats_filename = po.GetArg(2),
61  den_stats_filename = po.GetArg(3),
62  model_out_filename = po.GetArg(4);
63 
64  AmDiagGmm am_gmm;
65  TransitionModel trans_model;
66  {
67  bool binary_read;
68  Input ki(model_in_filename, &binary_read);
69  trans_model.Read(ki.Stream(), binary_read);
70  am_gmm.Read(ki.Stream(), binary_read);
71  }
72 
73  Vector<double> num_transition_accs; // won't be used.
74  Vector<double> den_transition_accs; // won't be used.
75 
76  AccumAmDiagGmm num_stats;
77  AccumAmDiagGmm den_stats;
78  {
79  bool binary;
80  Input ki(num_stats_filename, &binary);
81  num_transition_accs.Read(ki.Stream(), binary);
82  num_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
83  }
84 
85  {
86  bool binary;
87  Input ki(den_stats_filename, &binary);
88  num_transition_accs.Read(ki.Stream(), binary);
89  den_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
90  }
91 
92 
93  { // Update GMMs.
94  BaseFloat auxf_impr, count;
95  int32 num_floored;
96  UpdateEbwAmDiagGmm(num_stats, den_stats, update_flags, ebw_opts, &am_gmm,
97  &auxf_impr, &count, &num_floored);
98  KALDI_LOG << "Num count " << num_stats.TotStatsCount() << ", den count "
99  << den_stats.TotStatsCount();
100  KALDI_LOG << "Overall auxf impr/frame from Gaussian update is " << (auxf_impr/count)
101  << " over " << count << " frames; floored D for "
102  << num_floored << " Gaussians.";
103  }
104 
105  {
106  Output ko(model_out_filename, binary_write);
107  trans_model.Write(ko.Stream(), binary_write);
108  am_gmm.Write(ko.Stream(), binary_write);
109  }
110 
111  KALDI_LOG << "Written model to " << model_out_filename;
112 
113  } catch(const std::exception &e) {
114  std::cerr << e.what() << '\n';
115  return -1;
116  }
117 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
Definition: model-common.cc:26
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void Register(const std::string &name, bool *ptr, const std::string &doc)
const size_t count
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Read(std::istream &is, bool binary)
BaseFloat TotStatsCount() const
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void Read(std::istream &in_stream, bool binary, bool add=false)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
void UpdateEbwAmDiagGmm(const AccumAmDiagGmm &num_stats, const AccumAmDiagGmm &den_stats, GmmFlagsType flags, const EbwOptions &opts, AmDiagGmm *am_gmm, BaseFloat *auxf_change_out, BaseFloat *count_out, int32 *num_floored_out)
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.
void Register(OptionsItf *opts)
Definition: ebw-diag-gmm.h:39