sgmm2-est.cc File Reference
Include dependency graph for sgmm2-est.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 30 of file sgmm2-est.cc.

References MleAmSgmm2Accs::Check(), AmSgmm2::ComputeDerivedVars(), kaldi::ComputeFeatureNormalizingTransform(), count, AmSgmm2::full_ubm(), kaldi::g_num_threads, ParseOptions::GetArg(), MleAmSgmm2Accs::GetStateOccupancies(), AmSgmm2::IncreasePhoneSpaceDim(), AmSgmm2::IncreaseSpkSpaceDim(), KALDI_LOG, kaldi::kSgmmTransitions, TransitionModel::MleUpdate(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), MleAmSgmm2Accs::Read(), TransitionModel::Read(), AmSgmm2::Read(), Vector< Real >::Read(), ParseOptions::Register(), MleAmSgmm2Options::Register(), MleTransitionUpdateConfig::Register(), Sgmm2SplitSubstatesConfig::Register(), AmSgmm2::RemoveSpeakerSpace(), Sgmm2SplitSubstatesConfig::split_substates, AmSgmm2::SplitSubstates(), Output::Stream(), Input::Stream(), kaldi::StringToSgmmUpdateFlags(), kaldi::StringToSgmmWriteFlags(), MleAmSgmm2Updater::Update(), TransitionModel::Write(), AmSgmm2::Write(), and VectorBase< Real >::Write().

