sgmm2-acc-stats.cc File Reference
Include dependency graph for sgmm2-acc-stats.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 30 of file sgmm2-acc-stats.cc.

References MleAmSgmm2Accs::Accumulate(), TransitionModel::Accumulate(), Sgmm2PerSpkDerivedVars::Clear(), MleAmSgmm2Accs::CommitStatsForSpk(), AmSgmm2::ComputePerFrameVars(), AmSgmm2::ComputePerSpkDerivedVars(), kaldi::ConvertPosteriorToPdfs(), SequentialTableReader< Holder >::Done(), Sgmm2PerSpkDerivedVars::Empty(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), RandomAccessTableReaderMapped< Holder >::HasKey(), rnnlm::i, TransitionModel::InitStats(), RandomAccessTableReaderMapped< Holder >::IsOpen(), rnnlm::j, KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), TransitionModel::Read(), AmSgmm2::Read(), ParseOptions::Register(), MleAmSgmm2Accs::ResizeAccumulators(), MatrixBase< Real >::Row(), Sgmm2PerSpkDerivedVars::SetSpeakerVector(), Output::Stream(), Input::Stream(), kaldi::StringToSgmmUpdateFlags(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), RandomAccessTableReaderMapped< Holder >::Value(), MleAmSgmm2Accs::Write(), and VectorBase< Real >::Write().

