gmm-adapt-map.cc
Go to the documentation of this file.
1 // gmmbin/gmm-adapt-map.cc
2 
3 // Copyright 2012 Cisco Systems (author: Neha Agrawal)
4 // Johns Hopkins University (author: Daniel Povey)
5 // 2014 Guoguo Chen
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include <string>
23 #include <vector>
24 
25 #include "base/kaldi-common.h"
26 #include "util/common-utils.h"
27 #include "gmm/am-diag-gmm.h"
28 #include "hmm/transition-model.h"
29 #include "gmm/mle-am-diag-gmm.h"
30 #include "hmm/posterior.h"
31 
32 int main(int argc, char *argv[]) {
33  try {
34  typedef kaldi::int32 int32;
35  using namespace kaldi;
36  const char *usage =
37  "Compute MAP estimates per-utterance (default) or per-speaker for\n"
38  "the supplied set of speakers (spk2utt option). This will typically\n"
39  "be piped into gmm-latgen-map\n"
40  "\n"
41  "Usage: gmm-adapt-map [options] <model-in> <feature-rspecifier> "
42  "<posteriors-rspecifier> <map-am-wspecifier>\n";
43 
44  ParseOptions po(usage);
45  std::string spk2utt_rspecifier;
46  bool binary = true;
47  MapDiagGmmOptions map_config;
48  std::string update_flags_str = "mw";
49 
50  po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to "
51  "utterance-list map");
52  po.Register("binary", &binary, "Write output in binary mode");
53  po.Register("update-flags", &update_flags_str, "Which GMM parameters will be "
54  "updated: subset of mvw.");
55  map_config.Register(&po);
56 
57  po.Read(argc, argv);
58 
59  if (po.NumArgs() != 4) {
60  po.PrintUsage();
61  exit(1);
62  }
63 
64  std::string model_filename = po.GetArg(1),
65  feature_rspecifier = po.GetArg(2),
66  posteriors_rspecifier = po.GetArg(3),
67  map_am_wspecifier = po.GetArg(4);
68 
69  GmmFlagsType update_flags = StringToGmmFlags(update_flags_str);
70 
71  RandomAccessPosteriorReader posteriors_reader(posteriors_rspecifier);
72  MapAmDiagGmmWriter map_am_writer(map_am_wspecifier);
73 
74  AmDiagGmm am_gmm;
75  TransitionModel trans_model;
76  {
77  bool binary;
78  Input is(model_filename, &binary);
79  trans_model.Read(is.Stream(), binary);
80  am_gmm.Read(is.Stream(), binary);
81  }
82 
83  double tot_like = 0.0, tot_like_change = 0.0, tot_t = 0.0,
84  tot_t_check = 0.0;
85  int32 num_done = 0, num_err = 0;
86 
87  if (spk2utt_rspecifier != "") { // per-speaker adaptation
88  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
89  RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
90  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
91  std::string spk = spk2utt_reader.Key();
92  AmDiagGmm copy_am_gmm;
93  copy_am_gmm.CopyFromAmDiagGmm(am_gmm);
94  AccumAmDiagGmm map_accs;
95  map_accs.Init(am_gmm, update_flags);
96 
97  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
98 
99  // for each speaker, estimate MAP means
100  std::vector<std::string>::const_iterator iter = uttlist.begin(),
101  end = uttlist.end();
102  for (; iter != end; ++iter) {
103  std::string utt = *iter;
104  if (!feature_reader.HasKey(utt)) {
105  KALDI_WARN << "Did not find features for utterance " << utt;
106  continue;
107  }
108  if (!posteriors_reader.HasKey(utt)) {
109  KALDI_WARN << "Did not find posteriors for utterance " << utt;
110  num_err++;
111  continue;
112  }
113  const Matrix<BaseFloat> &feats = feature_reader.Value(utt);
114  const Posterior &posterior = posteriors_reader.Value(utt);
115  if (posterior.size() != feats.NumRows()) {
116  KALDI_WARN << "Posteriors has wrong size " << (posterior.size())
117  << " vs. " << (feats.NumRows());
118  num_err++;
119  continue;
120  }
121 
122  BaseFloat file_like = 0.0, file_t = 0.0;
123  Posterior pdf_posterior;
124  ConvertPosteriorToPdfs(trans_model, posterior, &pdf_posterior);
125  for ( size_t i = 0; i < posterior.size(); i++ ) {
126  for ( size_t j = 0; j < pdf_posterior[i].size(); j++ ) {
127  int32 pdf_id = pdf_posterior[i][j].first;
128  BaseFloat weight = pdf_posterior[i][j].second;
129  file_like += map_accs.AccumulateForGmm(copy_am_gmm,
130  feats.Row(i),
131  pdf_id, weight);
132  file_t += weight;
133  }
134  }
135 
136  KALDI_VLOG(2) << "Average like for utterance " << utt << " is "
137  << (file_like/file_t) << " over " << file_t << " frames.";
138 
139  tot_like += file_like;
140  tot_t += file_t;
141  num_done++;
142 
143  if (num_done % 10 == 0)
144  KALDI_VLOG(1) << "Avg like per frame so far is "
145  << (tot_like / tot_t);
146  } // end looping over all utterances of the current speaker
147 
148  // MAP estimation.
149  BaseFloat spk_objf_change = 0.0, spk_frames = 0.0;
150  MapAmDiagGmmUpdate(map_config, map_accs, update_flags, &copy_am_gmm,
151  &spk_objf_change, &spk_frames);
152  KALDI_LOG << "For speaker " << spk << ", objective function change "
153  << "from MAP was " << (spk_objf_change / spk_frames)
154  << " over " << spk_frames << " frames.";
155  tot_like_change += spk_objf_change;
156  tot_t_check += spk_frames;
157 
158  // Writing AM for each speaker in a table
159  map_am_writer.Write(spk,copy_am_gmm);
160  } // end looping over speakers
161  } else { // per-utterance adaptation
162  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
163  for ( ; !feature_reader.Done(); feature_reader.Next() ) {
164  std::string utt = feature_reader.Key();
165  AmDiagGmm copy_am_gmm;
166  copy_am_gmm.CopyFromAmDiagGmm(am_gmm);
167  AccumAmDiagGmm map_accs;
168  map_accs.Init(am_gmm, update_flags);
169  map_accs.SetZero(update_flags);
170 
171  if ( !posteriors_reader.HasKey(utt) ) {
172  KALDI_WARN << "Did not find aligned transcription for utterance "
173  << utt;
174  num_err++;
175  continue;
176  }
177  const Matrix<BaseFloat> &feats = feature_reader.Value();
178  const Posterior &posterior = posteriors_reader.Value(utt);
179 
180  if ( posterior.size() != feats.NumRows() ) {
181  KALDI_WARN << "Posteriors has wrong size " << (posterior.size())
182  << " vs. " << (feats.NumRows());
183  num_err++;
184  continue;
185  }
186  num_done++;
187  BaseFloat file_like = 0.0, file_t = 0.0;
188  Posterior pdf_posterior;
189  ConvertPosteriorToPdfs(trans_model, posterior, &pdf_posterior);
190  for ( size_t i = 0; i < posterior.size(); i++ ) {
191  for ( size_t j = 0; j < pdf_posterior[i].size(); j++ ) {
192  int32 pdf_id = pdf_posterior[i][j].first;
193  BaseFloat prob = pdf_posterior[i][j].second;
194  file_like += map_accs.AccumulateForGmm(copy_am_gmm,feats.Row(i),
195  pdf_id, prob);
196  file_t += prob;
197  }
198  }
199  KALDI_VLOG(2) << "Average like for utterance " << utt << " is "
200  << (file_like/file_t) << " over " << file_t << " frames.";
201  tot_like += file_like;
202  tot_t += file_t;
203  if ( num_done % 10 == 0 )
204  KALDI_VLOG(1) << "Avg like per frame so far is "
205  << (tot_like / tot_t);
206 
207  // MAP
208  BaseFloat utt_objf_change = 0.0, utt_frames = 0.0;
209  MapAmDiagGmmUpdate(map_config, map_accs, update_flags, &copy_am_gmm,
210  &utt_objf_change, &utt_frames);
211  KALDI_LOG << "For utterance " << utt << ", objective function change "
212  << "from MAP was " << (utt_objf_change / utt_frames)
213  << " over " << utt_frames << " frames.";
214  tot_like_change += utt_objf_change;
215  tot_t_check += utt_frames;
216 
217  // Writing AM for each utterance in a table
218  map_am_writer.Write(feature_reader.Key(), copy_am_gmm);
219  }
220  }
221  KALDI_ASSERT(ApproxEqual(tot_t, tot_t_check));
222  KALDI_LOG << "Done " << num_done << " files, " << num_err
223  << " with errors";
224  KALDI_LOG << "Overall acoustic likelihood was " << (tot_like / tot_t)
225  << " and change in likelihod per frame was "
226  << (tot_like_change / tot_t) << " over " << tot_t << " frames.";
227  return (num_done != 0 ? 0 : 1);
228  } catch(const std::exception& e) {
229  std::cerr << e.what();
230  return -1;
231  }
232 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CopyFromAmDiagGmm(const AmDiagGmm &other)
Copies the parameters from another model. Allocates necessary memory.
Definition: am-diag-gmm.cc:79
void MapAmDiagGmmUpdate(const MapDiagGmmOptions &config, const AccumAmDiagGmm &am_diag_gmm_acc, GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
Maximum A Posteriori update.
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
Definition: model-common.cc:26
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void SetZero(GmmFlagsType flags)
BaseFloat AccumulateForGmm(const AmDiagGmm &model, const VectorBase< BaseFloat > &data, int32 gmm_index, BaseFloat weight)
Accumulate stats for a single GMM in the model; returns log likelihood.
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void Register(OptionsItf *opts)
Definition: mle-diag-gmm.h:93
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
int main(int argc, char *argv[])
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
Definition: posterior.cc:322
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
void Init(const AmDiagGmm &model, GmmFlagsType flags)
Initializes accumulators for each GMM based on the number of components and dimension.
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Definition: kaldi-math.h:265
Configuration variables for Maximum A Posteriori (MAP) update.
Definition: mle-diag-gmm.h:76