All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
build-tree.cc
Go to the documentation of this file.
1 // bin/build-tree.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "hmm/hmm-topology.h"
24 #include "tree/context-dep.h"
25 #include "tree/build-tree.h"
26 #include "tree/build-tree-utils.h"
28 #include "util/text-utils.h"
29 
30 int main(int argc, char *argv[]) {
31  using namespace kaldi;
32  try {
33  using namespace kaldi;
34  typedef kaldi::int32 int32;
35 
36  const char *usage =
37  "Train decision tree\n"
38  "Usage: build-tree [options] <tree-stats-in> <roots-file> <questions-file> <topo-file> <tree-out>\n"
39  "e.g.: \n"
40  " build-tree treeacc roots.txt 1.qst topo tree\n";
41 
42  bool binary = true;
43  int32 P = 1, N = 3;
44 
45  BaseFloat thresh = 300.0;
46  BaseFloat cluster_thresh = -1.0; // negative means use smallest split in splitting phase as thresh.
47  int32 max_leaves = 0;
48  bool round_num_leaves = true;
49  std::string occs_out_filename;
50 
51  ParseOptions po(usage);
52  po.Register("binary", &binary, "Write output in binary mode");
53  po.Register("context-width", &N, "Context window size [must match "
54  "acc-tree-stats]");
55  po.Register("central-position", &P, "Central position in context window "
56  "[must match acc-tree-stats]");
57  po.Register("max-leaves", &max_leaves, "Maximum number of leaves to be "
58  "used in tree-buliding (if positive)");
59  po.Register("thresh", &thresh, "Log-likelihood change threshold for "
60  "tree-building");
61  po.Register("cluster-thresh", &cluster_thresh, "Log-likelihood change "
62  "threshold for clustering after tree-building. 0 means "
63  "no clustering; -1 means use as a clustering threshold the "
64  "likelihood change of the final split.");
65  po.Register("round-num-leaves", &round_num_leaves,
66  "If true, then the number of leaves will be reduced to a "
67  "multiple of 8 by clustering.");
68 
69  po.Read(argc, argv);
70 
71  if (po.NumArgs() != 5) {
72  po.PrintUsage();
73  exit(1);
74  }
75 
76  std::string stats_filename = po.GetArg(1),
77  roots_filename = po.GetArg(2),
78  questions_filename = po.GetArg(3),
79  topo_filename = po.GetArg(4),
80  tree_out_filename = po.GetArg(5);
81 
82 
83  // Following 2 variables derived from roots file.
84  // phone_sets is sets of phones that share their roots.
85  // Just one phone each for normal systems.
86  std::vector<std::vector<int32> > phone_sets;
87  std::vector<bool> is_shared_root;
88  std::vector<bool> is_split_root;
89  {
90  Input ki(roots_filename.c_str());
91  ReadRootsFile(ki.Stream(), &phone_sets, &is_shared_root, &is_split_root);
92  }
93 
94  HmmTopology topo;
95  ReadKaldiObject(topo_filename, &topo);
96 
97  BuildTreeStatsType stats;
98  {
99  bool binary_in;
100  GaussClusterable gc; // dummy needed to provide type.
101  Input ki(stats_filename, &binary_in);
102  ReadBuildTreeStats(ki.Stream(), binary_in, gc, &stats);
103  }
104  KALDI_LOG << "Number of separate statistics is " << stats.size();
105 
106  Questions qo;
107  {
108  bool binary_in;
109  try {
110  Input ki(questions_filename, &binary_in);
111  qo.Read(ki.Stream(), binary_in);
112  } catch (const std::exception &e) {
113  KALDI_ERR << "Error reading questions file "<<questions_filename<<", error is: " << e.what();
114  }
115  }
116 
117 
118  std::vector<int32> phone2num_pdf_classes;
119  topo.GetPhoneToNumPdfClasses(&phone2num_pdf_classes);
120 
121  EventMap *to_pdf = NULL;
122 
124 
125  to_pdf = BuildTree(qo,
126  phone_sets,
127  phone2num_pdf_classes,
128  is_shared_root,
129  is_split_root,
130  stats,
131  thresh,
132  max_leaves,
133  cluster_thresh,
134  P,
135  round_num_leaves);
136 
137  { // This block is to warn about low counts.
138  std::vector<BuildTreeStatsType> split_stats;
139  SplitStatsByMap(stats, *to_pdf,
140  &split_stats);
141  for (size_t i = 0; i < split_stats.size(); i++)
142  if (SumNormalizer(split_stats[i]) < 100.0)
143  KALDI_VLOG(1) << "For pdf-id " << i << ", low count "
144  << SumNormalizer(split_stats[i]);
145  }
146 
147  ContextDependency ctx_dep(N, P, to_pdf); // takes ownership
148  // of pointer "to_pdf", so set it NULL.
149  to_pdf = NULL;
150 
151  WriteKaldiObject(ctx_dep, tree_out_filename, binary);
152 
153  { // This block is just doing some checks.
154 
155  std::vector<int32> all_phones;
156  for (size_t i = 0; i < phone_sets.size(); i++)
157  all_phones.insert(all_phones.end(),
158  phone_sets[i].begin(), phone_sets[i].end());
159  SortAndUniq(&all_phones);
160  if (all_phones != topo.GetPhones()) {
161  std::ostringstream ss;
162  WriteIntegerVector(ss, false, all_phones);
163  ss << " vs. ";
164  WriteIntegerVector(ss, false, topo.GetPhones());
165  KALDI_WARN << "Mismatch between phone sets provided in roots file, and those in topology: " << ss.str();
166  }
167  std::vector<int32> phones_vec; // phones we saw.
168  PossibleValues(P, stats, &phones_vec); // function in build-tree-utils.h
169 
170  std::vector<int32> unseen_phones; // diagnostic.
171  for (size_t i = 0; i < all_phones.size(); i++)
172  if (!std::binary_search(phones_vec.begin(), phones_vec.end(), all_phones[i]))
173  unseen_phones.push_back(all_phones[i]);
174  for (size_t i = 0; i < phones_vec.size(); i++)
175  if (!std::binary_search(all_phones.begin(), all_phones.end(), phones_vec[i]))
176  KALDI_ERR << "Phone " << (phones_vec[i])
177  << " appears in stats but is not listed in roots file.";
178  if (!unseen_phones.empty()) {
179  std::ostringstream ss;
180  for (size_t i = 0; i < unseen_phones.size(); i++)
181  ss << unseen_phones[i] << ' ';
182  // Note, unseen phones is just a warning as in certain kinds of
183  // systems, this can be OK (e.g. where phone encodes position and
184  // stress information).
185  KALDI_WARN << "Saw no stats for following phones: " << ss.str();
186  }
187  }
188 
189  KALDI_LOG << "Wrote tree";
190 
191  DeleteBuildTreeStats(&stats);
192  } catch(const std::exception &e) {
193  std::cerr << e.what();
194  return -1;
195  }
196 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in)
Sums the normalizer [typically, data-count] over the stats.
This class defines, for each EventKeyType, a set of initial questions that it tries and also a number...
A class for storing topology information for phones.
Definition: hmm-topology.h:94
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void SplitStatsByMap(const BuildTreeStatsType &stats, const EventMap &e, std::vector< BuildTreeStatsType > *stats_out)
Splits stats according to the EventMap, indexing them at output by the leaf type. ...
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
Definition: stl-utils.h:39
void Register(const std::string &name, bool *ptr, const std::string &doc)
bool PossibleValues(EventKeyType key, const BuildTreeStatsType &stats, std::vector< EventValueType > *ans)
Convenience function e.g.
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
void ReadBuildTreeStats(std::istream &is, bool binary, const Clusterable &example, BuildTreeStatsType *stats)
Reads BuildTreeStats object.
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const std::vector< int32 > & GetPhones() const
Returns a reference to a sorted, unique list of phones covered by the topology (these phones will be ...
Definition: hmm-topology.h:164
EventMap * BuildTree(Questions &qopts, const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &phone2num_pdf_classes, const std::vector< bool > &share_roots, const std::vector< bool > &do_split, const BuildTreeStatsType &stats, BaseFloat thresh, int32 max_leaves, BaseFloat cluster_thresh, int32 P, bool round_num_leaves)
BuildTree is the normal way to build a set of decision trees.
Definition: build-tree.cc:136
void GetPhoneToNumPdfClasses(std::vector< int32 > *phone2num_pdf_classes) const
Outputs a vector of int32, indexed by phone, that gives the number of Pdf-classes pdf-classes for the...
Definition: hmm-topology.cc:31
void Read(std::istream &is, bool binary)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_WARN
Definition: kaldi-error.h:130
int main(int argc, char *argv[])
Definition: build-tree.cc:30
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
Definition: event-map.h:86
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
void ReadRootsFile(std::istream &is, std::vector< std::vector< int32 > > *phone_sets, std::vector< bool > *is_shared_root, std::vector< bool > *is_split_root)
Reads the roots file (throws on error).
Definition: build-tree.cc:857
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
#define KALDI_LOG
Definition: kaldi-error.h:133
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.