All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
gmm-est-weights-ebw.cc File Reference
Include dependency graph for gmm-est-weights-ebw.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file gmm-est-weights-ebw.cc.

References count, ParseOptions::GetArg(), KALDI_LOG, kaldi::kGmmWeights, ParseOptions::NumArgs(), ParseOptions::PrintUsage(), AccumAmDiagGmm::Read(), AmDiagGmm::Read(), ParseOptions::Read(), TransitionModel::Read(), Vector< Real >::Read(), EbwWeightOptions::Register(), ParseOptions::Register(), Output::Stream(), Input::Stream(), kaldi::StringToGmmFlags(), AccumAmDiagGmm::TotCount(), kaldi::UpdateEbwWeightsAmDiagGmm(), AmDiagGmm::Write(), and TransitionModel::Write().

27  {
28  try {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31 
32  const char *usage =
33  "Do EBW update on weights for MMI, MPE or MCE discriminative training.\n"
34  "Numerator stats should not be I-smoothed\n"
35  "Usage: gmm-est-weights-ebw [options] <model-in> <stats-num-in> <stats-den-in> <model-out>\n"
36  "e.g.: gmm-est-weights-ebw 1.mdl num.acc den.acc 2.mdl\n";
37 
38  bool binary_write = false;
39  std::string update_flags_str = "w";
40 
41  EbwWeightOptions ebw_weight_opts;
42  ParseOptions po(usage);
43  po.Register("binary", &binary_write, "Write output in binary mode");
44  po.Register("update-flags", &update_flags_str, "Which GMM parameters to "
45  "update; only \"w\" flag is looked at.");
46 
47  ebw_weight_opts.Register(&po);
48 
49  po.Read(argc, argv);
50 
51  if (po.NumArgs() != 4) {
52  po.PrintUsage();
53  exit(1);
54  }
55 
56  kaldi::GmmFlagsType update_flags =
57  StringToGmmFlags(update_flags_str);
58 
59  std::string model_in_filename = po.GetArg(1),
60  num_stats_filename = po.GetArg(2),
61  den_stats_filename = po.GetArg(3),
62  model_out_filename = po.GetArg(4);
63 
64  AmDiagGmm am_gmm;
65  TransitionModel trans_model;
66  {
67  bool binary_read;
68  Input ki(model_in_filename, &binary_read);
69  trans_model.Read(ki.Stream(), binary_read);
70  am_gmm.Read(ki.Stream(), binary_read);
71  }
72 
73  Vector<double> num_transition_accs; // won't be used.
74  Vector<double> den_transition_accs; // won't be used.
75 
76  AccumAmDiagGmm num_stats;
77  AccumAmDiagGmm den_stats;
78  {
79  bool binary;
80  Input ki(num_stats_filename, &binary);
81  num_transition_accs.Read(ki.Stream(), binary);
82  num_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
83  }
84 
85  {
86  bool binary;
87  Input ki(den_stats_filename, &binary);
88  num_transition_accs.Read(ki.Stream(), binary);
89  den_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
90  }
91 
92  if (update_flags & kGmmWeights) { // Update weights.
93  BaseFloat auxf_impr, count;
94  UpdateEbwWeightsAmDiagGmm(num_stats, den_stats, ebw_weight_opts, &am_gmm,
95  &auxf_impr, &count);
96  KALDI_LOG << "Num count " << num_stats.TotCount() << ", den count "
97  << den_stats.TotCount();
98  KALDI_LOG << "Overall auxf impr/frame from weight update is " << (auxf_impr/count)
99  << " over " << count << " frames.";
100  } else {
101  KALDI_LOG << "Doing nothing because flags do not specify to update the weights.";
102  }
103 
104  {
105  Output ko(model_out_filename, binary_write);
106  trans_model.Write(ko.Stream(), binary_write);
107  am_gmm.Write(ko.Stream(), binary_write);
108  }
109 
110  KALDI_LOG << "Written model to " << model_out_filename;
111 
112  } catch(const std::exception &e) {
113  std::cerr << e.what() << '\n';
114  return -1;
115  }
116 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
Definition: model-common.cc:26
BaseFloat TotCount() const
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
void UpdateEbwWeightsAmDiagGmm(const AccumAmDiagGmm &num_stats, const AccumAmDiagGmm &den_stats, const EbwWeightOptions &opts, AmDiagGmm *am_gmm, BaseFloat *auxf_change_out, BaseFloat *count_out)
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
void Register(OptionsItf *opts)
Definition: ebw-diag-gmm.h:55
void Read(std::istream &in_stream, bool binary, bool add=false)
void Write(std::ostream &os, bool binary) const
#define KALDI_LOG
Definition: kaldi-error.h:133
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.