gmm-init-model.cc
Go to the documentation of this file.
1 // gmmbin/gmm-init-model.cc
2 
3 // Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)
4 // Johns Hopkins University (author: Guoguo Chen)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 
22 #include "base/kaldi-common.h"
23 #include "util/common-utils.h"
24 #include "gmm/am-diag-gmm.h"
25 #include "hmm/transition-model.h"
26 #include "gmm/mle-am-diag-gmm.h"
27 #include "tree/build-tree-utils.h"
28 #include "tree/context-dep.h"
30 #include "util/text-utils.h"
31 
32 namespace kaldi {
33 
35 void InitAmGmm(const BuildTreeStatsType &stats,
36  const EventMap &to_pdf_map,
37  AmDiagGmm *am_gmm,
38  const TransitionModel &trans_model,
39  BaseFloat var_floor) {
40  // Get stats split by tree-leaf ( == pdf):
41  std::vector<BuildTreeStatsType> split_stats;
42  SplitStatsByMap(stats, to_pdf_map, &split_stats);
43 
44  split_stats.resize(to_pdf_map.MaxResult() + 1); // ensure that
45  // if the last leaf had no stats, this vector still has the right size.
46 
47  // Make sure each leaf has stats.
48  for (size_t i = 0; i < split_stats.size(); i++) {
49  if (split_stats[i].empty()) {
50  std::vector<int32> bad_pdfs(1, i), bad_phones;
51  GetPhonesForPdfs(trans_model, bad_pdfs, &bad_phones);
52  std::ostringstream ss;
53  for (int32 idx = 0; idx < bad_phones.size(); idx ++)
54  ss << bad_phones[idx] << ' ';
55  KALDI_WARN << "Tree has pdf-id " << i
56  << " with no stats; corresponding phone list: " << ss.str();
57  /*
58  This probably means you have phones that were unseen in training
59  and were not shared with other phones in the roots file.
60  You should modify your roots file as necessary to fix this.
61  (i.e. share that phone with a similar but seen phone on one line
62  of the roots file). Be sure to regenerate roots.int from roots.txt,
63  if using s5 scripts. To work out the phone, search for
64  pdf-id i in the output of show-transitions (for this model). */
65  }
66  }
67  std::vector<Clusterable*> summed_stats;
68  SumStatsVec(split_stats, &summed_stats);
69  Clusterable *avg_stats = SumClusterable(summed_stats);
70  KALDI_ASSERT(avg_stats != NULL && "No stats available in gmm-init-model.");
71  for (size_t i = 0; i < summed_stats.size(); i++) {
72  GaussClusterable *c =
73  static_cast<GaussClusterable*>(summed_stats[i] != NULL ? summed_stats[i] : avg_stats);
74  DiagGmm gmm(*c, var_floor);
75  am_gmm->AddPdf(gmm);
76  BaseFloat count = c->count();
77  if (count < 100) {
78  std::vector<int32> bad_pdfs(1, i), bad_phones;
79  GetPhonesForPdfs(trans_model, bad_pdfs, &bad_phones);
80  std::ostringstream ss;
81  for (int32 idx = 0; idx < bad_phones.size(); idx ++)
82  ss << bad_phones[idx] << ' ';
83  KALDI_WARN << "Very small count for state " << i << ": "
84  << count << "; corresponding phone list: " << ss.str();
85  }
86  }
87  DeletePointers(&summed_stats);
88  delete avg_stats;
89 }
90 
92 void GetOccs(const BuildTreeStatsType &stats,
93  const EventMap &to_pdf_map,
94  Vector<BaseFloat> *occs) {
95 
96  // Get stats split by tree-leaf ( == pdf):
97  std::vector<BuildTreeStatsType> split_stats;
98  SplitStatsByMap(stats, to_pdf_map, &split_stats);
99  if (split_stats.size() != to_pdf_map.MaxResult()+1) {
100  KALDI_ASSERT(split_stats.size() < to_pdf_map.MaxResult()+1);
101  split_stats.resize(to_pdf_map.MaxResult()+1);
102  }
103  occs->Resize(split_stats.size());
104  for (int32 pdf = 0; pdf < occs->Dim(); pdf++)
105  (*occs)(pdf) = SumNormalizer(split_stats[pdf]);
106 }
107 
108 
109 
115 
117  const EventMap &to_pdf_map,
118  int32 N, // context-width
119  int32 P, // central-position
120  const std::string &old_tree_rxfilename,
121  const std::string &old_model_rxfilename,
122  BaseFloat var_floor,
123  AmDiagGmm *am_gmm) {
124 
125  AmDiagGmm old_am_gmm;
126  ContextDependency old_tree;
127  { // Read old_gm_gmm
128  bool binary_in;
129  TransitionModel old_trans_model;
130  Input ki(old_model_rxfilename, &binary_in);
131  old_trans_model.Read(ki.Stream(), binary_in);
132  old_am_gmm.Read(ki.Stream(), binary_in);
133  }
134  { // Read tree.
135  bool binary_in;
136  Input ki(old_tree_rxfilename, &binary_in);
137  old_tree.Read(ki.Stream(), binary_in);
138  }
139 
140 
141  // Get stats split by (new) tree-leaf ( == pdf):
142  std::vector<BuildTreeStatsType> split_stats;
143  SplitStatsByMap(stats, to_pdf_map, &split_stats);
144  // Make sure each leaf has stats.
145  for (size_t i = 0; i < split_stats.size(); i++) {
146  if (split_stats[i].empty()) {
147  KALDI_WARN << "Leaf " << i << " of new tree has no stats.";
148  }
149  }
150  if (static_cast<int32>(split_stats.size()) != to_pdf_map.MaxResult() + 1) {
151  KALDI_ASSERT(static_cast<int32>(split_stats.size()) <
152  to_pdf_map.MaxResult() + 1);
153  KALDI_WARN << "Tree may have final leaf with no stats.";
154  split_stats.resize(to_pdf_map.MaxResult() + 1);
155  // avoid indexing errors later.
156  }
157 
158  int32 oldN = old_tree.ContextWidth(), oldP = old_tree.CentralPosition();
159 
160  // avg_stats will be used for leaves that have no stats.
161  Clusterable *avg_stats = SumStats(stats);
162  GaussClusterable *avg_stats_gc = dynamic_cast<GaussClusterable*>(avg_stats);
163  KALDI_ASSERT(avg_stats_gc != NULL && "Empty stats input.");
164  DiagGmm avg_gmm(*avg_stats_gc, var_floor);
165  delete avg_stats;
166  avg_stats = NULL;
167  avg_stats_gc = NULL;
168 
169  const EventMap &old_map = old_tree.ToPdfMap();
170 
171  KALDI_ASSERT(am_gmm->NumPdfs() == 0);
172  int32 num_pdfs = static_cast<int32>(split_stats.size());
173  for (int32 pdf = 0; pdf < num_pdfs; pdf++) {
174  BuildTreeStatsType &my_stats = split_stats[pdf];
175  // The next statement converts the stats to a possibly narrower older
176  // context-width (e.g. triphone -> monophone).
177  // note: don't get confused by the "old" and "new" in the parameters
178  // to ConvertStats. The next line is correct.
179  bool ret = ConvertStats(N, P, oldN, oldP, &my_stats);
180  if (!ret)
181  KALDI_ERR << "InitAmGmmFromOld: old system has wider context "
182  "so cannot convert stats.";
183  // oldpdf_to_count works out a map from old pdf-id to count (for stats
184  // that align to this "new" pdf... we'll use it to work out the old pdf-id
185  // that's "closest" in stats overlap to this new pdf ("pdf").
186  std::map<int32, BaseFloat> oldpdf_to_count;
187  for (size_t i = 0; i < my_stats.size(); i++) {
188  EventType evec = my_stats[i].first;
189  EventAnswerType ans;
190  bool ret = old_map.Map(evec, &ans);
191  if (!ret) { KALDI_ERR << "Could not map context using old tree."; }
192  KALDI_ASSERT(my_stats[i].second != NULL);
193  BaseFloat stats_count = my_stats[i].second->Normalizer();
194  if (oldpdf_to_count.count(ans) == 0) oldpdf_to_count[ans] = stats_count;
195  else oldpdf_to_count[ans] += stats_count;
196  }
197  BaseFloat max_count = 0; int32 max_old_pdf = -1;
198  for (std::map<int32, BaseFloat>::const_iterator iter = oldpdf_to_count.begin();
199  iter != oldpdf_to_count.end();
200  ++iter) {
201  if (iter->second > max_count) {
202  max_count = iter->second;
203  max_old_pdf = iter->first;
204  }
205  }
206  if (max_count == 0) { // no overlap - probably a leaf with no stats at all.
207  KALDI_WARN << "Leaf " << pdf << " of new tree being initialized with "
208  << "globally averaged stats.";
209  am_gmm->AddPdf(avg_gmm);
210  } else {
211  am_gmm->AddPdf(old_am_gmm.GetPdf(max_old_pdf)); // Here is where we copy the relevant old PDF.
212  }
213  }
214 }
215 
216 
217 
218 }
219 
220 int main(int argc, char *argv[]) {
221  using namespace kaldi;
222  try {
223  using namespace kaldi;
224  typedef kaldi::int32 int32;
225 
226  const char *usage =
227  "Initialize GMM from decision tree and tree stats\n"
228  "Usage: gmm-init-model [options] <tree-in> <tree-stats-in> <topo-file> <model-out> [<old-tree> <old-model>]\n"
229  "e.g.: \n"
230  " gmm-init-model tree treeacc topo 1.mdl\n"
231  "or (initializing GMMs with old model):\n"
232  " gmm-init-model tree treeacc topo 1.mdl prev/tree prev/30.mdl\n";
233 
234  bool binary = true;
235  double var_floor = 0.01;
236  std::string occs_out_filename;
237 
238 
239  ParseOptions po(usage);
240  po.Register("binary", &binary, "Write output in binary mode");
241  po.Register("write-occs", &occs_out_filename, "File to write state "
242  "occupancies to.");
243  po.Register("var-floor", &var_floor, "Variance floor used while "
244  "initializing Gaussians");
245 
246  po.Read(argc, argv);
247 
248  if (po.NumArgs() != 4 && po.NumArgs() != 6) {
249  po.PrintUsage();
250  exit(1);
251  }
252 
253  std::string
254  tree_filename = po.GetArg(1),
255  stats_filename = po.GetArg(2),
256  topo_filename = po.GetArg(3),
257  model_out_filename = po.GetArg(4),
258  old_tree_filename = po.GetOptArg(5),
259  old_model_filename = po.GetOptArg(6);
260 
261  ContextDependency ctx_dep;
262  ReadKaldiObject(tree_filename, &ctx_dep);
263 
264  BuildTreeStatsType stats;
265  {
266  bool binary_in;
267  GaussClusterable gc; // dummy needed to provide type.
268  Input ki(stats_filename, &binary_in);
269  ReadBuildTreeStats(ki.Stream(), binary_in, gc, &stats);
270  }
271  KALDI_LOG << "Number of separate statistics is " << stats.size();
272 
273  HmmTopology topo;
274  ReadKaldiObject(topo_filename, &topo);
275 
276  const EventMap &to_pdf = ctx_dep.ToPdfMap(); // not owned here.
277 
278  TransitionModel trans_model(ctx_dep, topo);
279 
280  // Now, the summed_stats will be used to initialize the GMM.
281  AmDiagGmm am_gmm;
282  if (old_tree_filename.empty())
283  InitAmGmm(stats, to_pdf, &am_gmm, trans_model, var_floor); // Normal case: initialize 1 Gauss/model from tree stats.
284  else {
285  InitAmGmmFromOld(stats, to_pdf,
286  ctx_dep.ContextWidth(),
287  ctx_dep.CentralPosition(),
288  old_tree_filename,
289  old_model_filename,
290  var_floor,
291  &am_gmm);
292  }
293 
294  if (!occs_out_filename.empty()) { // write state occs
295  Vector<BaseFloat> occs;
296  GetOccs(stats, to_pdf, &occs);
297  Output ko(occs_out_filename, binary);
298  occs.Write(ko.Stream(), binary);
299  }
300 
301  {
302  Output ko(model_out_filename, binary);
303  trans_model.Write(ko.Stream(), binary);
304  am_gmm.Write(ko.Stream(), binary);
305  }
306  KALDI_LOG << "Wrote model.";
307 
308  DeleteBuildTreeStats(&stats);
309  } catch(const std::exception &e) {
310  std::cerr << e.what();
311  return -1;
312  }
313 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in)
Sums the normalizer [typically, data-count] over the stats.
void AddPdf(const DiagGmm &gmm)
Adds a GMM to the model, and increments the total number of PDFs.
Definition: am-diag-gmm.cc:57
virtual int32 ContextWidth() const
ContextWidth() returns the value N (e.g.
Definition: context-dep.h:61
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
const EventMap & ToPdfMap() const
Definition: context-dep.h:98
bool ConvertStats(int32 oldN, int32 oldP, int32 newN, int32 newP, BuildTreeStatsType *stats)
Converts stats from a given context-window (N) and central-position (P) to a different N and P...
void GetOccs(const BuildTreeStatsType &stats, const EventMap &to_pdf_map, Vector< BaseFloat > *occs)
Get state occupation counts.
A class for storing topology information for phones.
Definition: hmm-topology.h:93
Clusterable * SumStats(const BuildTreeStatsType &stats_in)
Sums stats, or returns NULL stats_in has no non-NULL stats.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void InitAmGmm(const BuildTreeStatsType &stats, const EventMap &to_pdf_map, AmDiagGmm *am_gmm, const TransitionModel &trans_model, BaseFloat var_floor)
InitAmGmm initializes the GMM with one Gaussian per state.
virtual EventAnswerType MaxResult() const
Definition: event-map.h:142
int main(int argc, char *argv[])
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void SplitStatsByMap(const BuildTreeStatsType &stats, const EventMap &e, std::vector< BuildTreeStatsType > *stats_out)
Splits stats according to the EventMap, indexing them at output by the leaf type. ...
kaldi::int32 int32
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
virtual bool Map(const EventType &event, EventAnswerType *ans) const =0
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
void ReadBuildTreeStats(std::istream &is, bool binary, const Clusterable &example, BuildTreeStatsType *stats)
Reads BuildTreeStats object.
const size_t count
bool GetPhonesForPdfs(const TransitionModel &trans_model, const std::vector< int32 > &pdfs, std::vector< int32 > *phones)
Works out which phones might correspond to the given pdfs.
std::istream & Stream()
Definition: kaldi-io.cc:826
std::vector< std::pair< EventKeyType, EventValueType > > EventType
Definition: event-map.h:58
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void InitAmGmmFromOld(const BuildTreeStatsType &stats, const EventMap &to_pdf_map, int32 N, int32 P, const std::string &old_tree_rxfilename, const std::string &old_model_rxfilename, BaseFloat var_floor, AmDiagGmm *am_gmm)
InitAmGmmFromOld initializes the GMM based on a previously trained model and tree, which must require no more phonetic context than the current tree.
void Read(std::istream &is, bool binary)
void SumStatsVec(const std::vector< BuildTreeStatsType > &stats_in, std::vector< Clusterable *> *stats_out)
Sum a vector of stats.
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
virtual int32 CentralPosition() const
Central position P of the phone context, in 0-based numbering, e.g.
Definition: context-dep.h:62
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
int NumArgs() const
Number of positional parameters (c.f. argc-1).
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
Definition: am-diag-gmm.h:119
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
Definition: event-map.h:86
void Write(std::ostream &os, bool binary) const
A class representing a vector.
Definition: kaldi-vector.h:406
void Read(std::istream &is, bool binary)
Read context-dependency object from disk; throws on error.
Definition: context-dep.cc:155
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
int32 EventAnswerType
As far as the event-map code itself is concerned, things of type EventAnswerType may take any value e...
Definition: event-map.h:56
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
Clusterable * SumClusterable(const std::vector< Clusterable *> &vec)
Sums stats (ptrs may be NULL). Returns NULL if no non-NULL stats present.
std::string GetOptArg(int param) const