All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages File Reference
Include dependency graph for

Go to the source code of this file.


int main (int argc, char *argv[])

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 

Definition at line 30 of file

References MleAmSgmm2Accs::Accumulate(), TransitionModel::Accumulate(), Sgmm2PerSpkDerivedVars::Clear(), MleAmSgmm2Accs::CommitStatsForSpk(), AmSgmm2::ComputePerFrameVars(), AmSgmm2::ComputePerSpkDerivedVars(), kaldi::ConvertPosteriorToPdfs(), SequentialTableReader< Holder >::Done(), Sgmm2PerSpkDerivedVars::Empty(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), RandomAccessTableReaderMapped< Holder >::HasKey(), rnnlm::i, TransitionModel::InitStats(), RandomAccessTableReaderMapped< Holder >::IsOpen(), rnnlm::j, KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), TransitionModel::Read(), AmSgmm2::Read(), ParseOptions::Register(), MleAmSgmm2Accs::ResizeAccumulators(), MatrixBase< Real >::Row(), Sgmm2PerSpkDerivedVars::SetSpeakerVector(), Output::Stream(), Input::Stream(), kaldi::StringToSgmmUpdateFlags(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), RandomAccessTableReaderMapped< Holder >::Value(), MleAmSgmm2Accs::Write(), and VectorBase< Real >::Write().

