All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
build-tree.cc File Reference
Include dependency graph for build-tree.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 30 of file build-tree.cc.

References kaldi::BuildTree(), kaldi::DeleteBuildTreeStats(), ParseOptions::GetArg(), HmmTopology::GetPhones(), HmmTopology::GetPhoneToNumPdfClasses(), rnnlm::i, KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, ParseOptions::NumArgs(), kaldi::PossibleValues(), ParseOptions::PrintUsage(), ParseOptions::Read(), Questions::Read(), kaldi::ReadBuildTreeStats(), kaldi::ReadKaldiObject(), kaldi::ReadRootsFile(), ParseOptions::Register(), kaldi::SortAndUniq(), kaldi::SplitStatsByMap(), Input::Stream(), kaldi::SumNormalizer(), kaldi::WriteIntegerVector(), and kaldi::WriteKaldiObject().

30  {
31  using namespace kaldi;
32  try {
33  using namespace kaldi;
34  typedef kaldi::int32 int32;
35 
36  const char *usage =
37  "Train decision tree\n"
38  "Usage: build-tree [options] <tree-stats-in> <roots-file> <questions-file> <topo-file> <tree-out>\n"
39  "e.g.: \n"
40  " build-tree treeacc roots.txt 1.qst topo tree\n";
41 
42  bool binary = true;
43  int32 P = 1, N = 3;
44 
45  BaseFloat thresh = 300.0;
46  BaseFloat cluster_thresh = -1.0; // negative means use smallest split in splitting phase as thresh.
47  int32 max_leaves = 0;
48  std::string occs_out_filename;
49 
50  ParseOptions po(usage);
51  po.Register("binary", &binary, "Write output in binary mode");
52  po.Register("context-width", &N, "Context window size [must match "
53  "acc-tree-stats]");
54  po.Register("central-position", &P, "Central position in context window "
55  "[must match acc-tree-stats]");
56  po.Register("max-leaves", &max_leaves, "Maximum number of leaves to be "
57  "used in tree-buliding (if positive)");
58  po.Register("thresh", &thresh, "Log-likelihood change threshold for "
59  "tree-building");
60  po.Register("cluster-thresh", &cluster_thresh, "Log-likelihood change "
61  "threshold for clustering after tree-building. 0 means "
62  "no clustering; -1 means use as a clustering threshold the "
63  "likelihood change of the final split.");
64 
65  po.Read(argc, argv);
66 
67  if (po.NumArgs() != 5) {
68  po.PrintUsage();
69  exit(1);
70  }
71 
72  std::string stats_filename = po.GetArg(1),
73  roots_filename = po.GetArg(2),
74  questions_filename = po.GetArg(3),
75  topo_filename = po.GetArg(4),
76  tree_out_filename = po.GetArg(5);
77 
78 
79  // Following 2 variables derived from roots file.
80  // phone_sets is sets of phones that share their roots.
81  // Just one phone each for normal systems.
82  std::vector<std::vector<int32> > phone_sets;
83  std::vector<bool> is_shared_root;
84  std::vector<bool> is_split_root;
85  {
86  Input ki(roots_filename.c_str());
87  ReadRootsFile(ki.Stream(), &phone_sets, &is_shared_root, &is_split_root);
88  }
89 
90  HmmTopology topo;
91  ReadKaldiObject(topo_filename, &topo);
92 
93  BuildTreeStatsType stats;
94  {
95  bool binary_in;
96  GaussClusterable gc; // dummy needed to provide type.
97  Input ki(stats_filename, &binary_in);
98  ReadBuildTreeStats(ki.Stream(), binary_in, gc, &stats);
99  }
100  KALDI_LOG << "Number of separate statistics is " << stats.size();
101 
102  Questions qo;
103  {
104  bool binary_in;
105  try {
106  Input ki(questions_filename, &binary_in);
107  qo.Read(ki.Stream(), binary_in);
108  } catch (const std::exception &e) {
109  KALDI_ERR << "Error reading questions file "<<questions_filename<<", error is: " << e.what();
110  }
111  }
112 
113 
114  std::vector<int32> phone2num_pdf_classes;
115  topo.GetPhoneToNumPdfClasses(&phone2num_pdf_classes);
116 
117  EventMap *to_pdf = NULL;
118 
120 
121  to_pdf = BuildTree(qo,
122  phone_sets,
123  phone2num_pdf_classes,
124  is_shared_root,
125  is_split_root,
126  stats,
127  thresh,
128  max_leaves,
129  cluster_thresh,
130  P);
131 
132  { // This block is to warn about low counts.
133  std::vector<BuildTreeStatsType> split_stats;
134  SplitStatsByMap(stats, *to_pdf,
135  &split_stats);
136  for (size_t i = 0; i < split_stats.size(); i++)
137  if (SumNormalizer(split_stats[i]) < 100.0)
138  KALDI_VLOG(1) << "For pdf-id " << i << ", low count "
139  << SumNormalizer(split_stats[i]);
140  }
141 
142  ContextDependency ctx_dep(N, P, to_pdf); // takes ownership
143  // of pointer "to_pdf", so set it NULL.
144  to_pdf = NULL;
145 
146  WriteKaldiObject(ctx_dep, tree_out_filename, binary);
147 
148  { // This block is just doing some checks.
149 
150  std::vector<int32> all_phones;
151  for (size_t i = 0; i < phone_sets.size(); i++)
152  all_phones.insert(all_phones.end(),
153  phone_sets[i].begin(), phone_sets[i].end());
154  SortAndUniq(&all_phones);
155  if (all_phones != topo.GetPhones()) {
156  std::ostringstream ss;
157  WriteIntegerVector(ss, false, all_phones);
158  ss << " vs. ";
159  WriteIntegerVector(ss, false, topo.GetPhones());
160  KALDI_WARN << "Mismatch between phone sets provided in roots file, and those in topology: " << ss.str();
161  }
162  std::vector<int32> phones_vec; // phones we saw.
163  PossibleValues(P, stats, &phones_vec); // function in build-tree-utils.h
164 
165  std::vector<int32> unseen_phones; // diagnostic.
166  for (size_t i = 0; i < all_phones.size(); i++)
167  if (!std::binary_search(phones_vec.begin(), phones_vec.end(), all_phones[i]))
168  unseen_phones.push_back(all_phones[i]);
169  for (size_t i = 0; i < phones_vec.size(); i++)
170  if (!std::binary_search(all_phones.begin(), all_phones.end(), phones_vec[i]))
171  KALDI_ERR << "Phone " << (phones_vec[i])
172  << " appears in stats but is not listed in roots file.";
173  if (!unseen_phones.empty()) {
174  std::ostringstream ss;
175  for (size_t i = 0; i < unseen_phones.size(); i++)
176  ss << unseen_phones[i] << ' ';
177  // Note, unseen phones is just a warning as in certain kinds of
178  // systems, this can be OK (e.g. where phone encodes position and
179  // stress information).
180  KALDI_WARN << "Saw no stats for following phones: " << ss.str();
181  }
182  }
183 
184  KALDI_LOG << "Wrote tree";
185 
186  DeleteBuildTreeStats(&stats);
187  } catch(const std::exception &e) {
188  std::cerr << e.what();
189  return -1;
190  }
191 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in)
Sums the normalizer [typically, data-count] over the stats.
This class defines, for each EventKeyType, a set of initial questions that it tries and also a number...
A class for storing topology information for phones.
Definition: hmm-topology.h:94
void SplitStatsByMap(const BuildTreeStatsType &stats, const EventMap &e, std::vector< BuildTreeStatsType > *stats_out)
Splits stats according to the EventMap, indexing them at output by the leaf type. ...
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
Definition: stl-utils.h:51
bool PossibleValues(EventKeyType key, const BuildTreeStatsType &stats, std::vector< EventValueType > *ans)
Convenience function e.g.
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:818
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
void ReadBuildTreeStats(std::istream &is, bool binary, const Clusterable &example, BuildTreeStatsType *stats)
Reads BuildTreeStats object.
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const std::vector< int32 > & GetPhones() const
Returns a reference to a sorted, unique list of phones covered by the topology (these phones will be ...
Definition: hmm-topology.h:164
void GetPhoneToNumPdfClasses(std::vector< int32 > *phone2num_pdf_classes) const
Outputs a vector of int32, indexed by phone, that gives the number of Pdf-classes pdf-classes for the...
Definition: hmm-topology.cc:31
void Read(std::istream &is, bool binary)
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_WARN
Definition: kaldi-error.h:130
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
Definition: event-map.h:86
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
void ReadRootsFile(std::istream &is, std::vector< std::vector< int32 > > *phone_sets, std::vector< bool > *is_shared_root, std::vector< bool > *is_split_root)
Reads the roots file (throws on error).
Definition: build-tree.cc:783
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
#define KALDI_LOG
Definition: kaldi-error.h:133
EventMap * BuildTree(Questions &qopts, const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &phone2num_pdf_classes, const std::vector< bool > &share_roots, const std::vector< bool > &do_split, const BuildTreeStatsType &stats, BaseFloat thresh, int32 max_leaves, BaseFloat cluster_thresh, int32 P)
BuildTree is the normal way to build a set of decision trees.
Definition: build-tree.cc:135