sgmm2-acc-stats2.cc File Reference
Include dependency graph for sgmm2-acc-stats2.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 29 of file sgmm2-acc-stats2.cc.

References TransitionModel::Accumulate(), Sgmm2PerSpkDerivedVars::Clear(), MleAmSgmm2Accs::CommitStatsForSpk(), AmSgmm2::ComputePerFrameVars(), AmSgmm2::ComputePerSpkDerivedVars(), SequentialTableReader< Holder >::Done(), Sgmm2PerSpkDerivedVars::Empty(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), RandomAccessTableReaderMapped< Holder >::HasKey(), AmSgmm2::HasSpeakerDependentWeights(), AmSgmm2::HasSpeakerSpace(), rnnlm::i, TransitionModel::InitStats(), RandomAccessTableReaderMapped< Holder >::IsOpen(), rnnlm::j, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), kaldi::kSgmmSpeakerProjections, kaldi::kSgmmSpeakerWeightProjections, kaldi::kSgmmTransitions, SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), TransitionModel::Read(), AmSgmm2::Read(), ParseOptions::Register(), MleAmSgmm2Accs::ResizeAccumulators(), MatrixBase< Real >::Row(), Sgmm2PerSpkDerivedVars::SetSpeakerVector(), Output::Stream(), Input::Stream(), kaldi::StringToSgmmUpdateFlags(), TransitionModel::TransitionIdToPdf(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), RandomAccessTableReaderMapped< Holder >::Value(), MleAmSgmm2Accs::Write(), and VectorBase< Real >::Write().