30  {
31  try {
32  using namespace kaldi;
33  typedef kaldi::int32 int32;
34  const char *usage =
35  "Estimate SGMM model parameters from accumulated stats.\n"
36  "Usage: sgmm2-est [options] <model-in> <stats-in> <model-out>\n";
37 
38  bool binary_write = true;
39  std::string update_flags_str = "vMNwucSt";
40  std::string write_flags_str = "gsnu";
42  kaldi::MleAmSgmm2Options sgmm_opts;
44  int32 increase_phn_dim = 0;
45  int32 increase_spk_dim = 0;
46  bool remove_speaker_space = false;
47  bool spk_dep_weights = false;
48  std::string occs_out_filename;
49 
50  ParseOptions po(usage);
51  po.Register("binary", &binary_write, "Write output in binary mode");
52  po.Register("increase-phn-dim", &increase_phn_dim, "Increase phone-space "
53  "dimension as far as allowed towards this target.");
54  po.Register("increase-spk-dim", &increase_spk_dim, "Increase speaker-space "
55  "dimension as far as allowed towards this target.");
56  po.Register("spk-dep-weights", &spk_dep_weights, "If true, have speaker-"
57  "dependent weights (symmetric SGMM)-- this option only makes"
58  "a difference if you use the --increase-spk-dim option and "
59  "are increasing the speaker dimension from zero.");
60  po.Register("remove-speaker-space", &remove_speaker_space, "Remove speaker-specific "
61  "projections N");
62  po.Register("write-occs", &occs_out_filename, "File to write pdf "
63  "occupantion counts to.");
64  po.Register("update-flags", &update_flags_str, "Which SGMM parameters to "
65  "update: subset of vMNwcSt.");
66  po.Register("write-flags", &write_flags_str, "Which SGMM parameters to "
67  "write: subset of gsnu");
68  po.Register("num-threads", &g_num_threads, "Number of threads to use in "
69  "weight update and normalizer computation");
70  tcfg.Register(&po);
71  sgmm_opts.Register(&po);
72  split_opts.Register(&po);
73 
74  po.Read(argc, argv);
75  if (po.NumArgs() != 3) {
76  po.PrintUsage();
77  exit(1);
78  }
79  std::string model_in_filename = po.GetArg(1),
80  stats_filename = po.GetArg(2),
81  model_out_filename = po.GetArg(3);
82 
83  kaldi::SgmmUpdateFlagsType update_flags =
84  StringToSgmmUpdateFlags(update_flags_str);
85  kaldi::SgmmWriteFlagsType write_flags =
86  StringToSgmmWriteFlags(write_flags_str);
87 
88  AmSgmm2 am_sgmm;
89  TransitionModel trans_model;
90  {
91  bool binary;
92  Input ki(model_in_filename, &binary);
93  trans_model.Read(ki.Stream(), binary);
94  am_sgmm.Read(ki.Stream(), binary);
95  }
96 
97  Vector<double> transition_accs;
98  MleAmSgmm2Accs sgmm_accs;
99  {
100  bool binary;
101  Input ki(stats_filename, &binary);
102  transition_accs.Read(ki.Stream(), binary);
103  sgmm_accs.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
104  }
105 
106  if (update_flags & kSgmmTransitions) { // Update transition model.
107  BaseFloat objf_impr, count;
108  trans_model.MleUpdate(transition_accs, tcfg, &objf_impr, &count);
109  KALDI_LOG << "Transition model update: Overall " << (objf_impr/count)
110  << " log-like improvement per frame over " << (count)
111  << " frames.";
112  }
113 
114  sgmm_accs.Check(am_sgmm, true); // Will check consistency and print some diagnostics.
115 
116  { // Do the update.
117  kaldi::MleAmSgmm2Updater updater(sgmm_opts);
118  updater.Update(sgmm_accs, &am_sgmm, update_flags);
119  }
120 
121  Vector<BaseFloat> pdf_occs;
122  sgmm_accs.GetStateOccupancies(&pdf_occs);
123 
124  if (split_opts.split_substates != 0)
125  am_sgmm.SplitSubstates(pdf_occs, split_opts);
126 
127  if (!occs_out_filename.empty()) {
128  kaldi::Output ko(occs_out_filename, binary_write);
129  pdf_occs.Write(ko.Stream(), binary_write);
130  }
131 
132  if (increase_phn_dim != 0 || increase_spk_dim != 0) {
133  // Feature normalizing transform matrix used to initialize the new columns
134  // of the phonetic- or speaker-space projection matrices.
135  kaldi::Matrix<BaseFloat> norm_xform;
136  ComputeFeatureNormalizingTransform(am_sgmm.full_ubm(), &norm_xform);
137  if (increase_phn_dim != 0)
138  am_sgmm.IncreasePhoneSpaceDim(increase_phn_dim, norm_xform);
139  if (increase_spk_dim != 0)
140  am_sgmm.IncreaseSpkSpaceDim(increase_spk_dim, norm_xform,
141  spk_dep_weights);
142  }
143  if (remove_speaker_space) {
144  KALDI_LOG << "Removing speaker space (projections N_)";
145  am_sgmm.RemoveSpeakerSpace();
146  }
147 
148  am_sgmm.ComputeDerivedVars(); // recompute normalizers, and possibly
149  // weights.
150 
151  {
152  Output ko(model_out_filename, binary_write);
153  trans_model.Write(ko.Stream(), binary_write);
154  am_sgmm.Write(ko.Stream(), binary_write, write_flags);
155  }
156 
157 
158  KALDI_LOG << "Written model to " << model_out_filename;
159  return 0;
160  } catch(const std::exception &e) {
161  std::cerr << e.what();
162  return -1;
163  }
164 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
uint16 SgmmWriteFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:70
void Write(std::ostream &os, bool binary, SgmmWriteFlagsType write_params) const
Definition: am-sgmm2.cc:203
Class for definition of the subspace Gmm acoustic model.
Definition: am-sgmm2.h:231
void IncreasePhoneSpaceDim(int32 target_dim, const Matrix< BaseFloat > &norm_xform)
Functions for increasing the phonetic and speaker space dimensions.
Definition: am-sgmm2.cc:699
void MleUpdate(const Vector< double > &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum Likelihood estimation.
void GetStateOccupancies(Vector< BaseFloat > *occs) const
Accessors.
void Read(std::istream &in_stream, bool binary, bool add)
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
int32 g_num_threads
Definition: kaldi-thread.cc:25
void Read(std::istream &is, bool binary)
Definition: am-sgmm2.cc:89
kaldi::int32 int32
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
Definition: model-common.cc:64
void IncreaseSpkSpaceDim(int32 target_dim, const Matrix< BaseFloat > &norm_xform, bool speaker_dependent_weights)
Increase the subspace dimension for speakers.
Definition: am-sgmm2.cc:747
const FullGmm & full_ubm() const
Accessors.
Definition: am-sgmm2.h:378
void Register(OptionsItf *opts)
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:59
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void SplitSubstates(const Vector< BaseFloat > &state_occupancies, const Sgmm2SplitSubstatesConfig &config)
Increases the total number of substates based on the state occupancies.
Definition: am-sgmm2.cc:657
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
void ComputeDerivedVars()
Computes (and initializes if necessary) derived vars...
Definition: am-sgmm2.cc:810
void RemoveSpeakerSpace()
Definition: am-sgmm2.h:370
void Register(OptionsItf *opts)
SgmmUpdateFlagsType StringToSgmmWriteFlags(std::string str)
Definition: model-common.cc:86
void Check(const AmSgmm2 &model, bool show_properties=true) const
Checks the various accumulators for correct sizes given a model.
Configuration variables needed in the SGMM estimation process.
void Register(OptionsItf *opts)
Definition: am-sgmm2.h:102
Class for the accumulators associated with the phonetic-subspace model parameters.
void ComputeFeatureNormalizingTransform(const FullGmm &gmm, Matrix< BaseFloat > *xform)
Computes the inverse of an LDA transform (without dimensionality reduction) The computed transform is...
Definition: am-sgmm2.cc:1297
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.