build-tree-two-level.cc
Go to the documentation of this file.
1 // bin/build-tree-two-level.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "hmm/hmm-topology.h"
24 #include "tree/context-dep.h"
25 #include "tree/build-tree.h"
26 #include "tree/build-tree-utils.h"
27 #include "tree/context-dep.h"
29 #include "util/text-utils.h"
30 
31 namespace kaldi {
32 void GetSeenPhones(BuildTreeStatsType &stats, int P, std::vector<int32> *phones_out) {
33  // Get list of phones that we saw (in the central position P, although it
34  // shouldn't matter what position).
35 
36  std::set<int32> phones_set;
37  for (size_t i = 0 ; i < stats.size(); i++) {
38  const EventType &evec = stats[i].first;
39  for (size_t j = 0; j < evec.size(); j++) {
40  if (evec[j].first == P) { // "key" is position P
41  KALDI_ASSERT(evec[j].second != 0);
42  phones_set.insert(evec[j].second); // insert "value" of this
43  // phone.
44  }
45  }
46  CopySetToVector(phones_set, phones_out);
47  }
48 }
49 
50 
51 }
52 
53 int main(int argc, char *argv[]) {
54  using namespace kaldi;
55  try {
56  using namespace kaldi;
57  typedef kaldi::int32 int32;
58 
59  const char *usage =
60  "Trains two-level decision tree. Outputs the larger tree, and a mapping from the\n"
61  "leaf-ids of the larger tree to those of the smaller tree. Useful, for instance,\n"
62  "in tied-mixture systems with multiple codebooks.\n"
63  "\n"
64  "Usage: build-tree-two-level [options] <tree-stats-in> <roots-file> <questions-file> <topo-file> <tree-out> <mapping-out>\n"
65  "e.g.: \n"
66  " build-tree-two-level treeacc roots.txt 1.qst topo tree tree.map\n";
67 
68  bool binary = true;
69  int32 P = 1, N = 3;
70 
71  bool cluster_leaves = true;
72  int32 max_leaves_first = 1000;
73  int32 max_leaves_second = 5000;
74  std::string occs_out_filename;
75 
76  ParseOptions po(usage);
77  po.Register("binary", &binary, "Write output in binary mode");
78  po.Register("context-width", &N, "Context window size [must match "
79  "acc-tree-stats]");
80  po.Register("central-position", &P, "Central position in context window "
81  "[must match acc-tree-stats]");
82  po.Register("max-leaves-first", &max_leaves_first, "Maximum number of "
83  "leaves in first-level decision tree.");
84  po.Register("max-leaves-second", &max_leaves_second, "Maximum number of "
85  "leaves in second-level decision tree.");
86  po.Register("cluster-leaves", &cluster_leaves, "If true, do a post-clustering"
87  " of the leaves of the final decision tree.");
88 
89  po.Read(argc, argv);
90 
91  if (po.NumArgs() != 6) {
92  po.PrintUsage();
93  exit(1);
94  }
95 
96  std::string stats_filename = po.GetArg(1),
97  roots_filename = po.GetArg(2),
98  questions_filename = po.GetArg(3),
99  topo_filename = po.GetArg(4),
100  tree_out_filename = po.GetArg(5),
101  map_out_filename = po.GetArg(6);
102 
103 
104  // Following 2 variables derived from roots file.
105  // phone_sets is sets of phones that share their roots.
106  // Just one phone each for normal systems.
107  std::vector<std::vector<int32> > phone_sets;
108  std::vector<bool> is_shared_root;
109  std::vector<bool> is_split_root;
110  {
111  Input ki(roots_filename.c_str());
112  ReadRootsFile(ki.Stream(), &phone_sets, &is_shared_root, &is_split_root);
113  }
114 
115  HmmTopology topo;
116  ReadKaldiObject(topo_filename, &topo);
117 
118  BuildTreeStatsType stats;
119  {
120  bool binary_in;
121  GaussClusterable gc; // dummy needed to provide type.
122  Input ki(stats_filename, &binary_in);
123  ReadBuildTreeStats(ki.Stream(), binary_in, gc, &stats);
124  }
125  std::cerr << "Number of separate statistics is " << stats.size() << '\n';
126 
127  Questions qo;
128  {
129  bool binary_in;
130  try {
131  Input ki(questions_filename, &binary_in);
132  qo.Read(ki.Stream(), binary_in);
133  } catch (const std::exception &e) {
134  KALDI_ERR << "Error reading questions file "<<questions_filename<<", error is: " << e.what();
135  }
136  }
137 
138 
139  std::vector<int32> phone2num_pdf_classes;
140  topo.GetPhoneToNumPdfClasses(&phone2num_pdf_classes);
141 
142  EventMap *to_pdf = NULL;
143 
144  std::vector<int32> mapping;
145 
147 
148  to_pdf = BuildTreeTwoLevel(qo,
149  phone_sets,
150  phone2num_pdf_classes,
151  is_shared_root,
152  is_split_root,
153  stats,
154  max_leaves_first,
155  max_leaves_second,
156  cluster_leaves,
157  P,
158  &mapping);
159 
160  ContextDependency ctx_dep(N, P, to_pdf); // takes ownership
161  // of pointer "to_pdf", so set it NULL.
162  to_pdf = NULL;
163 
164  WriteKaldiObject(ctx_dep, tree_out_filename, binary);
165 
166  {
167  Output ko(map_out_filename, binary);
168  WriteIntegerVector(ko.Stream(), binary, mapping);
169  }
170 
171  { // This block is just doing some checks.
172 
173  std::vector<int32> all_phones;
174  for (size_t i = 0; i < phone_sets.size(); i++)
175  all_phones.insert(all_phones.end(),
176  phone_sets[i].begin(), phone_sets[i].end());
177  SortAndUniq(&all_phones);
178  if (all_phones != topo.GetPhones()) {
179  std::ostringstream ss;
180  WriteIntegerVector(ss, false, all_phones);
181  ss << " vs. ";
182  WriteIntegerVector(ss, false, topo.GetPhones());
183  KALDI_WARN << "Mismatch between phone sets provided in roots file, and those in topology: " << ss.str();
184  }
185  std::vector<int32> phones_vec; // phones we saw.
186  GetSeenPhones(stats, P, &phones_vec);
187 
188  std::vector<int32> unseen_phones; // diagnostic.
189  for (size_t i = 0; i < all_phones.size(); i++)
190  if (!std::binary_search(phones_vec.begin(), phones_vec.end(), all_phones[i]))
191  unseen_phones.push_back(all_phones[i]);
192  for (size_t i = 0; i < phones_vec.size(); i++)
193  if (!std::binary_search(all_phones.begin(), all_phones.end(), phones_vec[i]))
194  KALDI_ERR << "Phone "<< (phones_vec[i]) << " appears in stats but is not listed in roots file.";
195  if (!unseen_phones.empty()) {
196  std::ostringstream ss;
197  for (size_t i = 0; i < unseen_phones.size(); i++)
198  ss << unseen_phones[i] << ' ';
199  // Note, unseen phones is just a warning as in certain kinds of
200  // systems, this can be OK (e.g. where phone encodes position and
201  // stress information).
202  KALDI_WARN << "Saw no stats for following phones: " << ss.str();
203  }
204  }
205 
206  std::cerr << "Wrote tree and mapping\n";
207 
208  DeleteBuildTreeStats(&stats);
209  } catch(const std::exception &e) {
210  std::cerr << e.what();
211  return -1;
212  }
213 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
EventMap * BuildTreeTwoLevel(Questions &qopts, const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &phone2num_pdf_classes, const std::vector< bool > &share_roots, const std::vector< bool > &do_split, const BuildTreeStatsType &stats, int32 max_leaves_first, int32 max_leaves_second, bool cluster_leaves, int32 P, std::vector< int32 > *leaf_map)
BuildTreeTwoLevel builds a two-level tree, useful for example in building tied mixture systems with m...
Definition: build-tree.cc:387
void CopySetToVector(const std::set< T > &s, std::vector< T > *v)
Copies the elements of a set to a vector.
Definition: stl-utils.h:86
This class defines, for each EventKeyType, a set of initial questions that it tries and also a number...
A class for storing topology information for phones.
Definition: hmm-topology.h:93
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq&#39;s (removes duplicates) from a vector.
Definition: stl-utils.h:39
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
void ReadBuildTreeStats(std::istream &is, bool binary, const Clusterable &example, BuildTreeStatsType *stats)
Reads BuildTreeStats object.
std::istream & Stream()
Definition: kaldi-io.cc:826
std::vector< std::pair< EventKeyType, EventValueType > > EventType
Definition: event-map.h:58
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Read(std::istream &is, bool binary)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int main(int argc, char *argv[])
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void GetSeenPhones(BuildTreeStatsType &stats, int P, std::vector< int32 > *phones_out)
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
Definition: event-map.h:86
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
void ReadRootsFile(std::istream &is, std::vector< std::vector< int32 > > *phone_sets, std::vector< bool > *is_shared_root, std::vector< bool > *is_split_root)
Reads the roots file (throws on error).
Definition: build-tree.cc:857
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...