regression-tree.cc
Go to the documentation of this file.
1 // transform/regression-tree.cc
2 
3 // Copyright 2009-2011 Saarland University
4 // Author: Arnab Ghoshal
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <string>
22 #include <utility>
23 using std::pair;
24 #include <vector>
25 using std::vector;
26 
29 #include "util/common-utils.h"
30 
31 namespace kaldi {
32 
35  const std::vector<int32> &sil_indices,
36  const AmDiagGmm &am,
37  int32 max_clusters) {
38  KALDI_ASSERT(IsSortedAndUniq(sil_indices));
39  int32 dim = am.Dim(),
40  num_pdfs = static_cast<int32>(am.NumPdfs());
41  vector<Clusterable*> gauss_means;
42  // For each Gaussianin the model, the pair of (pdf, gaussian) indices.
43  vector< pair<int32, int32> > gauss_indices;
44  Vector<BaseFloat> tmp_mean(dim);
45  Vector<BaseFloat> tmp_var(dim);
46  BaseFloat var_floor = 0.01;
47 
48  gauss2bclass_.resize(num_pdfs);
49  gauss_means.reserve(am.NumGauss()); // NOT resize, uses push_back
50  gauss_indices.reserve(am.NumGauss()); // NOT resize, uses push_back
51 
52  for (int32 pdf_index = 0; pdf_index < num_pdfs; pdf_index++) {
53  gauss2bclass_[pdf_index].resize(am.GetPdf(pdf_index).NumGauss());
54  for (int32 num_gauss = am.GetPdf(pdf_index).NumGauss(),
55  gauss_index = 0; gauss_index < num_gauss; ++gauss_index) {
56  // don't include silence while clustering...
57  if (std::binary_search(sil_indices.begin(), sil_indices.end(), pdf_index))
58  continue;
59 
60  am.GetGaussianMean(pdf_index, gauss_index, &tmp_mean);
61  am.GetGaussianVariance(pdf_index, gauss_index, &tmp_var);
62  tmp_var.AddVec2(1.0, tmp_mean); // make it x^2 stats.
63  BaseFloat this_weight = state_occs(pdf_index) *
64  (am.GetPdf(pdf_index).weights())(gauss_index);
65  tmp_mean.Scale(this_weight);
66  tmp_var.Scale(this_weight);
67  gauss_indices.push_back(std::make_pair(pdf_index, gauss_index));
68  gauss_means.push_back(new GaussClusterable(tmp_mean, tmp_var, var_floor,
69  this_weight));
70  }
71  }
72 
73  vector<int32> leaves;
74  vector<int32> clust_parents;
75  int32 num_leaves;
76  TreeClusterOptions opts; // Use default options or get from somewhere else
77  TreeCluster(gauss_means,
78  (sil_indices.empty() ? max_clusters : max_clusters-1),
79  NULL /* clusters not needed */,
80  &leaves, &clust_parents, &num_leaves, opts);
81 
82  if (sil_indices.empty()) { // no special treatment of silence...
83  num_baseclasses_ = static_cast<int32>(num_leaves);
84  baseclasses_.resize(num_leaves);
85  parents_.resize(clust_parents.size());
86  for (int32 i = 0, num_nodes = clust_parents.size(); i < num_nodes; i++) {
87  parents_[i] = static_cast<int32>(clust_parents[i]);
88  }
89  num_nodes_ = static_cast<int32>(clust_parents.size());
90  for (int32 i = 0; i < static_cast<int32>(gauss_indices.size()); i++) {
91  baseclasses_[leaves[i]].push_back(gauss_indices[i]);
92  gauss2bclass_[gauss_indices[i].first][gauss_indices[i].second] = leaves[i];
93  }
94  } else {
95  // separate top-level split between silence and speech...
96  // silence is node zero and new parent is last-numbered one.
97  num_baseclasses_ = static_cast<int32>(num_leaves+1); // +1 to include 0 == silence
98  baseclasses_.resize(num_leaves+1); // +1 to include 0 == silence
99  parents_.resize(clust_parents.size()+2); // +1 to include 0 == silence, +parent.
100 
101  int32 top_node = clust_parents.size() + 1;
102  for (int32 i = 0; i < static_cast<int32>(clust_parents.size()); i++) {
103  parents_[i+1] = clust_parents[i]+1; // handle offsets
104  }
105  parents_[0] = top_node;
106  parents_[clust_parents.size()] = top_node; // old top node's parent is new top node.
107  parents_[top_node] = top_node; // being own parent is sign of being top node.
108 
109  num_nodes_ = static_cast<int32>(clust_parents.size() + 2);
110  // Assign nonsilence Gaussians to their assigned classes (add one
111  // to all leaf indices, make room for silence class).
112  for (int32 i = 0; i < static_cast<int32>(gauss_indices.size()); i++) {
113  baseclasses_[leaves[i]+1].push_back(gauss_indices[i]);
114  gauss2bclass_[gauss_indices[i].first][gauss_indices[i].second] = leaves[i]+1;
115  }
116  // Assign silence Gaussians to zero'th baseclass.
117  for (int32 i = 0; i < static_cast<int32>(sil_indices.size()); i++) {
118  int32 pdf_index = sil_indices[i];
119  for (int32 j = 0; j < am.GetPdf(pdf_index).NumGauss(); j++) {
120  baseclasses_[0].push_back(std::make_pair(pdf_index, j));
121  gauss2bclass_[pdf_index][j] = 0;
122  }
123  }
124  }
125  DeletePointers(&gauss_means);
126 }
127 
128 
129 static bool GetActiveParents(int32 node, const vector<int32> &parents,
130  const vector<bool> &is_active,
131  vector<int32> *active_parents_out) {
132  KALDI_ASSERT(parents.size() == is_active.size());
133  KALDI_ASSERT(static_cast<size_t>(node) < parents.size());
134  active_parents_out->clear();
135 
136  if (node == static_cast<int32> (parents.size() - 1)) { // root node
137  if (is_active[node]) {
138  active_parents_out->push_back(node);
139  return true;
140  } else {
141  return false;
142  }
143  }
144 
145  bool ret_val = false;
146  while (node < static_cast<int32> (parents.size() - 1)) { // exclude the root
147  node = parents[node];
148  if (is_active[node]) {
149  active_parents_out->push_back(node);
150  ret_val = true;
151  }
152  }
153  return ret_val; // will return if not starting from root
154 }
155 
162 bool RegressionTree::GatherStats(const vector<AffineXformStats*> &stats_in,
163  double min_count,
164  vector<int32> *regclasses_out,
165  vector<AffineXformStats*> *stats_out) const {
166  KALDI_ASSERT(static_cast<int32>(stats_in.size()) == num_baseclasses_);
167  if (static_cast<int32>(regclasses_out->size()) != num_baseclasses_)
168  regclasses_out->resize(static_cast<size_t>(num_baseclasses_), -1);
169  if (num_baseclasses_ == 1) // Only root node in tree
170  KALDI_ASSERT(num_nodes_ == 1);
171 
172  double total_occ = 0.0;
173  int32 num_regclasses = 0;
174  vector<double> node_occupancies(num_nodes_, 0.0);
175  vector<bool> generate_xform(num_nodes_, false);
176  vector<int32> regclasses(num_nodes_, -1);
177 
178  // Go through the leaves (baseclasses) and find where to generate transforms
179  for (int32 bclass = 0; bclass < num_baseclasses_; bclass++) {
180  total_occ += stats_in[bclass]->beta_;
181  node_occupancies[bclass] = stats_in[bclass]->beta_;
182  if (num_baseclasses_ != 1) { // Don't count twice if tree only has root.
183  node_occupancies[parents_[bclass]] += node_occupancies[bclass];
184  }
185  if (node_occupancies[bclass] < min_count) {
186  // Not enough count, so pass the responsibility to the parent.
187  generate_xform[bclass] = false;
188  generate_xform[parents_[bclass]] = true;
189  } else { // generate at the leaf level.
190  generate_xform[bclass] = true;
191  regclasses[bclass] = num_regclasses++;
192  }
193  }
194  // Check whether there is enough data for the single global transform (at
195  // the root of the regression tree). If not, no transforms will be computed.
196  if (total_occ < min_count) {
197  // Make all baseclasses use the unit transform at the root.
198  for (int32 bclass = 0; bclass < num_baseclasses_; bclass++) {
199  (*regclasses_out)[bclass] = 0;
200  }
201  DeletePointers(stats_out);
202  stats_out->clear();
203  KALDI_WARN << "Not enough data to compute global transform. Occupancy at "
204  << "root = " << total_occ << "<" << min_count;
205  return false;
206  }
207 
208  // Now go through the non-leaf nodes and find where to generate transforms.
209  // Iterates only till num_nodes_ - 1 so that it doesn't count root twice.
210  for (int32 node = num_baseclasses_; node < num_nodes_ - 1; node++) {
211  node_occupancies[parents_[node]] += node_occupancies[node];
212  // Only bother with generating transforms if a child asked for it.
213  if (generate_xform[node]) {
214  if (node_occupancies[node] < min_count) {
215  // Not enough count, so pass the responsibility to the parent.
216  generate_xform[node] = false;
217  generate_xform[parents_[node]] = true;
218  } else { // transform will be generated at this level.
219  regclasses[node] = num_regclasses++;
220  }
221  }
222  }
223 
224  AssertEqual(node_occupancies[num_nodes_-1], total_occ, 1.0e-9);
225  // If needed, generate a transform at the root.
226  if (generate_xform[num_nodes_-1] && regclasses[num_nodes_-1] < 0) {
227  KALDI_ASSERT(node_occupancies[num_nodes_-1] >= min_count);
228  regclasses[num_nodes_-1] = num_regclasses++;
229  }
230 
231  // Initialize the accumulators for output stats.
232  // NOTE: memory is allocated here; be careful to delete the pointers
233  stats_out->resize(num_regclasses);
234  for (int32 r = 0; r < num_regclasses; r++) {
235  (*stats_out)[r] = new AffineXformStats();
236  (*stats_out)[r]->Init(stats_in[0]->dim_, stats_in[0]->G_.size());
237  }
238 
239  // Finally go through the tree again and add stats
240  vector<int32> active_parents;
241  for (int32 bclass = 0; bclass < num_baseclasses_; bclass++) {
242  if (generate_xform[bclass]) {
243  KALDI_ASSERT(regclasses[bclass] > -1);
244  (*stats_out)[regclasses[bclass]]->CopyStats(*(stats_in[bclass]));
245  (*regclasses_out)[bclass] = regclasses[bclass];
246  if (GetActiveParents(bclass, parents_, generate_xform, &active_parents)) {
247  // Some other baseclass has less count
248  for (vector<int32>::const_iterator p = active_parents.begin(),
249  endp = active_parents.end(); p != endp; ++p) {
250  KALDI_ASSERT(regclasses[*p] > -1);
251  (*stats_out)[regclasses[*p]]->Add(*(stats_in[bclass]));
252  }
253  }
254  } else {
255  bool found = GetActiveParents(bclass, parents_, generate_xform,
256  &active_parents);
257  KALDI_ASSERT(found); // must have active parents
258  for (vector<int32>::const_iterator p = active_parents.begin(),
259  endp = active_parents.end(); p != endp; ++p) {
260  KALDI_ASSERT(regclasses[*p] > -1);
261  (*stats_out)[regclasses[*p]]->Add(*(stats_in[bclass]));
262  }
263  (*regclasses_out)[bclass] = regclasses[active_parents[0]];
264  }
265  }
266 
267  KALDI_ASSERT(num_regclasses <= num_baseclasses_);
268  return true;
269 }
270 
271 void RegressionTree::Write(std::ostream &out, bool binary) const {
272  WriteToken(out, binary, "<REGTREE>");
273  WriteToken(out, binary, "<NUMNODES>");
274  WriteBasicType(out, binary, num_nodes_);
275  if (!binary) out << '\n';
276  WriteToken(out, binary, "<PARENTS>");
277  if (!binary) out << '\n';
278  WriteIntegerVector(out, binary, parents_);
279  WriteToken(out, binary, "</PARENTS>");
280  if (!binary) out << '\n';
281 
282  WriteToken(out, binary, "<BASECLASSES>");
283  if (!binary) out << '\n';
284  WriteToken(out, binary, "<NUMBASECLASSES>");
285  WriteBasicType(out, binary, num_baseclasses_);
286  if (!binary) out << '\n';
287  for (int32 bclass = 0; bclass < num_baseclasses_; bclass++) {
288  WriteToken(out, binary, "<CLASS>");
289  WriteBasicType(out, binary, bclass);
290  WriteBasicType(out, binary, static_cast<int32>(
291  baseclasses_[bclass].size()));
292  if (!binary) out << '\n';
293  for (vector< pair<int32, int32> >::const_iterator
294  it = baseclasses_[bclass].begin(), end = baseclasses_[bclass].end();
295  it != end; it++) {
296  WriteBasicType(out, binary, it->first);
297  WriteBasicType(out, binary, it->second);
298  if (!binary) out << '\n';
299  }
300 
301  WriteToken(out, binary, "</CLASS>");
302  if (!binary) out << '\n';
303  }
304  WriteToken(out, binary, "</BASECLASSES>");
305  if (!binary) out << '\n';
306 }
307 
308 void RegressionTree::Read(std::istream &in, bool binary,
309  const AmDiagGmm &am) {
310  int32 total_gauss = 0;
311  ExpectToken(in, binary, "<REGTREE>");
312  ExpectToken(in, binary, "<NUMNODES>");
313  ReadBasicType(in, binary, &num_nodes_);
315  parents_.resize(static_cast<size_t>(num_nodes_));
316  ExpectToken(in, binary, "<PARENTS>");
317  ReadIntegerVector(in, binary, &parents_);
318  ExpectToken(in, binary, "</PARENTS>");
319 
320  ExpectToken(in, binary, "<BASECLASSES>");
321  ExpectToken(in, binary, "<NUMBASECLASSES>");
322  ReadBasicType(in, binary, &num_baseclasses_);
324  baseclasses_.resize(static_cast<size_t>(num_baseclasses_));
325  for (int32 bclass = 0; bclass < num_baseclasses_; bclass++) {
326  ExpectToken(in, binary, "<CLASS>");
327  int32 class_id, num_comp, pdf_id, gauss_id;
328  ReadBasicType(in, binary, &class_id);
329  ReadBasicType(in, binary, &num_comp);
330  KALDI_ASSERT(class_id == bclass && num_comp > 0);
331  total_gauss += num_comp;
332  baseclasses_[bclass].reserve(num_comp);
333 
334  for (int32 i = 0; i < num_comp; i++) {
335  ReadBasicType(in, binary, &pdf_id);
336  ReadBasicType(in, binary, &gauss_id);
337  KALDI_ASSERT(pdf_id >= 0 && gauss_id >= 0);
338  baseclasses_[bclass].push_back(std::make_pair(pdf_id, gauss_id));
339  }
340 
341  ExpectToken(in, binary, "</CLASS>");
342  }
343  ExpectToken(in, binary, "</BASECLASSES>");
344 
345  if (total_gauss != am.NumGauss())
346  KALDI_ERR << "Expecting " << am.NumGauss() << " Gaussians in "
347  "regression tree, found " << total_gauss;
348  MakeGauss2Bclass(am);
349 }
350 
352  gauss2bclass_.resize(am.NumPdfs());
353  for (int32 pdf_index = 0, num_pdfs = am.NumPdfs(); pdf_index < num_pdfs;
354  ++pdf_index) {
355  gauss2bclass_[pdf_index].resize(am.NumGaussInPdf(pdf_index));
356  }
357 
358  int32 total_gauss = 0;
359  for (int32 bclass_index = 0; bclass_index < num_baseclasses_;
360  ++bclass_index) {
361  vector< pair<int32, int32> >::const_iterator itr =
362  baseclasses_[bclass_index].begin(), end =
363  baseclasses_[bclass_index].end();
364  for (; itr != end; ++itr) {
365  KALDI_ASSERT(itr->first < am.NumPdfs() &&
366  itr->second < am.NumGaussInPdf(itr->first));
367  gauss2bclass_[itr->first][itr->second] = bclass_index;
368  total_gauss++;
369  }
370  }
371 
372  if (total_gauss != am.NumGauss())
373  KALDI_ERR << "Expecting " << am.NumGauss() << " Gaussians in "
374  "regression tree, found " << total_gauss;
375 }
376 
377 } // namespace kaldi
378 
void Read(std::istream &in, bool binary, const AmDiagGmm &am)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
int32 NumGauss() const
Definition: am-diag-gmm.cc:72
std::vector< std::vector< int32 > > gauss2bclass_
Mapping from (pdf, gaussian) indices to baseclasses.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
int32 num_baseclasses_
Number of leaf nodes.
kaldi::int32 int32
std::vector< std::vector< std::pair< int32, int32 > > > baseclasses_
Each baseclass (leaf of regression tree) is a vector of Gaussian indices.
int32 NumGaussInPdf(int32 pdf_index) const
Definition: am-diag-gmm.h:113
std::vector< int32 > parents_
For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node.
void MakeGauss2Bclass(const AmDiagGmm &am)
void Write(std::ostream &out, bool binary) const
bool GatherStats(const std::vector< AffineXformStats *> &stats_in, double min_count, std::vector< int32 > *regclasses_out, std::vector< AffineXformStats *> *stats_out) const
Parses the regression tree and finds the nodes whose occupancies (read from stats_in) are greater tha...
void AddVec2(const Real alpha, const VectorBase< Real > &v)
Add vector : *this = *this + alpha * rv^2 [element-wise squaring].
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
Definition: io-funcs-inl.h:232
static bool GetActiveParents(int32 node, const vector< int32 > &parents, const vector< bool > &is_active, vector< int32 > *active_parents_out)
void BuildTree(const Vector< BaseFloat > &state_occs, const std::vector< int32 > &sil_indices, const AmDiagGmm &am, int32 max_clusters)
Top-down clustering of the Gaussians in a model based on their means.
void GetGaussianVariance(int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
Definition: am-diag-gmm.h:138
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
const Vector< BaseFloat > & weights() const
Definition: diag-gmm.h:178
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
void Scale(Real alpha)
Multiplies all elements by this constant.
int32 Dim() const
Definition: am-diag-gmm.h:79
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
Definition: am-diag-gmm.h:119
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void GetGaussianMean(int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
Definition: am-diag-gmm.h:131
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
BaseFloat TreeCluster(const std::vector< Clusterable *> &points, int32 max_clust, std::vector< Clusterable *> *clusters_out, std::vector< int32 > *assignments_out, std::vector< int32 > *clust_assignments_out, int32 *num_leaves_out, TreeClusterOptions cfg)
TreeCluster is a top-down clustering algorithm, using a binary tree (not necessarily balanced)...
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
bool IsSortedAndUniq(const std::vector< T > &vec)
Returns true if the vector is sorted and contains each element only once.
Definition: stl-utils.h:63
int32 num_nodes_
Total (non-leaf+leaf) nodes.