All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
build-tree-two-level.cc File Reference
Include dependency graph for build-tree-two-level.cc:

Go to the source code of this file.

Namespaces

 kaldi
 Relabels neural network egs with the read pdf-id alignments.
 

Functions

void GetSeenPhones (BuildTreeStatsType &stats, int P, std::vector< int32 > *phones_out)
 
int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 53 of file build-tree-two-level.cc.

References kaldi::BuildTreeTwoLevel(), kaldi::DeleteBuildTreeStats(), ParseOptions::GetArg(), HmmTopology::GetPhones(), HmmTopology::GetPhoneToNumPdfClasses(), kaldi::GetSeenPhones(), rnnlm::i, KALDI_ERR, KALDI_WARN, ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), Questions::Read(), kaldi::ReadBuildTreeStats(), kaldi::ReadKaldiObject(), kaldi::ReadRootsFile(), ParseOptions::Register(), kaldi::SortAndUniq(), Output::Stream(), Input::Stream(), kaldi::WriteIntegerVector(), and kaldi::WriteKaldiObject().

53  {
54  using namespace kaldi;
55  try {
56  using namespace kaldi;
57  typedef kaldi::int32 int32;
58 
59  const char *usage =
60  "Trains two-level decision tree. Outputs the larger tree, and a mapping from the\n"
61  "leaf-ids of the larger tree to those of the smaller tree. Useful, for instance,\n"
62  "in tied-mixture systems with multiple codebooks.\n"
63  "\n"
64  "Usage: build-tree-two-level [options] <tree-stats-in> <roots-file> <questions-file> <topo-file> <tree-out> <mapping-out>\n"
65  "e.g.: \n"
66  " build-tree-two-level treeacc roots.txt 1.qst topo tree tree.map\n";
67 
68  bool binary = true;
69  int32 P = 1, N = 3;
70 
71  bool cluster_leaves = true;
72  int32 max_leaves_first = 1000;
73  int32 max_leaves_second = 5000;
74  std::string occs_out_filename;
75 
76  ParseOptions po(usage);
77  po.Register("binary", &binary, "Write output in binary mode");
78  po.Register("context-width", &N, "Context window size [must match "
79  "acc-tree-stats]");
80  po.Register("central-position", &P, "Central position in context window "
81  "[must match acc-tree-stats]");
82  po.Register("max-leaves-first", &max_leaves_first, "Maximum number of "
83  "leaves in first-level decision tree.");
84  po.Register("max-leaves-second", &max_leaves_second, "Maximum number of "
85  "leaves in second-level decision tree.");
86  po.Register("cluster-leaves", &cluster_leaves, "If true, do a post-clustering"
87  " of the leaves of the final decision tree.");
88 
89  po.Read(argc, argv);
90 
91  if (po.NumArgs() != 6) {
92  po.PrintUsage();
93  exit(1);
94  }
95 
96  std::string stats_filename = po.GetArg(1),
97  roots_filename = po.GetArg(2),
98  questions_filename = po.GetArg(3),
99  topo_filename = po.GetArg(4),
100  tree_out_filename = po.GetArg(5),
101  map_out_filename = po.GetArg(6);
102 
103 
104  // Following 2 variables derived from roots file.
105  // phone_sets is sets of phones that share their roots.
106  // Just one phone each for normal systems.
107  std::vector<std::vector<int32> > phone_sets;
108  std::vector<bool> is_shared_root;
109  std::vector<bool> is_split_root;
110  {
111  Input ki(roots_filename.c_str());
112  ReadRootsFile(ki.Stream(), &phone_sets, &is_shared_root, &is_split_root);
113  }
114 
115  HmmTopology topo;
116  ReadKaldiObject(topo_filename, &topo);
117 
118  BuildTreeStatsType stats;
119  {
120  bool binary_in;
121  GaussClusterable gc; // dummy needed to provide type.
122  Input ki(stats_filename, &binary_in);
123  ReadBuildTreeStats(ki.Stream(), binary_in, gc, &stats);
124  }
125  std::cerr << "Number of separate statistics is " << stats.size() << '\n';
126 
127  Questions qo;
128  {
129  bool binary_in;
130  try {
131  Input ki(questions_filename, &binary_in);
132  qo.Read(ki.Stream(), binary_in);
133  } catch (const std::exception &e) {
134  KALDI_ERR << "Error reading questions file "<<questions_filename<<", error is: " << e.what();
135  }
136  }
137 
138 
139  std::vector<int32> phone2num_pdf_classes;
140  topo.GetPhoneToNumPdfClasses(&phone2num_pdf_classes);
141 
142  EventMap *to_pdf = NULL;
143 
144  std::vector<int32> mapping;
145 
147 
148  to_pdf = BuildTreeTwoLevel(qo,
149  phone_sets,
150  phone2num_pdf_classes,
151  is_shared_root,
152  is_split_root,
153  stats,
154  max_leaves_first,
155  max_leaves_second,
156  cluster_leaves,
157  P,
158  &mapping);
159 
160  ContextDependency ctx_dep(N, P, to_pdf); // takes ownership
161  // of pointer "to_pdf", so set it NULL.
162  to_pdf = NULL;
163 
164  WriteKaldiObject(ctx_dep, tree_out_filename, binary);
165 
166  {
167  Output ko(map_out_filename, binary);
168  WriteIntegerVector(ko.Stream(), binary, mapping);
169  }
170 
171  { // This block is just doing some checks.
172 
173  std::vector<int32> all_phones;
174  for (size_t i = 0; i < phone_sets.size(); i++)
175  all_phones.insert(all_phones.end(),
176  phone_sets[i].begin(), phone_sets[i].end());
177  SortAndUniq(&all_phones);
178  if (all_phones != topo.GetPhones()) {
179  std::ostringstream ss;
180  WriteIntegerVector(ss, false, all_phones);
181  ss << " vs. ";
182  WriteIntegerVector(ss, false, topo.GetPhones());
183  KALDI_WARN << "Mismatch between phone sets provided in roots file, and those in topology: " << ss.str();
184  }
185  std::vector<int32> phones_vec; // phones we saw.
186  GetSeenPhones(stats, P, &phones_vec);
187 
188  std::vector<int32> unseen_phones; // diagnostic.
189  for (size_t i = 0; i < all_phones.size(); i++)
190  if (!std::binary_search(phones_vec.begin(), phones_vec.end(), all_phones[i]))
191  unseen_phones.push_back(all_phones[i]);
192  for (size_t i = 0; i < phones_vec.size(); i++)
193  if (!std::binary_search(all_phones.begin(), all_phones.end(), phones_vec[i]))
194  KALDI_ERR << "Phone "<< (phones_vec[i]) << " appears in stats but is not listed in roots file.";
195  if (!unseen_phones.empty()) {
196  std::ostringstream ss;
197  for (size_t i = 0; i < unseen_phones.size(); i++)
198  ss << unseen_phones[i] << ' ';
199  // Note, unseen phones is just a warning as in certain kinds of
200  // systems, this can be OK (e.g. where phone encodes position and
201  // stress information).
202  KALDI_WARN << "Saw no stats for following phones: " << ss.str();
203  }
204  }
205 
206  std::cerr << "Wrote tree and mapping\n";
207 
208  DeleteBuildTreeStats(&stats);
209  } catch(const std::exception &e) {
210  std::cerr << e.what();
211  return -1;
212  }
213 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
EventMap * BuildTreeTwoLevel(Questions &qopts, const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &phone2num_pdf_classes, const std::vector< bool > &share_roots, const std::vector< bool > &do_split, const BuildTreeStatsType &stats, int32 max_leaves_first, int32 max_leaves_second, bool cluster_leaves, int32 P, std::vector< int32 > *leaf_map)
BuildTreeTwoLevel builds a two-level tree, useful for example in building tied mixture systems with m...
Definition: build-tree.cc:313
This class defines, for each EventKeyType, a set of initial questions that it tries and also a number...
A class for storing topology information for phones.
Definition: hmm-topology.h:94
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
Definition: stl-utils.h:39
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:818
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
void ReadBuildTreeStats(std::istream &is, bool binary, const Clusterable &example, BuildTreeStatsType *stats)
Reads BuildTreeStats object.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const std::vector< int32 > & GetPhones() const
Returns a reference to a sorted, unique list of phones covered by the topology (these phones will be ...
Definition: hmm-topology.h:164
void GetPhoneToNumPdfClasses(std::vector< int32 > *phone2num_pdf_classes) const
Outputs a vector of int32, indexed by phone, that gives the number of Pdf-classes pdf-classes for the...
Definition: hmm-topology.cc:31
void Read(std::istream &is, bool binary)
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_WARN
Definition: kaldi-error.h:130
void GetSeenPhones(BuildTreeStatsType &stats, int P, std::vector< int32 > *phones_out)
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
Definition: event-map.h:86
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
void ReadRootsFile(std::istream &is, std::vector< std::vector< int32 > > *phone_sets, std::vector< bool > *is_shared_root, std::vector< bool > *is_split_root)
Reads the roots file (throws on error).
Definition: build-tree.cc:783
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...