sgmm2-post-to-gpost.cc
Go to the documentation of this file.
1 // sgmm2bin/sgmm2-post-to-gpost.cc
2 
3 // Copyright 2009-2012 Saarland University Microsoft Corporation
4 // Johns Hopkins University (Author: Daniel Povey)
5 // 2014 Guoguo Chen
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 
23 #include "base/kaldi-common.h"
24 #include "util/common-utils.h"
25 #include "sgmm2/am-sgmm2.h"
26 #include "hmm/transition-model.h"
28 #include "hmm/posterior.h"
29 
30 
31 int main(int argc, char *argv[]) {
32  using namespace kaldi;
33  try {
34  const char *usage =
35  "Convert posteriors to Gaussian-level posteriors for SGMM training.\n"
36  "Usage: sgmm2-post-to-gpost [options] <model-in> <feature-rspecifier> "
37  "<posteriors-rspecifier> <gpost-wspecifier>\n"
38  "e.g.: sgmm2-post-to-gpost 1.mdl 1.ali scp:train.scp 'ark:ali-to-post ark:1.ali ark:-|' ark:-";
39 
40  ParseOptions po(usage);
41  std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
42 
43  po.Register("gselect", &gselect_rspecifier, "Precomputed Gaussian indices (rspecifier)");
44  po.Register("spk-vecs", &spkvecs_rspecifier, "Speaker vectors (rspecifier)");
45  po.Register("utt2spk", &utt2spk_rspecifier,
46  "rspecifier for utterance to speaker map");
47 
48  po.Read(argc, argv);
49 
50  if (po.NumArgs() != 4) {
51  po.PrintUsage();
52  exit(1);
53  }
54  if (gselect_rspecifier == "")
55  KALDI_ERR << "--gselect option is required";
56 
57  std::string model_filename = po.GetArg(1),
58  feature_rspecifier = po.GetArg(2),
59  posteriors_rspecifier = po.GetArg(3),
60  gpost_wspecifier = po.GetArg(4);
61 
62  using namespace kaldi;
63  typedef kaldi::int32 int32;
64 
65  AmSgmm2 am_sgmm;
66  TransitionModel trans_model;
67  {
68  bool binary;
69  Input ki(model_filename, &binary);
70  trans_model.Read(ki.Stream(), binary);
71  am_sgmm.Read(ki.Stream(), binary);
72  }
73 
74  double tot_like = 0.0;
75  kaldi::int64 tot_t = 0;
76 
77  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
78  RandomAccessPosteriorReader posteriors_reader(posteriors_rspecifier);
79  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
80  RandomAccessBaseFloatVectorReaderMapped spkvecs_reader(spkvecs_rspecifier,
81  utt2spk_rspecifier);
82 
83  Sgmm2PerFrameDerivedVars per_frame_vars;
84 
85  Sgmm2GauPostWriter gpost_writer(gpost_wspecifier);
86 
87  int32 num_done = 0, num_err = 0;
88  for (; !feature_reader.Done(); feature_reader.Next()) {
89  const Matrix<BaseFloat> &mat = feature_reader.Value();
90  std::string utt = feature_reader.Key();
91  if (!posteriors_reader.HasKey(utt)
92  || posteriors_reader.Value(utt).size() != mat.NumRows()) {
93  KALDI_WARN << "No posteriors available for utterance " << utt
94  << " (or wrong size)";
95  num_err++;
96  continue;
97  }
98  Posterior posterior = posteriors_reader.Value(utt);
99 
100  if (!gselect_reader.HasKey(utt) ||
101  gselect_reader.Value(utt).size() != mat.NumRows()) {
102  KALDI_WARN << "No Gaussian-selection info available for utterance "
103  << utt << " (or wrong size)";
104  num_err++;
105  continue;
106  }
107  const std::vector<std::vector<int32> > &gselect =
108  gselect_reader.Value(utt);
109 
110  Sgmm2PerSpkDerivedVars spk_vars;
111  if (spkvecs_reader.IsOpen()) {
112  if (spkvecs_reader.HasKey(utt)) {
113  spk_vars.SetSpeakerVector(spkvecs_reader.Value(utt));
114  am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
115  } else {
116  KALDI_WARN << "Cannot find speaker vector for " << utt;
117  num_err++;
118  continue;
119  }
120  } // else spk_vars is "empty"
121 
122  num_done++;
123  BaseFloat tot_like_this_file = 0.0, tot_weight = 0.0;
124 
125  Sgmm2GauPost gpost(posterior.size()); // posterior.size() == T.
126 
127  SortPosteriorByPdfs(trans_model, &posterior);
128  int32 prev_pdf_id = -1;
129  BaseFloat prev_like = 0;
130  Matrix<BaseFloat> prev_posterior;
131  for (size_t i = 0; i < posterior.size(); i++) {
132  am_sgmm.ComputePerFrameVars(mat.Row(i), gselect[i],
133  spk_vars, &per_frame_vars);
134 
135  gpost[i].gselect = gselect[i];
136  gpost[i].tids.resize(posterior[i].size());
137  gpost[i].posteriors.resize(posterior[i].size());
138 
139  prev_pdf_id = -1; // Only cache for the same frame.
140  for (size_t j = 0; j < posterior[i].size(); j++) {
141  int32 tid = posterior[i][j].first, // transition identifier.
142  pdf_id = trans_model.TransitionIdToPdf(tid);
143  BaseFloat weight = posterior[i][j].second;
144  gpost[i].tids[j] = tid;
145 
146  if (pdf_id != prev_pdf_id) {
147  // First time see this pdf-id for this frame, update the cached
148  // variables.
149  prev_pdf_id = pdf_id;
150  prev_like = am_sgmm.ComponentPosteriors(per_frame_vars, pdf_id,
151  &spk_vars,
152  &prev_posterior);
153  }
154 
155  gpost[i].posteriors[j] = prev_posterior;
156  tot_like_this_file += prev_like * weight;
157  tot_weight += weight;
158  gpost[i].posteriors[j].Scale(weight);
159  }
160  }
161 
162  KALDI_VLOG(2) << "Average like for this file is "
163  << (tot_like_this_file/posterior.size()) << " over "
164  << posterior.size() <<" frames.";
165  tot_like += tot_like_this_file;
166  tot_t += posterior.size();
167  if (num_done % 10 == 0)
168  KALDI_LOG << "Avg like per frame so far is "
169  << (tot_like/tot_t);
170  gpost_writer.Write(utt, gpost);
171  }
172 
173  KALDI_LOG << "Overall like per frame (Gaussian only) = "
174  << (tot_like/tot_t) << " over " << tot_t << " frames.";
175 
176  KALDI_LOG << "Done " << num_done << " files, " << num_err
177  << " with errors.";
178 
179  return (num_done != 0 ? 0 : 1);
180  } catch(const std::exception &e) {
181  std::cerr << e.what();
182  return -1;
183  }
184 }
185 
186 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Class for definition of the subspace Gmm acoustic model.
Definition: am-sgmm2.h:231
BaseFloat ComponentPosteriors(const Sgmm2PerFrameDerivedVars &per_frame_vars, int32 j2, Sgmm2PerSpkDerivedVars *spk_vars, Matrix< BaseFloat > *post) const
Similar to LogLikelihood() function above, but also computes the posterior probabilities for the pre-...
Definition: am-sgmm2.cc:574
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
void Read(std::istream &is, bool binary)
Definition: am-sgmm2.cc:89
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
int32 TransitionIdToPdf(int32 trans_id) const
void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const
Computes the per-speaker derived vars; assumes vars->v_s is already set up.
Definition: am-sgmm2.cc:1369
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::istream & Stream()
Definition: kaldi-io.cc:826
void SortPosteriorByPdfs(const TransitionModel &tmodel, Posterior *post)
Sorts posterior entries so that transition-ids with same pdf-id are next to each other.
Definition: posterior.cc:314
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
void Scale(Real alpha)
Multiply each element with a scalar value.
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
void ComputePerFrameVars(const VectorBase< BaseFloat > &data, const std::vector< int32 > &gselect, const Sgmm2PerSpkDerivedVars &spk_vars, Sgmm2PerFrameDerivedVars *per_frame_vars) const
This needs to be called with each new frame of data, prior to accumulation or likelihood evaluation: ...
Definition: am-sgmm2.cc:442
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
indexed by time.
Definition: am-sgmm2.h:568
bool HasKey(const std::string &key)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
const T & Value(const std::string &key)
void SetSpeakerVector(const Vector< BaseFloat > &v_s_in)
Definition: am-sgmm2.h:180
#define KALDI_LOG
Definition: kaldi-error.h:153
Holds the per-frame precomputed quantities x(t), x_{i}(t), z_{i}(t), and n_{i}(t) (cf...
Definition: am-sgmm2.h:142