build-tree-test.cc
Go to the documentation of this file.
1 // tree/build-tree-test.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "util/stl-utils.h"
21 #include "tree/build-tree.h"
22 
23 namespace kaldi {
24 
26  for (int32 p = 0; p < 2; p++) {
27  int32 dim = 1 + Rand() % 40;
28  int32 num_phones = 1 + Rand() % 40;
29  int32 num_stats = 1 + (Rand() % 20);
30  int32 N = 2 + Rand() % 2; // 2 or 3.
31  int32 P = Rand() % N;
32  float ctx_dep_prob = 0.5 + 0.5*RandUniform();
33  std::vector<int32> phone_ids(num_phones);
34  for (size_t i = 0;i < (size_t)num_phones;i++)
35  phone_ids[i] = (i == 0 ? (Rand() % 2) : phone_ids[i-1] + 1 + (Rand()%2));
36  int32 max_phone = *std::max_element(phone_ids.begin(), phone_ids.end());
37  std::vector<int32> hmm_lengths(max_phone+1);
38  std::vector<bool> is_ctx_dep(max_phone+1);
39 
40  for (int32 i = 0; i <= max_phone; i++) {
41  hmm_lengths[i] = 1 + Rand() % 3;
42  is_ctx_dep[i] = (RandUniform() < ctx_dep_prob); // true w.p. ctx_dep_prob.
43  }
44  for (size_t i = 0;i < (size_t) num_phones;i++) {
45  KALDI_VLOG(2) << "For idx = "<< i << ", (phone_id, hmm_length, is_ctx_dep) == " << (phone_ids[i]) << " " << (hmm_lengths[phone_ids[i]]) << " " << (is_ctx_dep[phone_ids[i]]);
46  }
47  BuildTreeStatsType stats;
48  // put false for all_covered argument.
49  // if it doesn't really ensure that all are covered with true, this will induce
50  // failure in the test of context-fst.
51  GenRandStats(dim, num_stats, N, P, phone_ids, hmm_lengths, is_ctx_dep, false, &stats);
52  std::cout << "Writing random stats.";
53  std::cout <<"dim = " << dim << '\n';
54  std::cout <<"num_phones = " << num_phones << '\n';
55  std::cout <<"num_stats = " << num_stats << '\n';
56  std::cout <<"N = "<< N << '\n';
57  std::cout <<"P = "<< P << '\n';
58  std::cout << "is-ctx-dep = ";
59  for (size_t i = 0;i < is_ctx_dep.size();i++)
60  WriteBasicType(std::cout, false, static_cast<bool>(is_ctx_dep[i]));
61  std::cout << "hmm_lengths = "; WriteIntegerVector(std::cout, false, hmm_lengths);
62  std::cout << "phone_ids = "; WriteIntegerVector(std::cout, false, phone_ids);
63  std::cout << "Stats are: \n";
64  WriteBuildTreeStats(std::cout, false, stats);
65 
66 
67  // Now check the properties of the stats.
68  for (size_t i = 0;i < stats.size();i++) {
69  EventValueType central_phone;
70  bool b = EventMap::Lookup(stats[i].first, P, &central_phone);
71  KALDI_ASSERT(b);
72  EventValueType position;
73  b = EventMap::Lookup(stats[i].first, kPdfClass, &position);
74  KALDI_ASSERT(b);
75  KALDI_ASSERT(position>=0 && position < hmm_lengths[central_phone]);
76 
77  for (EventKeyType j = 0; j < N; j++) {
78  if (j != P) { // non-"central" phone.
79  EventValueType ctx_phone;
80  b = EventMap::Lookup(stats[i].first, j, &ctx_phone);
81  KALDI_ASSERT(is_ctx_dep[central_phone] == b);
82  }
83  }
84  }
85  DeleteBuildTreeStats(&stats);
86  }
87 }
88 
89 
90 void TestBuildTree() {
91  for (int32 p = 0; p < 3; p++) {
92  // First decide phone-ids, hmm lengths, is-ctx-dep...
93 
94  int32 dim = 1 + Rand() % 40;
95  int32 num_phones = 1 + Rand() % 8;
96  int32 num_stats = 1 + (Rand() % 15) * (Rand() % 15); // up to 14^2 + 1 separate stats.
97  int32 N = 2 + Rand() % 2; // 2 or 3.
98  int32 P = Rand() % N;
99  float ctx_dep_prob = 0.5 + 0.5*RandUniform();
100 
101  std::vector<int32> phone_ids(num_phones);
102  for (size_t i = 0;i < (size_t)num_phones;i++)
103  phone_ids[i] = (i == 0 ? (Rand() % 2) : phone_ids[i-1] + 1 + (Rand()%2));
104  int32 max_phone = *std::max_element(phone_ids.begin(), phone_ids.end());
105  std::vector<int32> hmm_lengths(max_phone+1);
106  std::vector<bool> is_ctx_dep(max_phone+1);
107 
108  for (int32 i = 0; i <= max_phone; i++) {
109  hmm_lengths[i] = 1 + Rand() % 3;
110  is_ctx_dep[i] = (RandUniform() < ctx_dep_prob); // true w.p. ctx_dep_prob.
111  }
112  for (size_t i = 0;i < (size_t) num_phones;i++) {
113  KALDI_VLOG(2) << "For idx = "<< i << ", (phone_id, hmm_length, is_ctx_dep) == " << (phone_ids[i]) << " " << (hmm_lengths[phone_ids[i]]) << " " << (is_ctx_dep[phone_ids[i]]);
114  }
115  // Generate rand stats. These were tested in TestGenRandStats() above.
116  BuildTreeStatsType stats;
117  bool ensure_all_covered = false;
118  GenRandStats(dim, num_stats, N, P, phone_ids, hmm_lengths, is_ctx_dep, ensure_all_covered, &stats);
119 
120  { // print out the stats.
121  std::cout << "Writing random stats.";
122  std::cout << "dim = " << dim << '\n';
123  std::cout << "num_phones = " << num_phones << '\n';
124  std::cout << "num_stats = " << num_stats << '\n';
125  std::cout << "N = "<< N << '\n';
126  std::cout << "P = "<< P << '\n';
127  std::cout << "is-ctx-dep = ";
128  for (size_t i = 0;i < is_ctx_dep.size();i++)
129  WriteBasicType(std::cout, false, static_cast<bool>(is_ctx_dep[i]));
130  std::cout << "hmm_lengths = "; WriteIntegerVector(std::cout, false, hmm_lengths);
131  std::cout << "phone_ids = "; WriteIntegerVector(std::cout, false, phone_ids);
132  std::cout << "Stats are: \n";
133  WriteBuildTreeStats(std::cout, false, stats);
134  }
135 
136  // Now build the tree.
137 
138  Questions qopts;
139  int32 num_quest = Rand() % 10, num_iters = rand () % 5;
140  qopts.InitRand(stats, num_quest, num_iters, kAllKeysUnion); // This was tested in build-tree-utils-test.cc
141 
142  {
143  std::cout << "Printing questions:\n";
144  std::vector<EventKeyType> keys;
145  qopts.GetKeysWithQuestions(&keys);
146  for (size_t i = 0;i < keys.size();i++) {
147  KALDI_ASSERT(qopts.HasQuestionsForKey(keys[i]));
148  const QuestionsForKey &opts = qopts.GetQuestionsOf(keys[i]);
149  std::cout << "num-quest: "<< opts.initial_questions.size() << '\n';
150  for (size_t j = 0;j < opts.initial_questions.size();j++) {
151  for (size_t k = 0;k < opts.initial_questions[j].size();k++)
152  std::cout << opts.initial_questions[j][k] <<" ";
153  std::cout << '\n';
154  }
155  }
156  }
157 
158  float thresh = 100.0 * RandUniform();
159  int max_leaves = 100;
160  std::cout <<"Thresh = "<<thresh<<" for building tree.\n";
161 
162  {
163  std::cout << "Building tree\n";
164  EventMap *tree = NULL;
165  std::vector<std::vector<int32> > phone_sets(phone_ids.size());
166  for (size_t i = 0; i < phone_ids.size(); i++)
167  phone_sets[i].push_back(phone_ids[i]);
168  std::vector<bool> share_roots(phone_sets.size(), true),
169  do_split(phone_sets.size(), true);
170 
171  if (p % 3 != 0) {
172  bool round_num_leaves = true;
173 
174  EventMap *tree_not_rounded =
175  BuildTree(qopts, phone_sets, hmm_lengths, share_roots,
176  do_split, stats, thresh, max_leaves, 0.0, P,
177  false);
178 
179  tree = BuildTree(qopts, phone_sets, hmm_lengths, share_roots,
180  do_split, stats, thresh, max_leaves, 0.0, P,
181  round_num_leaves);
182 
183  BuildTreeStatsType::const_iterator iter, end = stats.end();
184 
185  std::map<EventAnswerType, std::set<EventAnswerType> > mapping;
186  int32 num_removed = 0;
187  for (iter = stats.begin(); iter != end; ++iter) {
188  const EventType &evec = iter->first;
189  EventAnswerType ans_not_rounded;
190  KALDI_ASSERT(tree_not_rounded->Map(evec, &ans_not_rounded));
191 
192  EventAnswerType ans;
193  KALDI_ASSERT(tree->Map(evec, &ans));
194 
195  auto it = mapping.find(ans);
196  if (it == mapping.end()) {
197  std::set<EventAnswerType> leaf_set;
198  leaf_set.insert(ans_not_rounded);
199  mapping.insert(it, std::make_pair(ans, leaf_set));
200  } else if (it->second.count(ans_not_rounded) == 0) {
201  num_removed++;
202  it->second.insert(ans_not_rounded);
203  }
204  }
205 
206  std::cout << "Leaf rounding map:\n";
207  for (auto it = mapping.begin(); it != mapping.end(); ++it) {
208  WriteBasicType(std::cout, false, it->first);
209  for (auto it2 = it->second.begin(); it2 != it->second.end(); ++it2) {
210  WriteBasicType(std::cout, false, *it2);
211  }
212  std::cout << std::endl;
213  }
214 
215  KALDI_ASSERT(num_removed < 8);
216  } else {
217  tree = BuildTree(qopts, phone_sets, hmm_lengths, share_roots,
218  do_split, stats, thresh, max_leaves, 0.0, P,
219  false);
220  }
221 
222  // Would have print-out & testing code here.
223  std::cout << "Tree [default build] is:\n";
224  tree->Write(std::cout, false);
225  delete tree;
226  }
227  DeleteBuildTreeStats(&stats);
228  }
229 }
230 
231 
232 } // end namespace kaldi
233 
234 int main() {
237 }
238 
void GetKeysWithQuestions(std::vector< EventKeyType > *keys_out) const
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void TestBuildTree()
This class defines, for each EventKeyType, a set of initial questions that it tries and also a number...
float RandUniform(struct RandomState *state=NULL)
Returns a random number strictly between 0 and 1.
Definition: kaldi-math.h:151
void TestGenRandStats()
const QuestionsForKey & GetQuestionsOf(EventKeyType key) const
kaldi::int32 int32
void GenRandStats(int32 dim, int32 num_stats, int32 N, int32 P, const std::vector< int32 > &phone_ids, const std::vector< int32 > &phone2hmm_length, const std::vector< bool > &is_ctx_dep, bool ensure_all_phones_covered, BuildTreeStatsType *stats_out)
GenRandStats generates random statistics of the form used by BuildTree.
Definition: build-tree.cc:30
virtual bool Map(const EventType &event, EventAnswerType *ans) const =0
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
static const EventKeyType kPdfClass
Definition: context-dep.h:39
std::vector< std::pair< EventKeyType, EventValueType > > EventType
Definition: event-map.h:58
void InitRand(const BuildTreeStatsType &stats, int32 num_quest, int32 num_iters_refine, AllKeysType all_keys_type)
InitRand attempts to generate "reasonable" random questions.
int32 EventKeyType
Things of type EventKeyType can take any value.
Definition: event-map.h:45
QuestionsForKey is a class used to define the questions for a key, and also options that allow us to ...
std::vector< std::vector< EventValueType > > initial_questions
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
static bool Lookup(const EventType &event, EventKeyType key, EventValueType *ans)
Definition: event-map.cc:290
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
Definition: event-map.h:86
bool HasQuestionsForKey(EventKeyType key) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
EventMap * BuildTree(Questions &qopts, const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &phone2num_pdf_classes, const std::vector< bool > &share_roots, const std::vector< bool > &do_split, const BuildTreeStatsType &stats, BaseFloat thresh, int32 max_leaves, BaseFloat cluster_thresh, int32 P, bool round_num_leaves)
BuildTree is the normal way to build a set of decision trees.
Definition: build-tree.cc:136
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
int32 EventAnswerType
As far as the event-map code itself is concerned, things of type EventAnswerType may take any value e...
Definition: event-map.h:56
int32 EventValueType
Given current code, things of type EventValueType should generally be nonnegative and in a reasonably...
Definition: event-map.h:51
int main()
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
virtual void Write(std::ostream &os, bool binary)=0
Write to stream.
void WriteBuildTreeStats(std::ostream &os, bool binary, const BuildTreeStatsType &stats)
Writes BuildTreeStats object. This works even if pointers are NULL.