sgmm2-est-spkvecs.cc
Go to the documentation of this file.
1 // sgmm2bin/sgmm2-est-spkvecs.cc
2 
3 // Copyright 2009-2012 Saarland University Microsoft Corporation
4 // Johns Hopkins University (Author: Daniel Povey)
5 // 2014 Guoguo Chen
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include <string>
23 using std::string;
24 #include <vector>
25 using std::vector;
26 
27 #include "base/kaldi-common.h"
28 #include "util/common-utils.h"
29 #include "sgmm2/am-sgmm2.h"
31 #include "hmm/transition-model.h"
32 #include "hmm/posterior.h"
33 
34 namespace kaldi {
35 
37  const Posterior &post,
38  const TransitionModel &trans_model,
39  const AmSgmm2 &am_sgmm,
40  const vector< vector<int32> > &gselect,
41  Sgmm2PerSpkDerivedVars *spk_vars,
42  MleSgmm2SpeakerAccs *spk_stats) {
43  kaldi::Sgmm2PerFrameDerivedVars per_frame_vars;
44 
45  KALDI_ASSERT(gselect.size() == feats.NumRows());
46  Posterior pdf_post;
47  ConvertPosteriorToPdfs(trans_model, post, &pdf_post);
48  for (size_t i = 0; i < post.size(); i++) {
49  am_sgmm.ComputePerFrameVars(feats.Row(i), gselect[i],
50  *spk_vars, &per_frame_vars);
51 
52  for (size_t j = 0; j < pdf_post[i].size(); j++) {
53  int32 pdf_id = pdf_post[i][j].first;
54  spk_stats->Accumulate(am_sgmm, per_frame_vars, pdf_id,
55  pdf_post[i][j].second, spk_vars);
56  }
57  }
58 }
59 
60 } // end namespace kaldi
61 
62 int main(int argc, char *argv[]) {
63  try {
64  typedef kaldi::int32 int32;
65  using namespace kaldi;
66  const char *usage =
67  "Estimate SGMM speaker vectors, either per utterance or for the "
68  "supplied set of speakers (with spk2utt option).\n"
69  "Reads Gaussian-level posteriors. Writes to a table of vectors.\n"
70  "Usage: sgmm2-est-spkvecs [options] <model-in> <feature-rspecifier> "
71  "<post-rspecifier> <vecs-wspecifier>\n"
72  "note: --gselect option is required.";
73 
74  ParseOptions po(usage);
75  string gselect_rspecifier, spk2utt_rspecifier, spkvecs_rspecifier;
76  BaseFloat min_count = 100;
77  BaseFloat rand_prune = 1.0e-05;
78 
79  po.Register("gselect", &gselect_rspecifier,
80  "rspecifier for precomputed per-frame Gaussian indices from.");
81  po.Register("spk2utt", &spk2utt_rspecifier,
82  "File to read speaker to utterance-list map from.");
83  po.Register("spkvec-min-count", &min_count,
84  "Minimum count needed to estimate speaker vectors");
85  po.Register("rand-prune", &rand_prune, "Pruning threshold for posteriors");
86  po.Register("spk-vecs", &spkvecs_rspecifier, "Speaker vectors to use during aligment (rspecifier)");
87  po.Read(argc, argv);
88 
89  if (po.NumArgs() != 4) {
90  po.PrintUsage();
91  exit(1);
92  }
93  if (gselect_rspecifier == "")
94  KALDI_ERR << "--gselect option is mandatory.";
95 
96  string model_rxfilename = po.GetArg(1),
97  feature_rspecifier = po.GetArg(2),
98  post_rspecifier = po.GetArg(3),
99  vecs_wspecifier = po.GetArg(4);
100 
101  TransitionModel trans_model;
102  AmSgmm2 am_sgmm;
103  {
104  bool binary;
105  Input ki(model_rxfilename, &binary);
106  trans_model.Read(ki.Stream(), binary);
107  am_sgmm.Read(ki.Stream(), binary);
108  }
109  MleSgmm2SpeakerAccs spk_stats(am_sgmm, rand_prune);
110 
111  RandomAccessPosteriorReader post_reader(post_rspecifier);
112  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
113  RandomAccessBaseFloatVectorReader spkvecs_reader(spkvecs_rspecifier);
114 
115  BaseFloatVectorWriter vecs_writer(vecs_wspecifier);
116 
117  double tot_impr = 0.0, tot_t = 0.0;
118  int32 num_done = 0, num_err = 0;
119  std::vector<std::vector<int32> > empty_gselect;
120 
121  if (!spk2utt_rspecifier.empty()) { // per-speaker adaptation
122  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
123  RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
124 
125  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
126  spk_stats.Clear();
127  string spk = spk2utt_reader.Key();
128  const vector<string> &uttlist = spk2utt_reader.Value();
129 
130  Sgmm2PerSpkDerivedVars spk_vars;
131  if (spkvecs_reader.IsOpen()) {
132  if (spkvecs_reader.HasKey(spk)) {
133  spk_vars.SetSpeakerVector(spkvecs_reader.Value(spk));
134  am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
135  } else {
136  KALDI_WARN << "Cannot find speaker vector for speaker " << spk
137  << ", not processing this speaker.";
138  num_err++; // standard Kaldi behavior is to not process data
139  // when errors like this happen, as it's generally a script error;
140  continue;
141  }
142  } // else spk_vars is "empty"
143 
144  for (size_t i = 0; i < uttlist.size(); i++) {
145  std::string utt = uttlist[i];
146  if (!feature_reader.HasKey(utt)) {
147  KALDI_WARN << "Did not find features for utterance " << utt;
148  continue;
149  }
150  if (!post_reader.HasKey(utt)) {
151  KALDI_WARN << "Did not find posteriors for utterance " << utt;
152  num_err++;
153  continue;
154  }
155  const Matrix<BaseFloat> &feats = feature_reader.Value(utt);
156  const Posterior &post = post_reader.Value(utt);
157  if (static_cast<int32>(post.size()) != feats.NumRows()) {
158  KALDI_WARN << "Posterior vector has wrong size " << (post.size())
159  << " vs. " << (feats.NumRows());
160  num_err++;
161  continue;
162  }
163  if (!gselect_reader.HasKey(utt) ||
164  gselect_reader.Value(utt).size() != feats.NumRows()) {
165  KALDI_WARN << "No Gaussian-selection info available for utterance "
166  << utt << " (or wrong size)";
167  num_err++;
168  continue;
169  }
170  const std::vector<std::vector<int32> > &gselect =
171  gselect_reader.Value(utt);
172 
173  AccumulateForUtterance(feats, post, trans_model, am_sgmm,
174  gselect, &spk_vars, &spk_stats);
175  num_done++;
176  } // end looping over all utterances of the current speaker
177 
178  BaseFloat impr, spk_tot_t;
179  { // Compute the spk_vec and write it out.
180  Vector<BaseFloat> spk_vec(am_sgmm.SpkSpaceDim(), kSetZero);
181  if (spk_vars.GetSpeakerVector().Dim() != 0)
182  spk_vec.CopyFromVec(spk_vars.GetSpeakerVector());
183  spk_stats.Update(am_sgmm, min_count, &spk_vec, &impr, &spk_tot_t);
184  vecs_writer.Write(spk, spk_vec);
185  }
186  KALDI_LOG << "For speaker " << spk << ", auxf-impr from speaker vector is "
187  << (impr/spk_tot_t) << ", over " << spk_tot_t << " frames.";
188  tot_impr += impr;
189  tot_t += spk_tot_t;
190  } // end looping over speakers
191  } else { // per-utterance adaptation
192  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
193  for (; !feature_reader.Done(); feature_reader.Next()) {
194  string utt = feature_reader.Key();
195  const Matrix<BaseFloat> &feats = feature_reader.Value();
196  if (!post_reader.HasKey(utt) ||
197  post_reader.Value(utt).size() != feats.NumRows()) {
198  KALDI_WARN << "Did not find posts for utterance "
199  << utt << " (or wrong size).";
200  num_err++;
201  continue;
202  }
203  const Posterior &post = post_reader.Value(utt);
204 
205  Sgmm2PerSpkDerivedVars spk_vars;
206  if (spkvecs_reader.IsOpen()) {
207  if (spkvecs_reader.HasKey(utt)) {
208  spk_vars.SetSpeakerVector(spkvecs_reader.Value(utt));
209  am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
210  } else {
211  KALDI_WARN << "Cannot find speaker vector for utterance " << utt
212  << ", not processing it.";
213  num_err++;
214  continue;
215  }
216  } // else spk_vars is "empty"
217 
218  num_done++;
219 
220  if (!gselect_reader.HasKey(utt) ||
221  gselect_reader.Value(utt).size() != feats.NumRows()) {
222  KALDI_WARN << "No Gaussian-selection info available for utterance "
223  << utt << " (or wrong size)";
224  num_err++;
225  continue;
226  }
227  const std::vector<std::vector<int32> > &gselect =
228  gselect_reader.Value(utt);
229 
230  spk_stats.Clear();
231 
232  AccumulateForUtterance(feats, post, trans_model, am_sgmm,
233  gselect, &spk_vars, &spk_stats);
234 
235  BaseFloat impr, utt_tot_t;
236  { // Compute the spk_vec and write it out.
237  Vector<BaseFloat> spk_vec(am_sgmm.SpkSpaceDim(), kSetZero);
238  if (spk_vars.GetSpeakerVector().Dim() != 0)
239  spk_vec.CopyFromVec(spk_vars.GetSpeakerVector());
240  spk_stats.Update(am_sgmm, min_count, &spk_vec, &impr, &utt_tot_t);
241  vecs_writer.Write(utt, spk_vec);
242  }
243  KALDI_LOG << "For utterance " << utt << ", auxf-impr from speaker vectors is "
244  << (impr/utt_tot_t) << ", over " << utt_tot_t << " frames.";
245  tot_impr += impr;
246  tot_t += utt_tot_t;
247  }
248  }
249 
250  KALDI_LOG << "Overall auxf impr per frame is "
251  << (tot_impr / tot_t) << " over " << tot_t << " frames.";
252  KALDI_LOG << "Done " << num_done << " files, " << num_err << " with errors.";
253  return (num_done != 0 ? 0 : 1);
254  } catch(const std::exception &e) {
255  std::cerr << e.what();
256  return -1;
257  }
258 }
259 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Class for definition of the subspace Gmm acoustic model.
Definition: am-sgmm2.h:231
BaseFloat Accumulate(const AmSgmm2 &model, const Sgmm2PerFrameDerivedVars &frame_vars, int32 pdf_index, BaseFloat weight, Sgmm2PerSpkDerivedVars *spk_vars)
Accumulate statistics. Returns per-frame log-likelihood.
Class for the accumulators required to update the speaker vectors v_s.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
const Vector< BaseFloat > & GetSpeakerVector()
Definition: am-sgmm2.h:178
void Read(std::istream &is, bool binary)
Definition: am-sgmm2.cc:89
void AccumulateForUtterance(const Matrix< BaseFloat > &feats, const GaussPost &gpost, const TransitionModel &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats)
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const
Computes the per-speaker derived vars; assumes vars->v_s is already set up.
Definition: am-sgmm2.cc:1369
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void Update(const AmSgmm2 &model, BaseFloat min_count, Vector< BaseFloat > *v_s, BaseFloat *objf_impr_out, BaseFloat *count_out)
Update speaker vector.
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
int main(int argc, char *argv[])
void ComputePerFrameVars(const VectorBase< BaseFloat > &data, const std::vector< int32 > &gselect, const Sgmm2PerSpkDerivedVars &spk_vars, Sgmm2PerFrameDerivedVars *per_frame_vars) const
This needs to be called with each new frame of data, prior to accumulation or likelihood evaluation: ...
Definition: am-sgmm2.cc:442
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
void Clear()
Clear the statistics.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void SetSpeakerVector(const Vector< BaseFloat > &v_s_in)
Definition: am-sgmm2.h:180
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
Definition: posterior.cc:322
int32 SpkSpaceDim() const
Definition: am-sgmm2.h:362
#define KALDI_LOG
Definition: kaldi-error.h:153
Holds the per-frame precomputed quantities x(t), x_{i}(t), z_{i}(t), and n_{i}(t) (cf...
Definition: am-sgmm2.h:142