All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
build-tree.cc
Go to the documentation of this file.
1 // bin/build-tree.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "hmm/hmm-topology.h"
24 #include "tree/context-dep.h"
25 #include "tree/build-tree.h"
26 #include "tree/build-tree-utils.h"
28 #include "util/text-utils.h"
29 
30 int main(int argc, char *argv[]) {
31  using namespace kaldi;
32  try {
33  using namespace kaldi;
34  typedef kaldi::int32 int32;
35 
36  const char *usage =
37  "Train decision tree\n"
38  "Usage: build-tree [options] <tree-stats-in> <roots-file> <questions-file> <topo-file> <tree-out>\n"
39  "e.g.: \n"
40  " build-tree treeacc roots.txt 1.qst topo tree\n";
41 
42  bool binary = true;
43  int32 P = 1, N = 3;
44 
45  BaseFloat thresh = 300.0;
46  BaseFloat cluster_thresh = -1.0; // negative means use smallest split in splitting phase as thresh.
47  int32 max_leaves = 0;
48  std::string occs_out_filename;
49 
50  ParseOptions po(usage);
51  po.Register("binary", &binary, "Write output in binary mode");
52  po.Register("context-width", &N, "Context window size [must match "
53  "acc-tree-stats]");
54  po.Register("central-position", &P, "Central position in context window "
55  "[must match acc-tree-stats]");
56  po.Register("max-leaves", &max_leaves, "Maximum number of leaves to be "
57  "used in tree-buliding (if positive)");
58  po.Register("thresh", &thresh, "Log-likelihood change threshold for "
59  "tree-building");
60  po.Register("cluster-thresh", &cluster_thresh, "Log-likelihood change "
61  "threshold for clustering after tree-building. 0 means "
62  "no clustering; -1 means use as a clustering threshold the "
63  "likelihood change of the final split.");
64 
65  po.Read(argc, argv);
66 
67  if (po.NumArgs() != 5) {
68  po.PrintUsage();
69  exit(1);
70  }
71 
72  std::string stats_filename = po.GetArg(1),
73  roots_filename = po.GetArg(2),
74  questions_filename = po.GetArg(3),
75  topo_filename = po.GetArg(4),
76  tree_out_filename = po.GetArg(5);
77 
78 
79  // Following 2 variables derived from roots file.
80  // phone_sets is sets of phones that share their roots.
81  // Just one phone each for normal systems.
82  std::vector<std::vector<int32> > phone_sets;
83  std::vector<bool> is_shared_root;
84  std::vector<bool> is_split_root;
85  {
86  Input ki(roots_filename.c_str());
87  ReadRootsFile(ki.Stream(), &phone_sets, &is_shared_root, &is_split_root);
88  }
89 
90  HmmTopology topo;
91  ReadKaldiObject(topo_filename, &topo);
92 
93  BuildTreeStatsType stats;
94  {
95  bool binary_in;
96  GaussClusterable gc; // dummy needed to provide type.
97  Input ki(stats_filename, &binary_in);
98  ReadBuildTreeStats(ki.Stream(), binary_in, gc, &stats);
99  }
100  KALDI_LOG << "Number of separate statistics is " << stats.size();
101 
102  Questions qo;
103  {
104  bool binary_in;
105  try {
106  Input ki(questions_filename, &binary_in);
107  qo.Read(ki.Stream(), binary_in);
108  } catch (const std::exception &e) {
109  KALDI_ERR << "Error reading questions file "<<questions_filename<<", error is: " << e.what();
110  }
111  }
112 
113 
114  std::vector<int32> phone2num_pdf_classes;
115  topo.GetPhoneToNumPdfClasses(&phone2num_pdf_classes);
116 
117  EventMap *to_pdf = NULL;
118 
120 
121  to_pdf = BuildTree(qo,
122  phone_sets,
123  phone2num_pdf_classes,
124  is_shared_root,
125  is_split_root,
126  stats,
127  thresh,
128  max_leaves,
129  cluster_thresh,
130  P);
131 
132  { // This block is to warn about low counts.
133  std::vector<BuildTreeStatsType> split_stats;
134  SplitStatsByMap(stats, *to_pdf,
135  &split_stats);
136  for (size_t i = 0; i < split_stats.size(); i++)
137  if (SumNormalizer(split_stats[i]) < 100.0)
138  KALDI_VLOG(1) << "For pdf-id " << i << ", low count "
139  << SumNormalizer(split_stats[i]);
140  }
141 
142  ContextDependency ctx_dep(N, P, to_pdf); // takes ownership
143  // of pointer "to_pdf", so set it NULL.
144  to_pdf = NULL;
145 
146  WriteKaldiObject(ctx_dep, tree_out_filename, binary);
147 
148  { // This block is just doing some checks.
149 
150  std::vector<int32> all_phones;
151  for (size_t i = 0; i < phone_sets.size(); i++)
152  all_phones.insert(all_phones.end(),
153  phone_sets[i].begin(), phone_sets[i].end());
154  SortAndUniq(&all_phones);
155  if (all_phones != topo.GetPhones()) {
156  std::ostringstream ss;
157  WriteIntegerVector(ss, false, all_phones);
158  ss << " vs. ";
159  WriteIntegerVector(ss, false, topo.GetPhones());
160  KALDI_WARN << "Mismatch between phone sets provided in roots file, and those in topology: " << ss.str();
161  }
162  std::vector<int32> phones_vec; // phones we saw.
163  PossibleValues(P, stats, &phones_vec); // function in build-tree-utils.h
164 
165  std::vector<int32> unseen_phones; // diagnostic.
166  for (size_t i = 0; i < all_phones.size(); i++)
167  if (!std::binary_search(phones_vec.begin(), phones_vec.end(), all_phones[i]))
168  unseen_phones.push_back(all_phones[i]);
169  for (size_t i = 0; i < phones_vec.size(); i++)
170  if (!std::binary_search(all_phones.begin(), all_phones.end(), phones_vec[i]))
171  KALDI_ERR << "Phone " << (phones_vec[i])
172  << " appears in stats but is not listed in roots file.";
173  if (!unseen_phones.empty()) {
174  std::ostringstream ss;
175  for (size_t i = 0; i < unseen_phones.size(); i++)
176  ss << unseen_phones[i] << ' ';
177  // Note, unseen phones is just a warning as in certain kinds of
178  // systems, this can be OK (e.g. where phone encodes position and
179  // stress information).
180  KALDI_WARN << "Saw no stats for following phones: " << ss.str();
181  }
182  }
183 
184  KALDI_LOG << "Wrote tree";
185 
186  DeleteBuildTreeStats(&stats);
187  } catch(const std::exception &e) {
188  std::cerr << e.what();
189  return -1;
190  }
191 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in)
Sums the normalizer [typically, data-count] over the stats.
This class defines, for each EventKeyType, a set of initial questions that it tries and also a number...
A class for storing topology information for phones.
Definition: hmm-topology.h:94
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void SplitStatsByMap(const BuildTreeStatsType &stats, const EventMap &e, std::vector< BuildTreeStatsType > *stats_out)
Splits stats according to the EventMap, indexing them at output by the leaf type. ...
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
Definition: stl-utils.h:39
void Register(const std::string &name, bool *ptr, const std::string &doc)
bool PossibleValues(EventKeyType key, const BuildTreeStatsType &stats, std::vector< EventValueType > *ans)
Convenience function e.g.
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:818
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
void ReadBuildTreeStats(std::istream &is, bool binary, const Clusterable &example, BuildTreeStatsType *stats)
Reads BuildTreeStats object.
std::istream & Stream()
Definition: kaldi-io.cc:812
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const std::vector< int32 > & GetPhones() const
Returns a reference to a sorted, unique list of phones covered by the topology (these phones will be ...
Definition: hmm-topology.h:164
void GetPhoneToNumPdfClasses(std::vector< int32 > *phone2num_pdf_classes) const
Outputs a vector of int32, indexed by phone, that gives the number of Pdf-classes pdf-classes for the...
Definition: hmm-topology.cc:31
void Read(std::istream &is, bool binary)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_WARN
Definition: kaldi-error.h:130
int main(int argc, char *argv[])
Definition: build-tree.cc:30
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
Definition: event-map.h:86
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
void ReadRootsFile(std::istream &is, std::vector< std::vector< int32 > > *phone_sets, std::vector< bool > *is_shared_root, std::vector< bool > *is_split_root)
Reads the roots file (throws on error).
Definition: build-tree.cc:783
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
#define KALDI_LOG
Definition: kaldi-error.h:133
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
EventMap * BuildTree(Questions &qopts, const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &phone2num_pdf_classes, const std::vector< bool > &share_roots, const std::vector< bool > &do_split, const BuildTreeStatsType &stats, BaseFloat thresh, int32 max_leaves, BaseFloat cluster_thresh, int32 P)
BuildTree is the normal way to build a set of decision trees.
Definition: build-tree.cc:135