gmm-init-biphone.cc
Go to the documentation of this file.
1 // gmmbin/gmm-init-biphone.cc
2 
3 // Copyright 2017 Hossein Hadian
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "gmm/am-diag-gmm.h"
24 #include "tree/event-map.h"
25 #include "tree/context-dep.h"
26 #include "hmm/hmm-topology.h"
27 #include "hmm/transition-model.h"
28 
29 namespace kaldi {
30 // This function reads a file like:
31 // 1 2 3
32 // 4 5
33 // 6 7 8
34 // where each line is a list of integer id's of phones (that should have their pdfs shared).
35 void ReadSharedPhonesList(std::string rxfilename, std::vector<std::vector<int32> > *list_out) {
36  list_out->clear();
37  Input input(rxfilename);
38  std::istream &is = input.Stream();
39  std::string line;
40  while (std::getline(is, line)) {
41  list_out->push_back(std::vector<int32>());
42  if (!SplitStringToIntegers(line, " \t\r", true, &(list_out->back())))
43  KALDI_ERR << "Bad line in shared phones list: " << line << " (reading "
44  << PrintableRxfilename(rxfilename) << ")";
45  std::sort(list_out->rbegin()->begin(), list_out->rbegin()->end());
46  if (!IsSortedAndUniq(*(list_out->rbegin())))
47  KALDI_ERR << "Bad line in shared phones list (repeated phone): " << line
48  << " (reading " << PrintableRxfilename(rxfilename) << ")";
49  }
50 }
51 
53 *GetFullBiphoneStubMap(const std::vector<std::vector<int32> > &phone_sets,
54  const std::vector<int32> &phone2num_pdf_classes,
55  const std::vector<int32> &ci_phones_list,
56  const std::vector<std::vector<int32> > &bi_counts,
57  int32 biphone_min_count,
58  const std::vector<int32> &mono_counts,
59  int32 mono_min_count) {
60 
61  { // Check the inputs
62  KALDI_ASSERT(!phone_sets.empty());
63  std::set<int32> all_phones;
64  for (size_t i = 0; i < phone_sets.size(); i++) {
65  KALDI_ASSERT(IsSortedAndUniq(phone_sets[i]));
66  KALDI_ASSERT(!phone_sets[i].empty());
67  for (size_t j = 0; j < phone_sets[i].size(); j++) {
68  KALDI_ASSERT(all_phones.count(phone_sets[i][j]) == 0); // Check not present.
69  all_phones.insert(phone_sets[i][j]);
70  }
71  }
72  }
73 
74 
75  int32 numpdfs_per_phone = phone2num_pdf_classes[1];
76  int32 current_pdfid = 0;
77  std::map<EventValueType, EventMap*> level1_map; // key is 1
78 
79  for (size_t i = 0; i < ci_phones_list.size(); i++) {
80  std::map<EventValueType, EventAnswerType> level2_map;
81  level2_map[0] = current_pdfid++;
82  if (numpdfs_per_phone == 2) level2_map[1] = current_pdfid++;
83  level1_map[ci_phones_list[i]] = new TableEventMap(kPdfClass, level2_map);
84  }
85 
86  // If there is not enough data for a biphone, we will revert to monophone
87  // and if there is not enough data for the monophone either, we will revert
88  // to zerophone (which is like a global garbage pdf) after initializing it.
89  int32 zerophone_pdf = -1;
90  // If a monophone state is created for a phone-set, the corresponding pdf will
91  // be stored in this vector.
92  std::vector<int32> monophone_pdf(phone_sets.size(), -1);
93 
94  for (size_t i = 0; i < phone_sets.size(); i++) {
95 
96  if (numpdfs_per_phone == 1) {
97  // Create an event map for level2:
98  std::map<EventValueType, EventAnswerType> level2_map; // key is 0
99  level2_map[0] = current_pdfid++; // no-left-context case
100  for (size_t j = 0; j < phone_sets.size(); j++) {
101  int32 pdfid = current_pdfid++;
102  std::vector<int32> pset = phone_sets[j]; // All these will have a
103  // shared pdf with id=pdfid
104  for (size_t k = 0; k < pset.size(); k++)
105  level2_map[pset[k]] = pdfid;
106  }
107  std::vector<int32> pset = phone_sets[i]; // All these will have a
108  // shared event-map child
109  for (size_t k = 0; k < pset.size(); k++)
110  level1_map[pset[k]] = new TableEventMap(0, level2_map);
111  } else {
112  KALDI_ASSERT(numpdfs_per_phone == 2);
113  std::vector<int32> right_phoneset = phone_sets[i]; // All these will have a shared
114  // event-map child
115  // Create an event map for level2:
116  std::map<EventValueType, EventMap*> level2_map; // key is 0
117  { // Handle CI phones
118  std::map<EventValueType, EventAnswerType> level3_map; // key is kPdfClass
119  level3_map[0] = current_pdfid++;
120  level3_map[1] = current_pdfid++;
121  level2_map[0] = new TableEventMap(kPdfClass, level3_map); // no-left-context case
122  for (size_t i = 0; i < ci_phones_list.size(); i++) // ci-phone left-context cases
123  level2_map[ci_phones_list[i]] = new TableEventMap(kPdfClass, level3_map);
124  }
125  for (size_t j = 0; j < phone_sets.size(); j++) {
126  std::vector<int32> left_phoneset = phone_sets[j]; // All these will have a
127  // shared subtree with 2 pdfids
128  std::map<EventValueType, EventAnswerType> level3_map; // key is kPdfClass
129  if (bi_counts.empty() ||
130  bi_counts[left_phoneset[0]][right_phoneset[0]] >= biphone_min_count) {
131  level3_map[0] = current_pdfid++;
132  level3_map[1] = current_pdfid++;
133  } else if (mono_counts.empty() ||
134  mono_counts[right_phoneset[0]] > mono_min_count) {
135  // Revert to mono.
136  KALDI_VLOG(2) << "Reverting to mono for biphone (" << left_phoneset[0]
137  << "," << right_phoneset[0] << ")";
138  if (monophone_pdf[i] == -1) {
139  KALDI_VLOG(1) << "Reserving mono PDFs for phone-set " << i;
140  monophone_pdf[i] = current_pdfid++;
141  current_pdfid++; // num-pdfs-per-phone is 2
142  }
143  level3_map[0] = monophone_pdf[i];
144  level3_map[1] = monophone_pdf[i] + 1;
145  } else {
146  KALDI_VLOG(2) << "Reverting to zerophone for biphone ("
147  << left_phoneset[0]
148  << "," << right_phoneset[0] << ")";
149  // Revert to zerophone
150  if (zerophone_pdf == -1) {
151  KALDI_VLOG(1) << "Reserving zero PDFs.";
152  zerophone_pdf = current_pdfid++;
153  current_pdfid++; // num-pdfs-per-phone is 2
154  }
155  level3_map[0] = zerophone_pdf;
156  level3_map[1] = zerophone_pdf + 1;
157  }
158 
159  for (size_t k = 0; k < left_phoneset.size(); k++) {
160  int32 left_phone = left_phoneset[k];
161  level2_map[left_phone] = new TableEventMap(kPdfClass, level3_map);
162  }
163  }
164  for (size_t k = 0; k < right_phoneset.size(); k++) {
165  std::map<EventValueType, EventMap*> level2_copy;
166  for (auto const& kv: level2_map)
167  level2_copy[kv.first] = kv.second->Copy(std::vector<EventMap*>());
168  int32 right_phone = right_phoneset[k];
169  level1_map[right_phone] = new TableEventMap(0, level2_copy);
170  }
171  }
172 
173  }
174  KALDI_LOG << "Num PDFs: " << current_pdfid;
175  return new TableEventMap(1, level1_map);
176 }
177 
178 
180 BiphoneContextDependencyFull(std::vector<std::vector<int32> > phone_sets,
181  const std::vector<int32> phone2num_pdf_classes,
182  const std::vector<int32> &ci_phones_list,
183  const std::vector<std::vector<int32> > &bi_counts,
184  int32 biphone_min_count,
185  const std::vector<int32> &mono_counts,
186  int32 mono_min_count) {
187  // Remove all the CI phones from the phone sets
188  std::set<int32> ci_phones;
189  for (size_t i = 0; i < ci_phones_list.size(); i++)
190  ci_phones.insert(ci_phones_list[i]);
191  for (int32 i = phone_sets.size() - 1; i >= 0; i--) {
192  for (int32 j = phone_sets[i].size() - 1; j >= 0; j--) {
193  if (ci_phones.find(phone_sets[i][j]) != ci_phones.end()) { // Delete it
194  phone_sets[i].erase(phone_sets[i].begin() + j);
195  if (phone_sets[i].empty()) // If empty, delete the whole entry
196  phone_sets.erase(phone_sets.begin() + i);
197  }
198  }
199  }
200 
201  std::vector<bool> share_roots(phone_sets.size(), false); // Don't share roots
202  // N is context size, P = position of central phone (must be 0).
203  int32 P = 1, N = 2;
204  EventMap *pdf_map = GetFullBiphoneStubMap(phone_sets,
205  phone2num_pdf_classes,
206  ci_phones_list, bi_counts,
207  biphone_min_count, mono_counts,
208  mono_min_count);
209  return new ContextDependency(N, P, pdf_map);
210 }
211 
212 
213 } // end namespace kaldi
214 
215 /* This function reads the counts of biphones and monophones from a text file
216  generated for chain flat-start training. On each line there is either a
217  biphone count or a monophone count:
218  <left-phone-id> <right-phone-id> <count>
219  <monophone-id> <count>
220  The phone-id's are according to phones.txt.
221 
222  It's more efficient to load the biphone counts into a map because
223  most entries are zero, but since there are not many biphones, a 2-dim vector
224  is OK. */
225 static void ReadPhoneCounts(std::string &filename, int32 num_phones,
226  std::vector<int32> *mono_counts,
227  std::vector<std::vector<int32> > *bi_counts) {
228  // The actual phones start from id = 1 (so the last phone has id = num_phones).
229  mono_counts->resize(num_phones + 1, 0);
230  bi_counts->resize(num_phones + 1, std::vector<int>(num_phones + 1, 0));
231  std::ifstream infile(filename);
232  std::string line;
233  while (std::getline(infile, line)) {
234  std::istringstream iss(line);
235  int a, b;
236  long c;
237  if ((std::istringstream(line) >> a >> b >> c)) {
238  // It's a biphone count.
239  KALDI_ASSERT(a >= 0 && a <= num_phones); // 0 means no-left-context
240  KALDI_ASSERT(b > 0 && b <= num_phones);
241  KALDI_ASSERT(c >= 0);
242  (*bi_counts)[a][b] = c;
243  } else if ((std::istringstream(line) >> b >> c)) {
244  // It's a monophone count.
245  KALDI_ASSERT(b > 0 && b <= num_phones);
246  KALDI_ASSERT(c >= 0);
247  (*mono_counts)[b] = c;
248  } else {
249  KALDI_ERR << "Bad line in phone stats file: " << line;
250  }
251  }
252 }
253 
254 int main(int argc, char *argv[]) {
255  try {
256  using namespace kaldi;
257  using kaldi::int32;
258 
259  const char *usage =
260  "Initialize a biphone context-dependency tree with all the\n"
261  "leaves (i.e. a full tree). Intended for end-to-end tree-free models.\n"
262  "Usage: gmm-init-biphone <topology-in> <dim> <model-out> <tree-out> \n"
263  "e.g.: \n"
264  " gmm-init-biphone topo 39 bi.mdl bi.tree\n";
265 
266  bool binary = true;
267  std::string shared_phones_rxfilename, phone_counts_rxfilename;
268  int32 min_biphone_count = 100, min_mono_count = 20;
269  std::string ci_phones_str;
270  std::vector<int32> ci_phones; // Sorted, uniqe vector of
271  // context-independent phones.
272 
273  ParseOptions po(usage);
274  po.Register("binary", &binary, "Write output in binary mode");
275  po.Register("shared-phones", &shared_phones_rxfilename,
276  "rxfilename containing, on each line, a list of phones "
277  "whose pdfs should be shared.");
278  po.Register("ci-phones", &ci_phones_str, "Colon-separated list of "
279  "integer indices of context-independent phones.");
280  po.Register("phone-counts", &phone_counts_rxfilename,
281  "rxfilename containing, on each line, a biphone/phone and "
282  "its count in the training data.");
283  po.Register("min-biphone-count", &min_biphone_count, "Minimum number of "
284  "occurences of a biphone in training data to reserve pdfs "
285  "for it.");
286  po.Register("min-monophone-count", &min_mono_count, "Minimum number of "
287  "occurences of a monophone in training data to reserve pdfs "
288  "for it.");
289  po.Read(argc, argv);
290 
291  if (po.NumArgs() != 4) {
292  po.PrintUsage();
293  exit(1);
294  }
295 
296 
297  std::string topo_filename = po.GetArg(1);
298  int dim = 0;
299  if (!ConvertStringToInteger(po.GetArg(2), &dim) || dim <= 0 || dim > 10000)
300  KALDI_ERR << "Bad dimension:" << po.GetArg(2)
301  << ". It should be a positive integer.";
302  std::string model_filename = po.GetArg(3);
303  std::string tree_filename = po.GetArg(4);
304 
305  if (!ci_phones_str.empty()) {
306  SplitStringToIntegers(ci_phones_str, ":", false, &ci_phones);
307  std::sort(ci_phones.begin(), ci_phones.end());
308  if (!IsSortedAndUniq(ci_phones) || ci_phones.empty() || ci_phones[0] == 0)
309  KALDI_ERR << "Invalid --ci-phones option: " << ci_phones_str;
310  }
311 
312  Vector<BaseFloat> glob_inv_var(dim);
313  glob_inv_var.Set(1.0);
314  Vector<BaseFloat> glob_mean(dim);
315  glob_mean.Set(1.0);
316 
317  HmmTopology topo;
318  bool binary_in;
319  Input ki(topo_filename, &binary_in);
320  topo.Read(ki.Stream(), binary_in);
321 
322  const std::vector<int32> &phones = topo.GetPhones();
323 
324  std::vector<int32> phone2num_pdf_classes(1 + phones.back());
325  for (size_t i = 0; i < phones.size(); i++) {
326  phone2num_pdf_classes[phones[i]] = topo.NumPdfClasses(phones[i]);
327  // For now we only support 1 or 2 pdf's per phone
328  KALDI_ASSERT(phone2num_pdf_classes[phones[i]] == 1 ||
329  phone2num_pdf_classes[phones[i]] == 2);
330  }
331 
332  std::vector<int32> mono_counts;
333  std::vector<std::vector<int32> > bi_counts;
334  if (!phone_counts_rxfilename.empty()) {
335  ReadPhoneCounts(phone_counts_rxfilename, phones.size(),
336  &mono_counts, &bi_counts);
337  KALDI_LOG << "Loaded mono/bi phone counts.";
338  }
339 
340 
341  // Now the tree:
342  ContextDependency *ctx_dep = NULL;
343  std::vector<std::vector<int32> > shared_phones;
344  if (shared_phones_rxfilename == "") {
345  shared_phones.resize(phones.size());
346  for (size_t i = 0; i < phones.size(); i++)
347  shared_phones[i].push_back(phones[i]);
348  } else {
349  ReadSharedPhonesList(shared_phones_rxfilename, &shared_phones);
350  // ReadSharedPhonesList crashes on error.
351  }
352  ctx_dep = BiphoneContextDependencyFull(shared_phones, phone2num_pdf_classes,
353  ci_phones, bi_counts,
354  min_biphone_count,
355  mono_counts, min_mono_count);
356 
357  int32 num_pdfs = ctx_dep->NumPdfs();
358 
359  AmDiagGmm am_gmm;
360  DiagGmm gmm;
361  gmm.Resize(1, dim);
362  { // Initialize the gmm.
363  Matrix<BaseFloat> inv_var(1, dim);
364  inv_var.Row(0).CopyFromVec(glob_inv_var);
365  Matrix<BaseFloat> mu(1, dim);
366  mu.Row(0).CopyFromVec(glob_mean);
367  Vector<BaseFloat> weights(1);
368  weights.Set(1.0);
369  gmm.SetInvVarsAndMeans(inv_var, mu);
370  gmm.SetWeights(weights);
371  gmm.ComputeGconsts();
372  }
373 
374  for (int i = 0; i < num_pdfs; i++)
375  am_gmm.AddPdf(gmm);
376 
377  // Now the transition model:
378  TransitionModel trans_model(*ctx_dep, topo);
379 
380  {
381  Output ko(model_filename, binary);
382  trans_model.Write(ko.Stream(), binary);
383  am_gmm.Write(ko.Stream(), binary);
384  }
385 
386  // Now write the tree.
387  ctx_dep->Write(Output(tree_filename, binary).Stream(),
388  binary);
389 
390  delete ctx_dep;
391  return 0;
392  } catch(const std::exception &e) {
393  std::cerr << e.what();
394  return -1;
395  }
396 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void AddPdf(const DiagGmm &gmm)
Adds a GMM to the model, and increments the total number of PDFs.
Definition: am-diag-gmm.cc:57
bool ConvertStringToInteger(const std::string &str, Int *out)
Converts a string into an integer via strtoll and returns false if there was any kind of problem (i...
Definition: text-utils.h:118
void SetInvVarsAndMeans(const MatrixBase< Real > &invvars, const MatrixBase< Real > &means)
Use SetInvVarsAndMeans if updating both means and (inverse) variances.
Definition: diag-gmm-inl.h:63
A class for storing topology information for phones.
Definition: hmm-topology.h:93
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Resize(int32 nMix, int32 dim)
Resizes arrays to this dim. Does not initialize data.
Definition: diag-gmm.cc:66
int32 ComputeGconsts()
Sets the gconsts.
Definition: diag-gmm.cc:114
kaldi::int32 int32
void Read(std::istream &is, bool binary)
Definition: hmm-topology.cc:39
void Register(const std::string &name, bool *ptr, const std::string &doc)
int32 NumPdfClasses(int32 phone) const
Returns the number of pdf-classes for this phone; throws exception if phone not covered by this topol...
void ReadSharedPhonesList(std::string rxfilename, std::vector< std::vector< int32 > > *list_out)
static const EventKeyType kPdfClass
Definition: context-dep.h:39
virtual int32 NumPdfs() const
NumPdfs() returns the number of acoustic pdfs (they are numbered 0.. NumPdfs()-1).
Definition: context-dep.h:71
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Write(std::ostream &os, bool binary) const
Definition: context-dep.cc:145
std::ostream & Stream()
Definition: kaldi-io.cc:701
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
int main(int argc, char *argv[])
static void ReadPhoneCounts(std::string &filename, int32 num_phones, std::vector< int32 > *mono_counts, std::vector< std::vector< int32 > > *bi_counts)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
const std::vector< int32 > & GetPhones() const
Returns a reference to a sorted, unique list of phones covered by the topology (these phones will be ...
Definition: hmm-topology.h:163
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
Definition: event-map.h:86
void Write(std::ostream &os, bool binary) const
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Set(Real f)
Set all members of a vector to a specified value.
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
std::string PrintableRxfilename(const std::string &rxfilename)
PrintableRxfilename turns the rxfilename into a more human-readable form for error reporting...
Definition: kaldi-io.cc:61
void SetWeights(const VectorBase< Real > &w)
Mutators for both float or double.
Definition: diag-gmm-inl.h:28
EventMap * GetFullBiphoneStubMap(const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &phone2num_pdf_classes, const std::vector< int32 > &ci_phones_list, const std::vector< std::vector< int32 > > &bi_counts, int32 biphone_min_count, const std::vector< int32 > &mono_counts, int32 mono_min_count)
bool IsSortedAndUniq(const std::vector< T > &vec)
Returns true if the vector is sorted and contains each element only once.
Definition: stl-utils.h:63
#define KALDI_LOG
Definition: kaldi-error.h:153
ContextDependency * BiphoneContextDependencyFull(std::vector< std::vector< int32 > > phone_sets, const std::vector< int32 > phone2num_pdf_classes, const std::vector< int32 > &ci_phones_list, const std::vector< std::vector< int32 > > &bi_counts, int32 biphone_min_count, const std::vector< int32 > &mono_counts, int32 mono_min_count)