30  {
31  using namespace kaldi;
32  try {
33  const char *usage =
34  "Accumulate stats for SGMM training.\n"
35  "Usage: sgmm2-acc-stats [options] <model-in> <feature-rspecifier> "
36  "<posteriors-rspecifier> <stats-out>\n"
37  "e.g.: sgmm2-acc-stats --gselect=ark:gselect.ark 1.mdl 1.ali scp:train.scp 'ark:ali-to-post 1.ali ark:-|' 1.acc\n"
38  "(note: gselect option is mandatory)\n";
40  ParseOptions po(usage);
41  bool binary = true;
42  std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
43  std::string update_flags_str = "vMNwcSt";
44  BaseFloat rand_prune = 1.0e-05;
46  po.Register("binary", &binary, "Write output in binary mode");
47  po.Register("gselect", &gselect_rspecifier, "Precomputed Gaussian indices (rspecifier)");
48  po.Register("spk-vecs", &spkvecs_rspecifier, "Speaker vectors (rspecifier)");
49  po.Register("utt2spk", &utt2spk_rspecifier,
50  "rspecifier for utterance to speaker map");
51  po.Register("rand-prune", &rand_prune, "Pruning threshold for posteriors");
52  po.Register("update-flags", &update_flags_str, "Which SGMM parameters to accumulate "
53  "stats for: subset of vMNwcS.");
55  po.Read(argc, argv);
57  kaldi::SgmmUpdateFlagsType acc_flags = StringToSgmmUpdateFlags(update_flags_str);
59  if (po.NumArgs() != 4) {
60  po.PrintUsage();
61  exit(1);
62  }
63  if (gselect_rspecifier == "")
64  KALDI_ERR << "--gselect option is mandatory.";
66  std::string model_filename = po.GetArg(1),
67  feature_rspecifier = po.GetArg(2),
68  posteriors_rspecifier = po.GetArg(3),
69  accs_wxfilename = po.GetArg(4);
71  using namespace kaldi;
72  typedef kaldi::int32 int32;
74  int32 num_done = 0, num_err = 0;
75  Vector<double> transition_accs;
76  MleAmSgmm2Accs sgmm_accs(rand_prune);
78  { // this anonymous scope is to ensure deallocation of unnecessary stuff
79  // while we're writing out the accs, which could be a long time for large
80  // models.
82  // Initialize the readers before the model, as the model can
83  // be large, and we don't want to call fork() after reading it if
84  // virtual memory may be low.
85  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
86  RandomAccessPosteriorReader posteriors_reader(posteriors_rspecifier);
87  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
88  RandomAccessBaseFloatVectorReaderMapped spkvecs_reader(spkvecs_rspecifier,
89  utt2spk_rspecifier);
90  RandomAccessTokenReader utt2spk_map(utt2spk_rspecifier);
92  AmSgmm2 am_sgmm;
93  TransitionModel trans_model;
94  {
95  bool binary;
96  Input ki(model_filename, &binary);
97  trans_model.Read(ki.Stream(), binary);
98  am_sgmm.Read(ki.Stream(), binary);
99  }
102  trans_model.InitStats(&transition_accs);
103  sgmm_accs.ResizeAccumulators(am_sgmm, acc_flags, (spkvecs_rspecifier!=""));
105  double tot_like = 0.0;
106  double tot_t = 0;
108  kaldi::Sgmm2PerFrameDerivedVars per_frame_vars;
109  std::string cur_spk;
110  Sgmm2PerSpkDerivedVars spk_vars;
112  for (; !feature_reader.Done(); feature_reader.Next()) {
113  std::string utt = feature_reader.Key();
114  std::string spk = utt;
115  if (!utt2spk_rspecifier.empty()) {
116  if (!utt2spk_map.HasKey(utt)) {
117  KALDI_WARN << "utt2spk map does not have value for " << utt
118  << ", ignoring this utterance.";
119  continue;
120  } else { spk = utt2spk_map.Value(utt); }
121  }
123  if (spk != cur_spk && cur_spk != "")
124  sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
126  if (spk != cur_spk || spk_vars.Empty()) {
127  spk_vars.Clear();
128  if (spkvecs_reader.IsOpen()) {
129  if (spkvecs_reader.HasKey(utt)) {
130  spk_vars.SetSpeakerVector(spkvecs_reader.Value(utt));
131  am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
132  } else {
133  KALDI_WARN << "Cannot find speaker vector for " << utt;
134  num_err++;
135  continue;
136  }
137  } // else spk_vars is "empty"
138  }
140  cur_spk = spk;
142  const Matrix<BaseFloat> &features = feature_reader.Value();
143  if (!posteriors_reader.HasKey(utt) ||
144  posteriors_reader.Value(utt).size() != features.NumRows()) {
145  KALDI_WARN << "No posterior info available for utterance "
146  << utt << " (or wrong size)";
147  num_err++;
148  continue;
149  }
150  const Posterior &posterior = posteriors_reader.Value(utt);
152  if (!gselect_reader.HasKey(utt)
153  && gselect_reader.Value(utt).size() != features.NumRows()) {
154  KALDI_WARN << "No Gaussian-selection info available for utterance "
155  << utt << " (or wrong size)";
156  num_err++;
157  }
158  const std::vector<std::vector<int32> > &gselect =
159  gselect_reader.Value(utt);
161  num_done++;
163  BaseFloat tot_like_this_file = 0.0, tot_weight = 0.0;
165  Posterior pdf_posterior;
166  ConvertPosteriorToPdfs(trans_model, posterior, &pdf_posterior);
167  for (size_t i = 0; i < posterior.size(); i++) {
168  am_sgmm.ComputePerFrameVars(features.Row(i), gselect[i], spk_vars,
169  &per_frame_vars);
170  // Accumulates for SGMM.
171  for (size_t j = 0; j < pdf_posterior[i].size(); j++) {
172  int32 pdf_id = pdf_posterior[i][j].first;
173  BaseFloat weight = pdf_posterior[i][j].second;
174  tot_like_this_file += sgmm_accs.Accumulate(am_sgmm, per_frame_vars,
175  pdf_id, weight, &spk_vars)
176  * weight;
177  tot_weight += weight;
178  }
180  // Accumulates for transitions.
181  for (size_t j = 0; j < posterior[i].size(); j++) {
182  int32 tid = posterior[i][j].first;
183  BaseFloat weight = posterior[i][j].second;
184  trans_model.Accumulate(weight, tid, &transition_accs);
185  }
186  }
188  KALDI_VLOG(2) << "Average like for this file is "
189  << (tot_like_this_file/tot_weight) << " over "
190  << tot_weight <<" frames.";
191  tot_like += tot_like_this_file;
192  tot_t += tot_weight;
193  if (num_done % 50 == 0) {
194  KALDI_LOG << "Processed " << num_done << " utterances; for utterance "
195  << utt << " avg. like is "
196  << (tot_like_this_file/tot_weight)
197  << " over " << tot_weight <<" frames.";
198  }
199  }
200  sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars); // commit stats for
201  // last speaker.
203  KALDI_LOG << "Overall like per frame (Gaussian only) = "
204  << (tot_like/tot_t) << " over " << tot_t << " frames.";
206  KALDI_LOG << "Done " << num_done << " files, " << num_err
207  << " with errors.";
208  }
210  {
211  Output ko(accs_wxfilename, binary);
212  transition_accs.Write(ko.Stream(), binary);
213  sgmm_accs.Write(ko.Stream(), binary);
214  }
215  KALDI_LOG << "Written accs.";
216  return (num_done != 0 ? 0 : 1);
217  } catch(const std::exception &e) {
218  std::cerr << e.what();
219  return -1;
220  }
221 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Class for definition of the subspace Gmm acoustic model.
Definition: am-sgmm2.h:231
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void Read(std::istream &is, bool binary)
kaldi::int32 int32
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const
Computes the per-speaker derived vars; assumes vars->v_s is already set up.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
void InitStats(Vector< double > *stats) const
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:59
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
void ComputePerFrameVars(const VectorBase< BaseFloat > &data, const std::vector< int32 > &gselect, const Sgmm2PerSpkDerivedVars &spk_vars, Sgmm2PerFrameDerivedVars *per_frame_vars) const
This needs to be called with each new frame of data, prior to accumulation or likelihood evaluation: ...
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void SetSpeakerVector(const Vector< BaseFloat > &v_s_in)
Definition: am-sgmm2.h:180
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
Class for the accumulators associated with the phonetic-subspace model parameters.
#define KALDI_LOG
Definition: kaldi-error.h:153
Holds the per-frame precomputed quantities x(t), x_{i}(t), z_{i}(t), and n_{i}(t) (cf...
Definition: am-sgmm2.h:142