30  {
31  using namespace kaldi;
32  try {
33  const char *usage =
34  "Accumulate stats for SGMM training.\n"
35  "Usage: sgmm2-acc-stats [options] <model-in> <feature-rspecifier> "
36  "<posteriors-rspecifier> <stats-out>\n"
37  "e.g.: sgmm2-acc-stats --gselect=ark:gselect.ark 1.mdl 1.ali scp:train.scp 'ark:ali-to-post 1.ali ark:-|' 1.acc\n"
38  "(note: gselect option is mandatory)\n";
39 
40  ParseOptions po(usage);
41  bool binary = true;
42  std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
43  std::string update_flags_str = "vMNwcSt";
44  BaseFloat rand_prune = 1.0e-05;
45 
46  po.Register("binary", &binary, "Write output in binary mode");
47  po.Register("gselect", &gselect_rspecifier, "Precomputed Gaussian indices (rspecifier)");
48  po.Register("spk-vecs", &spkvecs_rspecifier, "Speaker vectors (rspecifier)");
49  po.Register("utt2spk", &utt2spk_rspecifier,
50  "rspecifier for utterance to speaker map");
51  po.Register("rand-prune", &rand_prune, "Pruning threshold for posteriors");
52  po.Register("update-flags", &update_flags_str, "Which SGMM parameters to accumulate "
53  "stats for: subset of vMNwcS.");
54 
55  po.Read(argc, argv);
56 
57  kaldi::SgmmUpdateFlagsType acc_flags = StringToSgmmUpdateFlags(update_flags_str);
58 
59  if (po.NumArgs() != 4) {
60  po.PrintUsage();
61  exit(1);
62  }
63  if (gselect_rspecifier == "")
64  KALDI_ERR << "--gselect option is mandatory.";
65 
66  std::string model_filename = po.GetArg(1),
67  feature_rspecifier = po.GetArg(2),
68  posteriors_rspecifier = po.GetArg(3),
69  accs_wxfilename = po.GetArg(4);
70 
71  using namespace kaldi;
72  typedef kaldi::int32 int32;
73 
74  int32 num_done = 0, num_err = 0;
75  Vector<double> transition_accs;
76  MleAmSgmm2Accs sgmm_accs(rand_prune);
77 
78  { // this anonymous scope is to ensure deallocation of unnecessary stuff
79  // while we're writing out the accs, which could be a long time for large
80  // models.
81 
82  // Initialize the readers before the model, as the model can
83  // be large, and we don't want to call fork() after reading it if
84  // virtual memory may be low.
85  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
86  RandomAccessPosteriorReader posteriors_reader(posteriors_rspecifier);
87  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
88  RandomAccessBaseFloatVectorReaderMapped spkvecs_reader(spkvecs_rspecifier,
89  utt2spk_rspecifier);
90  RandomAccessTokenReader utt2spk_map(utt2spk_rspecifier);
91 
92  AmSgmm2 am_sgmm;
93  TransitionModel trans_model;
94  {
95  bool binary;
96  Input ki(model_filename, &binary);
97  trans_model.Read(ki.Stream(), binary);
98  am_sgmm.Read(ki.Stream(), binary);
99  }
100 
101 
102  trans_model.InitStats(&transition_accs);
103  sgmm_accs.ResizeAccumulators(am_sgmm, acc_flags, (spkvecs_rspecifier!=""));
104 
105  double tot_like = 0.0;
106  double tot_t = 0;
107 
108  kaldi::Sgmm2PerFrameDerivedVars per_frame_vars;
109  std::string cur_spk;
110  Sgmm2PerSpkDerivedVars spk_vars;
111 
112  for (; !feature_reader.Done(); feature_reader.Next()) {
113  std::string utt = feature_reader.Key();
114  std::string spk = utt;
115  if (!utt2spk_rspecifier.empty()) {
116  if (!utt2spk_map.HasKey(utt)) {
117  KALDI_WARN << "utt2spk map does not have value for " << utt
118  << ", ignoring this utterance.";
119  continue;
120  } else { spk = utt2spk_map.Value(utt); }
121  }
122 
123  if (spk != cur_spk && cur_spk != "")
124  sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
125 
126  if (spk != cur_spk || spk_vars.Empty()) {
127  spk_vars.Clear();
128  if (spkvecs_reader.IsOpen()) {
129  if (spkvecs_reader.HasKey(utt)) {
130  spk_vars.SetSpeakerVector(spkvecs_reader.Value(utt));
131  am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
132  } else {
133  KALDI_WARN << "Cannot find speaker vector for " << utt;
134  num_err++;
135  continue;
136  }
137  } // else spk_vars is "empty"
138  }
139 
140  cur_spk = spk;
141 
142  const Matrix<BaseFloat> &features = feature_reader.Value();
143  if (!posteriors_reader.HasKey(utt) ||
144  posteriors_reader.Value(utt).size() != features.NumRows()) {
145  KALDI_WARN << "No posterior info available for utterance "
146  << utt << " (or wrong size)";
147  num_err++;
148  continue;
149  }
150  const Posterior &posterior = posteriors_reader.Value(utt);
151 
152  if (!gselect_reader.HasKey(utt)
153  && gselect_reader.Value(utt).size() != features.NumRows()) {
154  KALDI_WARN << "No Gaussian-selection info available for utterance "
155  << utt << " (or wrong size)";
156  num_err++;
157  }
158  const std::vector<std::vector<int32> > &gselect =
159  gselect_reader.Value(utt);
160 
161  num_done++;
162 
163  BaseFloat tot_like_this_file = 0.0, tot_weight = 0.0;
164 
165  Posterior pdf_posterior;
166  ConvertPosteriorToPdfs(trans_model, posterior, &pdf_posterior);
167  for (size_t i = 0; i < posterior.size(); i++) {
168  am_sgmm.ComputePerFrameVars(features.Row(i), gselect[i], spk_vars,
169  &per_frame_vars);
170  // Accumulates for SGMM.
171  for (size_t j = 0; j < pdf_posterior[i].size(); j++) {
172  int32 pdf_id = pdf_posterior[i][j].first;
173  BaseFloat weight = pdf_posterior[i][j].second;
174  tot_like_this_file += sgmm_accs.Accumulate(am_sgmm, per_frame_vars,
175  pdf_id, weight, &spk_vars)
176  * weight;
177  tot_weight += weight;
178  }
179 
180  // Accumulates for transitions.
181  for (size_t j = 0; j < posterior[i].size(); j++) {
182  int32 tid = posterior[i][j].first;
183  BaseFloat weight = posterior[i][j].second;
184  trans_model.Accumulate(weight, tid, &transition_accs);
185  }
186  }
187 
188  KALDI_VLOG(2) << "Average like for this file is "
189  << (tot_like_this_file/tot_weight) << " over "
190  << tot_weight <<" frames.";
191  tot_like += tot_like_this_file;
192  tot_t += tot_weight;
193  if (num_done % 50 == 0) {
194  KALDI_LOG << "Processed " << num_done << " utterances; for utterance "
195  << utt << " avg. like is "
196  << (tot_like_this_file/tot_weight)
197  << " over " << tot_weight <<" frames.";
198  }
199  }
200  sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars); // commit stats for
201  // last speaker.
202 
203  KALDI_LOG << "Overall like per frame (Gaussian only) = "
204  << (tot_like/tot_t) << " over " << tot_t << " frames.";
205 
206  KALDI_LOG << "Done " << num_done << " files, " << num_err
207  << " with errors.";
208  }
209 
210  {
211  Output ko(accs_wxfilename, binary);
212  transition_accs.Write(ko.Stream(), binary);
213  sgmm_accs.Write(ko.Stream(), binary);
214  }
215  KALDI_LOG << "Written accs.";
216  return (num_done != 0 ? 0 : 1);
217  } catch(const std::exception &e) {
218  std::cerr << e.what();
219  return -1;
220  }
221 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Class for definition of the subspace Gmm acoustic model.
Definition: am-sgmm2.h:231
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void Read(std::istream &is, bool binary)
Definition: am-sgmm2.cc:89
kaldi::int32 int32
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
Definition: model-common.cc:64
void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const
Computes the per-speaker derived vars; assumes vars->v_s is already set up.
Definition: am-sgmm2.cc:1369
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
void InitStats(Vector< double > *stats) const
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:59
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
void ComputePerFrameVars(const VectorBase< BaseFloat > &data, const std::vector< int32 > &gselect, const Sgmm2PerSpkDerivedVars &spk_vars, Sgmm2PerFrameDerivedVars *per_frame_vars) const
This needs to be called with each new frame of data, prior to accumulation or likelihood evaluation: ...
Definition: am-sgmm2.cc:442
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void SetSpeakerVector(const Vector< BaseFloat > &v_s_in)
Definition: am-sgmm2.h:180
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
Definition: posterior.cc:322
Class for the accumulators associated with the phonetic-subspace model parameters.
#define KALDI_LOG
Definition: kaldi-error.h:153
Holds the per-frame precomputed quantities x(t), x_{i}(t), z_{i}(t), and n_{i}(t) (cf...
Definition: am-sgmm2.h:142