gmm-basis-fmllr-accs.cc
Go to the documentation of this file.
1 // gmmbin/gmm-basis-fmllr-accs.cc
2 
3 // Copyright 2012 Carnegie Mellon University (author: Yajie Miao)
4 // 2014 Guoguo Chen
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <string>
22 using std::string;
23 #include <vector>
24 using std::vector;
25 
26 #include "base/kaldi-common.h"
27 #include "util/common-utils.h"
28 #include "gmm/am-diag-gmm.h"
29 #include "hmm/transition-model.h"
32 #include "hmm/posterior.h"
33 
34 namespace kaldi {
36  const Posterior &post,
37  const TransitionModel &trans_model,
38  const AmDiagGmm &am_gmm,
39  FmllrDiagGmmAccs *spk_stats) {
40  Posterior pdf_post;
41  ConvertPosteriorToPdfs(trans_model, post, &pdf_post);
42  for (size_t i = 0; i < post.size(); i++) {
43  for (size_t j = 0; j < pdf_post[i].size(); j++) {
44  int32 pdf_id = pdf_post[i][j].first;
45  spk_stats->AccumulateForGmm(am_gmm.GetPdf(pdf_id),
46  feats.Row(i),
47  pdf_post[i][j].second);
48  }
49  }
50 }
51 
52 
53 }
54 
55 int main(int argc, char *argv[]) {
56  try {
57  typedef kaldi::int32 int32;
58  using namespace kaldi;
59  const char *usage =
60  "Accumulate gradient scatter from training set, either per utterance or \n"
61  "for the supplied set of speakers (spk2utt option). Reads posterior to accumulate \n"
62  "fMLLR stats for each speaker/utterance. Writes gradient scatter matrix.\n"
63  "Usage: gmm-basis-fmllr-accs [options] <model-in> <feature-rspecifier>"
64  "<post-rspecifier> <accs-wspecifier>\n";
65 
66  bool binary_write = true;
67  string spk2utt_rspecifier;
68  ParseOptions po(usage);
69  po.Register("binary", &binary_write, "Write output in binary mode");
70  po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to "
71  "utterance-list map");
72 
73  po.Read(argc, argv);
74  if (po.NumArgs() != 4) {
75  po.PrintUsage();
76  exit(1);
77  }
78 
79  string
80  model_rxfilename = po.GetArg(1),
81  feature_rspecifier = po.GetArg(2),
82  post_rspecifier = po.GetArg(3),
83  accs_wspecifier = po.GetArg(4);
84 
85  TransitionModel trans_model;
86  AmDiagGmm am_gmm;
87  {
88  bool binary;
89  Input ki(model_rxfilename, &binary);
90  trans_model.Read(ki.Stream(), binary);
91  am_gmm.Read(ki.Stream(), binary);
92  }
93 
94  RandomAccessPosteriorReader post_reader(post_rspecifier);
95  BasisFmllrAccus basis_accs(am_gmm.Dim());
96 
97  int32 num_done = 0, num_no_post = 0, num_other_error = 0;
98  if (spk2utt_rspecifier != "") { // per-speaker mode
99  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
100  RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
101 
102  int32 num_spk = 0;
103  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
104  FmllrDiagGmmAccs spk_stats(am_gmm.Dim());
105  string spk = spk2utt_reader.Key();
106  const vector<string> &uttlist = spk2utt_reader.Value();
107  for (size_t i = 0; i < uttlist.size(); i++) {
108  std::string utt = uttlist[i];
109  if (!feature_reader.HasKey(utt)) {
110  KALDI_WARN << "Did not find features for utterance " << utt;
111  num_other_error++;
112  continue;
113  }
114  if (!post_reader.HasKey(utt)) {
115  KALDI_WARN << "Did not find posteriors for utterance " << utt;
116  num_no_post++;
117  continue;
118  }
119  const Matrix<BaseFloat> &feats = feature_reader.Value(utt);
120  const Posterior &post = post_reader.Value(utt);
121  if (static_cast<int32>(post.size()) != feats.NumRows()) {
122  KALDI_WARN << "Posterior vector has wrong size " << (post.size())
123  << " vs. " << (feats.NumRows());
124  num_other_error++;
125  continue;
126  }
127 
128  AccumulateForUtterance(feats, post, trans_model, am_gmm, &spk_stats);
129 
130  num_done++;
131  } // end looping over all utterances of this speaker
132  basis_accs.AccuGradientScatter(spk_stats);
133  num_spk++;
134  } // end looping over speakers
135  KALDI_LOG << "Accumulate statistics from " << num_spk << " speakers";
136 
137  } else { // per-utterance mode
138  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
139  for (; !feature_reader.Done(); feature_reader.Next()) {
140  string utt = feature_reader.Key();
141  if (!post_reader.HasKey(utt)) {
142  KALDI_WARN << "Did not find posts for utterance "
143  << utt;
144  num_no_post++;
145  continue;
146  }
147  const Matrix<BaseFloat> &feats = feature_reader.Value();
148  const Posterior &post = post_reader.Value(utt);
149 
150  if (static_cast<int32>(post.size()) != feats.NumRows()) {
151  KALDI_WARN << "Posterior has wrong size " << (post.size())
152  << " vs. " << (feats.NumRows());
153  num_other_error++;
154  continue;
155  }
156  // Accumulate stats for this utterance
157  FmllrDiagGmmAccs utt_stats(am_gmm.Dim());
158  AccumulateForUtterance(feats, post, trans_model, am_gmm, &utt_stats);
159  num_done++;
160 
161  basis_accs.AccuGradientScatter(utt_stats);
162  } // end looping over utterances
163  }
164  // Write out accumulations
165  {
166  Output ko(accs_wspecifier, binary_write);
167  basis_accs.Write(ko.Stream(), binary_write);
168  }
169  KALDI_LOG << "Done " << num_done << " files, " << num_no_post
170  << " with no posts, " << num_other_error << " with other errors.";
171  KALDI_LOG << "Written gradient scatter to " << accs_wspecifier;
172  return (num_done != 0 ? 0 : 1);
173  } catch(const std::exception& e) {
174  std::cerr << e.what();
175  return -1;
176  }
177 }
178 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
This does not work with multiple feature transforms.
void AccumulateForUtterance(const Matrix< BaseFloat > &feats, const GaussPost &gpost, const TransitionModel &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats)
kaldi::int32 int32
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::istream & Stream()
Definition: kaldi-io.cc:826
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
Stats for fMLLR subspace estimation.
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
int32 Dim() const
Definition: am-diag-gmm.h:79
int NumArgs() const
Number of positional parameters (c.f. argc-1).
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
Definition: am-diag-gmm.h:119
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
int main(int argc, char *argv[])
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
Definition: posterior.cc:322
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
BaseFloat AccumulateForGmm(const DiagGmm &gmm, const VectorBase< BaseFloat > &data, BaseFloat weight)
Accumulate stats for a single GMM in the model; returns log likelihood.