gmm-est-gaussians-ebw.cc File Reference
Include dependency graph for gmm-est-gaussians-ebw.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file gmm-est-gaussians-ebw.cc.

References count, ParseOptions::GetArg(), KALDI_LOG, ParseOptions::NumArgs(), ParseOptions::PrintUsage(), AccumAmDiagGmm::Read(), AmDiagGmm::Read(), ParseOptions::Read(), TransitionModel::Read(), Vector< Real >::Read(), EbwOptions::Register(), ParseOptions::Register(), Output::Stream(), Input::Stream(), kaldi::StringToGmmFlags(), AccumAmDiagGmm::TotStatsCount(), kaldi::UpdateEbwAmDiagGmm(), AmDiagGmm::Write(), and TransitionModel::Write().

27  {
28  try {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31 
32  const char *usage =
33  "Do EBW update for MMI, MPE or MCE discriminative training.\n"
34  "Numerator stats should already be I-smoothed (e.g. use gmm-ismooth-stats)\n"
35  "Usage: gmm-est-gaussians-ebw [options] <model-in> <stats-num-in> <stats-den-in> <model-out>\n"
36  "e.g.: gmm-est-gaussians-ebw 1.mdl num.acc den.acc 2.mdl\n";
37 
38  bool binary_write = false;
39  std::string update_flags_str = "mv";
40 
41  EbwOptions ebw_opts;
42  ParseOptions po(usage);
43  po.Register("binary", &binary_write, "Write output in binary mode");
44  po.Register("update-flags", &update_flags_str, "Which GMM parameters to "
45  "update: e.g. m or mv (w, t ignored).");
46 
47  ebw_opts.Register(&po);
48 
49  po.Read(argc, argv);
50 
51  if (po.NumArgs() != 4) {
52  po.PrintUsage();
53  exit(1);
54  }
55 
56  kaldi::GmmFlagsType update_flags =
57  StringToGmmFlags(update_flags_str);
58 
59  std::string model_in_filename = po.GetArg(1),
60  num_stats_filename = po.GetArg(2),
61  den_stats_filename = po.GetArg(3),
62  model_out_filename = po.GetArg(4);
63 
64  AmDiagGmm am_gmm;
65  TransitionModel trans_model;
66  {
67  bool binary_read;
68  Input ki(model_in_filename, &binary_read);
69  trans_model.Read(ki.Stream(), binary_read);
70  am_gmm.Read(ki.Stream(), binary_read);
71  }
72 
73  Vector<double> num_transition_accs; // won't be used.
74  Vector<double> den_transition_accs; // won't be used.
75 
76  AccumAmDiagGmm num_stats;
77  AccumAmDiagGmm den_stats;
78  {
79  bool binary;
80  Input ki(num_stats_filename, &binary);
81  num_transition_accs.Read(ki.Stream(), binary);
82  num_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
83  }
84 
85  {
86  bool binary;
87  Input ki(den_stats_filename, &binary);
88  num_transition_accs.Read(ki.Stream(), binary);
89  den_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
90  }
91 
92 
93  { // Update GMMs.
94  BaseFloat auxf_impr, count;
95  int32 num_floored;
96  UpdateEbwAmDiagGmm(num_stats, den_stats, update_flags, ebw_opts, &am_gmm,
97  &auxf_impr, &count, &num_floored);
98  KALDI_LOG << "Num count " << num_stats.TotStatsCount() << ", den count "
99  << den_stats.TotStatsCount();
100  KALDI_LOG << "Overall auxf impr/frame from Gaussian update is " << (auxf_impr/count)
101  << " over " << count << " frames; floored D for "
102  << num_floored << " Gaussians.";
103  }
104 
105  {
106  Output ko(model_out_filename, binary_write);
107  trans_model.Write(ko.Stream(), binary_write);
108  am_gmm.Write(ko.Stream(), binary_write);
109  }
110 
111  KALDI_LOG << "Written model to " << model_out_filename;
112 
113  } catch(const std::exception &e) {
114  std::cerr << e.what() << '\n';
115  return -1;
116  }
117 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
Definition: model-common.cc:26
kaldi::int32 int32
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
BaseFloat TotStatsCount() const
void Read(std::istream &in_stream, bool binary, bool add=false)
void Write(std::ostream &os, bool binary) const
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
void UpdateEbwAmDiagGmm(const AccumAmDiagGmm &num_stats, const AccumAmDiagGmm &den_stats, GmmFlagsType flags, const EbwOptions &opts, AmDiagGmm *am_gmm, BaseFloat *auxf_change_out, BaseFloat *count_out, int32 *num_floored_out)
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.
void Register(OptionsItf *opts)
Definition: ebw-diag-gmm.h:39