sgmm2-est-spkvecs-gpost.cc File Reference
#include <string>
#include <vector>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "sgmm2/am-sgmm2.h"
#include "sgmm2/estimate-am-sgmm2.h"
#include "hmm/transition-model.h"
Include dependency graph for sgmm2-est-spkvecs-gpost.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Functions

void AccumulateForUtterance (const Matrix< BaseFloat > &feats, const Sgmm2GauPost &gpost, const TransitionModel &trans_model, const AmSgmm2 &am_sgmm, Sgmm2PerSpkDerivedVars *spk_vars, MleSgmm2SpeakerAccs *spk_stats)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 57 of file sgmm2-est-spkvecs-gpost.cc.

References kaldi::AccumulateForUtterance(), MleSgmm2SpeakerAccs::Clear(), AmSgmm2::ComputePerSpkDerivedVars(), VectorBase< Real >::CopyFromVec(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), Sgmm2PerSpkDerivedVars::GetSpeakerVector(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, RandomAccessTableReader< Holder >::IsOpen(), KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), kaldi::kSetZero, SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), TransitionModel::Read(), AmSgmm2::Read(), ParseOptions::Register(), Sgmm2PerSpkDerivedVars::SetSpeakerVector(), AmSgmm2::SpkSpaceDim(), Input::Stream(), MleSgmm2SpeakerAccs::Update(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

57  {
58  try {
59  typedef kaldi::int32 int32;
60  using namespace kaldi;
61  const char *usage =
62  "Estimate SGMM speaker vectors, either per utterance or for the "
63  "supplied set of speakers (with spk2utt option).\n"
64  "Reads Gaussian-level posteriors. Writes to a table of vectors.\n"
65  "Usage: sgmm2-est-spkvecs-gpost [options] <model-in> <feature-rspecifier> "
66  "<gpost-rspecifier> <vecs-wspecifier>\n";
67 
68  ParseOptions po(usage);
69  string spk2utt_rspecifier, spkvecs_rspecifier;
70  BaseFloat min_count = 100;
71  BaseFloat rand_prune = 1.0e-05;
72 
73  po.Register("spk2utt", &spk2utt_rspecifier,
74  "File to read speaker to utterance-list map from.");
75  po.Register("spkvec-min-count", &min_count,
76  "Minimum count needed to estimate speaker vectors");
77  po.Register("rand-prune", &rand_prune, "Randomized pruning parameter for posteriors (more->faster).");
78  po.Register("spk-vecs", &spkvecs_rspecifier, "Speaker vectors to use during aligment (rspecifier)");
79  po.Read(argc, argv);
80 
81  if (po.NumArgs() != 4) {
82  po.PrintUsage();
83  exit(1);
84  }
85 
86  string model_rxfilename = po.GetArg(1),
87  feature_rspecifier = po.GetArg(2),
88  gpost_rspecifier = po.GetArg(3),
89  vecs_wspecifier = po.GetArg(4);
90 
91  TransitionModel trans_model;
92  AmSgmm2 am_sgmm;
93  {
94  bool binary;
95  Input ki(model_rxfilename, &binary);
96  trans_model.Read(ki.Stream(), binary);
97  am_sgmm.Read(ki.Stream(), binary);
98  }
99  MleSgmm2SpeakerAccs spk_stats(am_sgmm, rand_prune);
100 
101  RandomAccessSgmm2GauPostReader gpost_reader(gpost_rspecifier);
102 
103  RandomAccessBaseFloatVectorReader spkvecs_reader(spkvecs_rspecifier);
104 
105  BaseFloatVectorWriter vecs_writer(vecs_wspecifier);
106 
107  double tot_impr = 0.0, tot_t = 0.0;
108  int32 num_done = 0, num_err = 0;
109 
110  if (!spk2utt_rspecifier.empty()) { // per-speaker adaptation
111  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
112  RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
113 
114  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
115  spk_stats.Clear();
116  string spk = spk2utt_reader.Key();
117  const vector<string> &uttlist = spk2utt_reader.Value();
118 
119  Sgmm2PerSpkDerivedVars spk_vars;
120  if (spkvecs_reader.IsOpen()) {
121  if (spkvecs_reader.HasKey(spk)) {
122  spk_vars.SetSpeakerVector(spkvecs_reader.Value(spk));
123  am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
124  } else {
125  KALDI_WARN << "Cannot find speaker vector for " << spk;
126  }
127  } // else spk_vars is "empty"
128 
129  for (size_t i = 0; i < uttlist.size(); i++) {
130  std::string utt = uttlist[i];
131  if (!feature_reader.HasKey(utt)) {
132  KALDI_WARN << "Did not find features for utterance " << utt;
133  continue;
134  }
135  const Matrix<BaseFloat> &feats = feature_reader.Value(utt);
136  if (!gpost_reader.HasKey(utt) ||
137  gpost_reader.Value(utt).size() != feats.NumRows()) {
138  KALDI_WARN << "Did not find posteriors for utterance " << utt
139  << " (or wrong size).";
140  num_err++;
141  continue;
142  }
143  const Sgmm2GauPost &gpost = gpost_reader.Value(utt);
144 
145  AccumulateForUtterance(feats, gpost, trans_model, am_sgmm,
146  &spk_vars, &spk_stats);
147  num_done++;
148  } // end looping over all utterances of the current speaker
149 
150  BaseFloat impr, spk_tot_t;
151  { // Compute the spk_vec and write it out.
152  Vector<BaseFloat> spk_vec(am_sgmm.SpkSpaceDim(), kSetZero);
153  if (spk_vars.GetSpeakerVector().Dim() != 0)
154  spk_vec.CopyFromVec(spk_vars.GetSpeakerVector());
155  spk_stats.Update(am_sgmm, min_count, &spk_vec, &impr, &spk_tot_t);
156  vecs_writer.Write(spk, spk_vec);
157  }
158  KALDI_LOG << "For speaker " << spk << ", auxf-impr from speaker vector is "
159  << (impr/spk_tot_t) << ", over " << spk_tot_t << " frames.\n";
160  tot_impr += impr;
161  tot_t += spk_tot_t;
162  } // end looping over speakers
163  } else { // per-utterance adaptation
164  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
165  for (; !feature_reader.Done(); feature_reader.Next()) {
166  string utt = feature_reader.Key();
167  const Matrix<BaseFloat> &feats = feature_reader.Value();
168  if (!gpost_reader.HasKey(utt) ||
169  gpost_reader.Value(utt).size() != feats.NumRows()) {
170  KALDI_WARN << "Did not find posts for utterance "
171  << utt;
172  num_err++;
173  continue;
174  }
175  const Sgmm2GauPost &gpost = gpost_reader.Value(utt);
176 
177  Sgmm2PerSpkDerivedVars spk_vars;
178  if (spkvecs_reader.IsOpen()) {
179  if (spkvecs_reader.HasKey(utt)) {
180  spk_vars.SetSpeakerVector(spkvecs_reader.Value(utt));
181  am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
182  } else {
183  KALDI_WARN << "Cannot find speaker vector for " << utt;
184  }
185  } // else spk_vars is "empty"
186 
187  num_done++;
188  spk_stats.Clear();
189 
190  AccumulateForUtterance(feats, gpost, trans_model, am_sgmm,
191  &spk_vars, &spk_stats);
192 
193  BaseFloat impr, utt_tot_t;
194  { // Compute the spk_vec and write it out.
195  Vector<BaseFloat> spk_vec(am_sgmm.SpkSpaceDim(), kSetZero);
196  if (spk_vars.GetSpeakerVector().Dim() != 0)
197  spk_vec.CopyFromVec(spk_vars.GetSpeakerVector());
198  spk_stats.Update(am_sgmm, min_count, &spk_vec, &impr, &utt_tot_t);
199  vecs_writer.Write(utt, spk_vec);
200  }
201  KALDI_LOG << "For utterance " << utt << ", auxf-impr from speaker vectors is "
202  << (impr/utt_tot_t) << ", over " << utt_tot_t << " frames.";
203  tot_impr += impr;
204  tot_t += utt_tot_t;
205  }
206  }
207 
208  KALDI_LOG << "Done " << num_done << " files, " << num_err
209  << " with errors.";
210  KALDI_LOG << "Overall auxf impr per frame is " << (tot_impr / tot_t)
211  << " over " << tot_t << " frames.";
212  return (num_done != 0 ? 0 : 1);
213  } catch(const std::exception &e) {
214  std::cerr << e.what();
215  return -1;
216  }
217 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Class for definition of the subspace Gmm acoustic model.
Definition: am-sgmm2.h:231
Class for the accumulators required to update the speaker vectors v_s.
const Vector< BaseFloat > & GetSpeakerVector()
Definition: am-sgmm2.h:178
void Read(std::istream &is, bool binary)
Definition: am-sgmm2.cc:89
void AccumulateForUtterance(const Matrix< BaseFloat > &feats, const GaussPost &gpost, const TransitionModel &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats)
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const
Computes the per-speaker derived vars; assumes vars->v_s is already set up.
Definition: am-sgmm2.cc:1369
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:150
indexed by time.
Definition: am-sgmm2.h:568
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void SetSpeakerVector(const Vector< BaseFloat > &v_s_in)
Definition: am-sgmm2.h:180
int32 SpkSpaceDim() const
Definition: am-sgmm2.h:362
#define KALDI_LOG
Definition: kaldi-error.h:153