sgmm2-est-ebw.cc
Go to the documentation of this file.
1 // sgmm2bin/sgmm2-est-ebw.cc
2 
3 // Copyright 2012 Johns Hopkins Univerity (Author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "util/kaldi-thread.h"
24 #include "hmm/transition-model.h"
26 
27 
28 int main(int argc, char *argv[]) {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31  using std::string;
32  try {
33  const char *usage =
34  "Estimate SGMM model parameters discriminatively using Extended\n"
35  "Baum-Welch style of update\n"
36  "Usage: sgmm2-est-ebw [options] <model-in> <num-stats-in> <den-stats-in> <model-out>\n";
37 
38 
39  string update_flags_str = "vMNwcSt";
40  bool binary_write = true;
41  string write_flags_str = "gsnu";
42  EbwAmSgmm2Options opts;
43 
44 
45  ParseOptions po(usage);
46  po.Register("binary", &binary_write, "Write output in binary mode");
47  po.Register("update-flags", &update_flags_str, "Which SGMM parameters to "
48  "update: subset of vMNwcSt.");
49  po.Register("write-flags", &write_flags_str, "Which SGMM parameters to "
50  "write: subset of gsnu");
51  po.Register("num-threads", &g_num_threads, "Number of threads to use in "
52  "weight update and normalizer computation");
53  opts.Register(&po);
54 
55  po.Read(argc, argv);
56  if (po.NumArgs() != 4) {
57  po.PrintUsage();
58  exit(1);
59  }
60  string model_in_filename = po.GetArg(1),
61  num_stats_filename = po.GetArg(2),
62  den_stats_filename = po.GetArg(3),
63  model_out_filename = po.GetArg(4);
64 
65  SgmmUpdateFlagsType update_flags = StringToSgmmUpdateFlags(update_flags_str);
66  SgmmWriteFlagsType write_flags = StringToSgmmWriteFlags(write_flags_str);
67 
68  AmSgmm2 am_sgmm;
69  TransitionModel trans_model;
70  {
71  bool binary;
72  Input ki(model_in_filename, &binary);
73  trans_model.Read(ki.Stream(), binary);
74  am_sgmm.Read(ki.Stream(), binary);
75  }
76 
77  MleAmSgmm2Accs sgmm_num_accs;
78  {
79  bool binary;
80  Vector<double> transition_accs; // won't be used.
81  Input ki(num_stats_filename, &binary);
82  transition_accs.Read(ki.Stream(), binary);
83  sgmm_num_accs.Read(ki.Stream(), binary, false); // false == add; doesn't matter.
84  }
85  MleAmSgmm2Accs sgmm_den_accs;
86  {
87  bool binary;
88  Vector<double> transition_accs; // won't be used.
89  Input ki(den_stats_filename, &binary);
90  transition_accs.Read(ki.Stream(), binary);
91  sgmm_den_accs.Read(ki.Stream(), binary, false); // false == add; doesn't matter.
92  }
93 
94  sgmm_num_accs.Check(am_sgmm, true); // Will check consistency and print some diagnostics.
95  sgmm_den_accs.Check(am_sgmm, true); // Will check consistency and print some diagnostics.
96 
97  { // Update SGMM.
98  BaseFloat auxf_impr, count;
99  kaldi::EbwAmSgmm2Updater sgmm_updater(opts);
100  sgmm_updater.Update(sgmm_num_accs, sgmm_den_accs, &am_sgmm,
101  update_flags, &auxf_impr, &count);
102  KALDI_LOG << "Overall auxf impr/frame from SGMM update is " << (auxf_impr/count)
103  << " over " << count << " frames.";
104  }
105 
106  {
107  Output ko(model_out_filename, binary_write);
108  trans_model.Write(ko.Stream(), binary_write);
109  am_sgmm.Write(ko.Stream(), binary_write, write_flags);
110  }
111 
112  KALDI_LOG << "Wrote model to " << model_out_filename;
113  return 0;
114  } catch(const std::exception &e) {
115  std::cerr << e.what();
116  return -1;
117  }
118 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
uint16 SgmmWriteFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:70
void Write(std::ostream &os, bool binary, SgmmWriteFlagsType write_params) const
Definition: am-sgmm2.cc:203
Class for definition of the subspace Gmm acoustic model.
Definition: am-sgmm2.h:231
This header implements a form of Extended Baum-Welch training for SGMMs.
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Read(std::istream &in_stream, bool binary, bool add)
int32 g_num_threads
Definition: kaldi-thread.cc:25
void Read(std::istream &is, bool binary)
Definition: am-sgmm2.cc:89
kaldi::int32 int32
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
Definition: model-common.cc:64
void Register(const std::string &name, bool *ptr, const std::string &doc)
const size_t count
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:59
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Read(std::istream &is, bool binary)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
void Register(OptionsItf *opts)
void Update(const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, AmSgmm2 *model, SgmmUpdateFlagsType flags, BaseFloat *auxf_change_out, BaseFloat *count_out)
SgmmUpdateFlagsType StringToSgmmWriteFlags(std::string str)
Definition: model-common.cc:86
void Check(const AmSgmm2 &model, bool show_properties=true) const
Checks the various accumulators for correct sizes given a model.
Class for the accumulators associated with the phonetic-subspace model parameters.
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.