29  {
30  using namespace kaldi;
31  try {
32  const char *usage =
33  "Accumulate numerator and denominator stats for discriminative training\n"
34  "of SGMMs (input is posteriors of mixed sign)\n"
35  "Usage: sgmm2-acc-stats2 [options] <model-in> <feature-rspecifier> "
36  "<posteriors-rspecifier> <num-stats-out> <den-stats-out>\n"
37  "e.g.: sgmm2-acc-stats2 1.mdl 1.ali scp:train.scp ark:1.posts num.acc den.acc\n";
38 
39  ParseOptions po(usage);
40  bool binary = true;
41  std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
42  std::string update_flags_str = "vMNwucSt";
43  BaseFloat rand_prune = 1.0e-05;
44 
45  po.Register("binary", &binary, "Write output in binary mode");
46  po.Register("gselect", &gselect_rspecifier, "Precomputed Gaussian indices (rspecifier)");
47  po.Register("spk-vecs", &spkvecs_rspecifier, "Speaker vectors (rspecifier)");
48  po.Register("utt2spk", &utt2spk_rspecifier,
49  "rspecifier for utterance to speaker map");
50  po.Register("rand-prune", &rand_prune, "Pruning threshold for posteriors");
51  po.Register("update-flags", &update_flags_str, "Which SGMM parameters to accumulate "
52  "stats for: subset of vMNwcS.");
53 
54  po.Read(argc, argv);
55 
56  kaldi::SgmmUpdateFlagsType acc_flags = StringToSgmmUpdateFlags(update_flags_str);
57 
58  if (po.NumArgs() != 5) {
59  po.PrintUsage();
60  exit(1);
61  }
62 
63  std::string model_filename = po.GetArg(1),
64  feature_rspecifier = po.GetArg(2),
65  posteriors_rspecifier = po.GetArg(3),
66  num_accs_wxfilename = po.GetArg(4),
67  den_accs_wxfilename = po.GetArg(5);
68 
69 
70  using namespace kaldi;
71  typedef kaldi::int32 int32;
72  typedef kaldi::int64 int64;
73 
74  // Initialize the readers before the model, as the model can
75  // be large, and we don't want to call fork() after reading it if
76  // virtual memory may be low.
77  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
78  RandomAccessPosteriorReader posteriors_reader(posteriors_rspecifier);
79  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
80  RandomAccessBaseFloatVectorReaderMapped spkvecs_reader(spkvecs_rspecifier,
81  utt2spk_rspecifier);
82  RandomAccessTokenReader utt2spk_map(utt2spk_rspecifier);
83 
84  AmSgmm2 am_sgmm;
85  TransitionModel trans_model;
86  {
87  bool binary;
88  Input ki(model_filename, &binary);
89  trans_model.Read(ki.Stream(), binary);
90  am_sgmm.Read(ki.Stream(), binary);
91  }
92 
93  if (acc_flags & kSgmmSpeakerWeightProjections && !am_sgmm.HasSpeakerDependentWeights()) {
94  acc_flags &= ~kSgmmSpeakerWeightProjections;
95  KALDI_WARN << "Removing speaker weight projections (u) from flags "
96  "as not present in model\n";
97  }
98  if (acc_flags & kSgmmSpeakerProjections && !am_sgmm.HasSpeakerSpace()) {
99  acc_flags &= ~kSgmmSpeakerProjections;
100  KALDI_WARN << "Removing speaker projections (N) from flags "
101  "as not present in model\n";
102  }
103 
104  Vector<double> num_transition_accs, den_transition_accs;
105  if (acc_flags & kaldi::kSgmmTransitions) {
106  trans_model.InitStats(&num_transition_accs);
107  trans_model.InitStats(&den_transition_accs);
108  }
109  MleAmSgmm2Accs num_sgmm_accs(rand_prune), den_sgmm_accs(rand_prune);
110  bool have_spk_vecs = (spkvecs_rspecifier != "");
111  num_sgmm_accs.ResizeAccumulators(am_sgmm, acc_flags, have_spk_vecs);
112  den_sgmm_accs.ResizeAccumulators(am_sgmm, acc_flags, have_spk_vecs);
113 
114  double tot_like = 0.0, tot_weight = 0.0, tot_abs_weight = 0.0;
115  int64 tot_frames = 0;
116 
117  kaldi::Sgmm2PerFrameDerivedVars per_frame_vars;
118 
119  int32 num_done = 0, num_err = 0;
120  std::string cur_spk;
121  Sgmm2PerSpkDerivedVars spk_vars;
122 
123  for (; !feature_reader.Done(); feature_reader.Next()) {
124  std::string utt = feature_reader.Key();
125  std::string spk = utt;
126  if (!utt2spk_rspecifier.empty()) {
127  if (!utt2spk_map.HasKey(utt)) {
128  KALDI_WARN << "utt2spk map does not have value for " << utt
129  << ", ignoring this utterance.";
130  continue;
131  } else { spk = utt2spk_map.Value(utt); }
132  }
133  if (spk != cur_spk && cur_spk != "") {
134  num_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
135  den_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
136  }
137  if (spk != cur_spk || spk_vars.Empty()) {
138  spk_vars.Clear();
139  if (spkvecs_reader.IsOpen()) {
140  if (spkvecs_reader.HasKey(utt)) {
141  spk_vars.SetSpeakerVector(spkvecs_reader.Value(utt));
142  am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
143  } else {
144  KALDI_WARN << "Cannot find speaker vector for " << utt;
145  num_err++;
146  continue;
147  }
148  } // else spk_vars is "empty"
149  }
150  cur_spk = spk;
151 
152  const Matrix<BaseFloat> &features = feature_reader.Value();
153  if (!posteriors_reader.HasKey(utt) ||
154  posteriors_reader.Value(utt).size() != features.NumRows()) {
155  KALDI_WARN << "No posterior info available for utterance "
156  << utt << " (or wrong size)";
157  num_err++;
158  continue;
159  }
160 
161  const Posterior &posterior = posteriors_reader.Value(utt);
162  if (!gselect_reader.HasKey(utt)
163  && gselect_reader.Value(utt).size() != features.NumRows()) {
164  KALDI_WARN << "No Gaussian-selection info available for utterance "
165  << utt << " (or wrong size)";
166  num_err++;
167  }
168  const std::vector<std::vector<int32> > &gselect =
169  gselect_reader.Value(utt);
170 
171  num_done++;
172  BaseFloat tot_like_this_file = 0.0, tot_weight_this_file = 0.0,
173  tot_abs_weight_this_file = 0.0;
174 
175  for (size_t i = 0; i < posterior.size(); i++) {
176  if (posterior[i].empty())
177  continue;
178  am_sgmm.ComputePerFrameVars(features.Row(i), gselect[i], spk_vars,
179  &per_frame_vars);
180 
181  for (size_t j = 0; j < posterior[i].size(); j++) {
182  int32 tid = posterior[i][j].first, // transition identifier.
183  pdf_id = trans_model.TransitionIdToPdf(tid);
184  BaseFloat weight = posterior[i][j].second,
185  abs_weight = std::abs(weight);
186 
187  if (acc_flags & kaldi::kSgmmTransitions) {
188  trans_model.Accumulate(abs_weight, tid, weight > 0 ?
189  &num_transition_accs : &den_transition_accs);
190  }
191  tot_like_this_file +=
192  (weight > 0 ? num_sgmm_accs : den_sgmm_accs).Accumulate(
193  am_sgmm, per_frame_vars, pdf_id, abs_weight, &spk_vars)
194  * weight;
195  tot_weight_this_file += weight;
196  tot_abs_weight_this_file += abs_weight;
197  }
198  }
199  // Commit stats for the last speaker.
200  num_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
201  den_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
202 
203 
204  tot_like += tot_like_this_file;
205  tot_weight += tot_weight_this_file;
206  tot_abs_weight += tot_abs_weight_this_file;
207  tot_frames += posterior.size();
208  if (num_done % 50 == 0)
209  KALDI_LOG << "Processed " << num_done << " utterances.";
210  }
211  // Commit stats for last speaker.
212  num_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
213  den_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
214 
215  KALDI_LOG << "Overall weighted acoustic likelihood per frame was "
216  << (tot_like/tot_frames) << " over " << tot_frames << " frames; "
217  << "average weight per frame is " << (tot_weight/tot_frames)
218  << ", average abs(weight) per frame is "
219  << (tot_abs_weight/tot_frames);
220 
221  KALDI_LOG << "Done " << num_done << " files, " << num_err
222  << " with errors.";
223 
224  {
225  Output ko(num_accs_wxfilename, binary);
226  num_transition_accs.Write(ko.Stream(), binary);
227  num_sgmm_accs.Write(ko.Stream(), binary);
228  }
229  {
230  Output ko(den_accs_wxfilename, binary);
231  den_transition_accs.Write(ko.Stream(), binary);
232  den_sgmm_accs.Write(ko.Stream(), binary);
233  }
234  KALDI_LOG << "Written accs.";
235  return (num_done != 0 ? 0 : 1);
236  } catch(const std::exception &e) {
237  std::cerr << e.what();
238  return -1;
239  }
240 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Class for definition of the subspace Gmm acoustic model.
Definition: am-sgmm2.h:231
bool HasSpeakerSpace() const
Definition: am-sgmm2.h:368
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void Read(std::istream &is, bool binary)
Definition: am-sgmm2.cc:89
kaldi::int32 int32
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
Definition: model-common.cc:64
int32 TransitionIdToPdf(int32 trans_id) const
void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const
Computes the per-speaker derived vars; assumes vars->v_s is already set up.
Definition: am-sgmm2.cc:1369
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
t .. not really part of SGMM.
Definition: model-common.h:55
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
void InitStats(Vector< double > *stats) const
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:59
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void ComputePerFrameVars(const VectorBase< BaseFloat > &data, const std::vector< int32 > &gselect, const Sgmm2PerSpkDerivedVars &spk_vars, Sgmm2PerFrameDerivedVars *per_frame_vars) const
This needs to be called with each new frame of data, prior to accumulation or likelihood evaluation: ...
Definition: am-sgmm2.cc:442
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void SetSpeakerVector(const Vector< BaseFloat > &v_s_in)
Definition: am-sgmm2.h:180
bool HasSpeakerDependentWeights() const
True if doing SSGMM.
Definition: am-sgmm2.h:366
Class for the accumulators associated with the phonetic-subspace model parameters.
#define KALDI_LOG
Definition: kaldi-error.h:153
Holds the per-frame precomputed quantities x(t), x_{i}(t), z_{i}(t), and n_{i}(t) (cf...
Definition: am-sgmm2.h:142