RegressionTree Class Reference

A regression tree is a clustering of Gaussian densities in an acoustic model, such that the group of Gaussians at each node of the tree are transformed by the same transform. More...

#include <regression-tree.h>

Collaboration diagram for RegressionTree:

Public Member Functions

 RegressionTree ()
 
void BuildTree (const Vector< BaseFloat > &state_occs, const std::vector< int32 > &sil_indices, const AmDiagGmm &am, int32 max_clusters)
 Top-down clustering of the Gaussians in a model based on their means. More...
 
bool GatherStats (const std::vector< AffineXformStats *> &stats_in, double min_count, std::vector< int32 > *regclasses_out, std::vector< AffineXformStats *> *stats_out) const
 Parses the regression tree and finds the nodes whose occupancies (read from stats_in) are greater than min_count. More...
 
void Write (std::ostream &out, bool binary) const
 
void Read (std::istream &in, bool binary, const AmDiagGmm &am)
 
int32 NumBaseclasses () const
 Accessors (const) More...
 
const std::vector< std::pair< int32, int32 > > & GetBaseclass (int32 bclass) const
 
int32 Gauss2BaseclassId (size_t pdf_id, size_t gauss_id) const
 

Private Member Functions

void MakeGauss2Bclass (const AmDiagGmm &am)
 
 KALDI_DISALLOW_COPY_AND_ASSIGN (RegressionTree)
 

Private Attributes

int32 num_nodes_
 Total (non-leaf+leaf) nodes. More...
 
std::vector< int32parents_
 For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node. More...
 
int32 num_baseclasses_
 Number of leaf nodes. More...
 
std::vector< std::vector< std::pair< int32, int32 > > > baseclasses_
 Each baseclass (leaf of regression tree) is a vector of Gaussian indices. More...
 
std::vector< std::vector< int32 > > gauss2bclass_
 Mapping from (pdf, gaussian) indices to baseclasses. More...
 

Detailed Description

A regression tree is a clustering of Gaussian densities in an acoustic model, such that the group of Gaussians at each node of the tree are transformed by the same transform.

Each node is thus called a regression class.

Definition at line 41 of file regression-tree.h.

Constructor & Destructor Documentation

◆ RegressionTree()

Member Function Documentation

◆ BuildTree()

void BuildTree ( const Vector< BaseFloat > &  state_occs,
const std::vector< int32 > &  sil_indices,
const AmDiagGmm am,
int32  max_clusters 
)

Top-down clustering of the Gaussians in a model based on their means.

If sil_indices is nonempty, will put silence in a special class using a top-level split.

Definition at line 34 of file regression-tree.cc.

References VectorBase< Real >::AddVec2(), RegressionTree::baseclasses_, kaldi::DeletePointers(), AmDiagGmm::Dim(), RegressionTree::gauss2bclass_, AmDiagGmm::GetGaussianMean(), AmDiagGmm::GetGaussianVariance(), AmDiagGmm::GetPdf(), rnnlm::i, kaldi::IsSortedAndUniq(), rnnlm::j, KALDI_ASSERT, RegressionTree::num_baseclasses_, RegressionTree::num_nodes_, DiagGmm::NumGauss(), AmDiagGmm::NumGauss(), AmDiagGmm::NumPdfs(), RegressionTree::parents_, VectorBase< Real >::Scale(), kaldi::TreeCluster(), and DiagGmm::weights().

Referenced by main(), RegressionTree::RegressionTree(), UnitTestRegressionTree(), kaldi::UnitTestRegtreeFmllrDiagGmm(), and UnitTestRegtreeMllrDiagGmm().

37  {
38  KALDI_ASSERT(IsSortedAndUniq(sil_indices));
39  int32 dim = am.Dim(),
40  num_pdfs = static_cast<int32>(am.NumPdfs());
41  vector<Clusterable*> gauss_means;
42  // For each Gaussianin the model, the pair of (pdf, gaussian) indices.
43  vector< pair<int32, int32> > gauss_indices;
44  Vector<BaseFloat> tmp_mean(dim);
45  Vector<BaseFloat> tmp_var(dim);
46  BaseFloat var_floor = 0.01;
47 
48  gauss2bclass_.resize(num_pdfs);
49  gauss_means.reserve(am.NumGauss()); // NOT resize, uses push_back
50  gauss_indices.reserve(am.NumGauss()); // NOT resize, uses push_back
51 
52  for (int32 pdf_index = 0; pdf_index < num_pdfs; pdf_index++) {
53  gauss2bclass_[pdf_index].resize(am.GetPdf(pdf_index).NumGauss());
54  for (int32 num_gauss = am.GetPdf(pdf_index).NumGauss(),
55  gauss_index = 0; gauss_index < num_gauss; ++gauss_index) {
56  // don't include silence while clustering...
57  if (std::binary_search(sil_indices.begin(), sil_indices.end(), pdf_index))
58  continue;
59 
60  am.GetGaussianMean(pdf_index, gauss_index, &tmp_mean);
61  am.GetGaussianVariance(pdf_index, gauss_index, &tmp_var);
62  tmp_var.AddVec2(1.0, tmp_mean); // make it x^2 stats.
63  BaseFloat this_weight = state_occs(pdf_index) *
64  (am.GetPdf(pdf_index).weights())(gauss_index);
65  tmp_mean.Scale(this_weight);
66  tmp_var.Scale(this_weight);
67  gauss_indices.push_back(std::make_pair(pdf_index, gauss_index));
68  gauss_means.push_back(new GaussClusterable(tmp_mean, tmp_var, var_floor,
69  this_weight));
70  }
71  }
72 
73  vector<int32> leaves;
74  vector<int32> clust_parents;
75  int32 num_leaves;
76  TreeClusterOptions opts; // Use default options or get from somewhere else
77  TreeCluster(gauss_means,
78  (sil_indices.empty() ? max_clusters : max_clusters-1),
79  NULL /* clusters not needed */,
80  &leaves, &clust_parents, &num_leaves, opts);
81 
82  if (sil_indices.empty()) { // no special treatment of silence...
83  num_baseclasses_ = static_cast<int32>(num_leaves);
84  baseclasses_.resize(num_leaves);
85  parents_.resize(clust_parents.size());
86  for (int32 i = 0, num_nodes = clust_parents.size(); i < num_nodes; i++) {
87  parents_[i] = static_cast<int32>(clust_parents[i]);
88  }
89  num_nodes_ = static_cast<int32>(clust_parents.size());
90  for (int32 i = 0; i < static_cast<int32>(gauss_indices.size()); i++) {
91  baseclasses_[leaves[i]].push_back(gauss_indices[i]);
92  gauss2bclass_[gauss_indices[i].first][gauss_indices[i].second] = leaves[i];
93  }
94  } else {
95  // separate top-level split between silence and speech...
96  // silence is node zero and new parent is last-numbered one.
97  num_baseclasses_ = static_cast<int32>(num_leaves+1); // +1 to include 0 == silence
98  baseclasses_.resize(num_leaves+1); // +1 to include 0 == silence
99  parents_.resize(clust_parents.size()+2); // +1 to include 0 == silence, +parent.
100 
101  int32 top_node = clust_parents.size() + 1;
102  for (int32 i = 0; i < static_cast<int32>(clust_parents.size()); i++) {
103  parents_[i+1] = clust_parents[i]+1; // handle offsets
104  }
105  parents_[0] = top_node;
106  parents_[clust_parents.size()] = top_node; // old top node's parent is new top node.
107  parents_[top_node] = top_node; // being own parent is sign of being top node.
108 
109  num_nodes_ = static_cast<int32>(clust_parents.size() + 2);
110  // Assign nonsilence Gaussians to their assigned classes (add one
111  // to all leaf indices, make room for silence class).
112  for (int32 i = 0; i < static_cast<int32>(gauss_indices.size()); i++) {
113  baseclasses_[leaves[i]+1].push_back(gauss_indices[i]);
114  gauss2bclass_[gauss_indices[i].first][gauss_indices[i].second] = leaves[i]+1;
115  }
116  // Assign silence Gaussians to zero'th baseclass.
117  for (int32 i = 0; i < static_cast<int32>(sil_indices.size()); i++) {
118  int32 pdf_index = sil_indices[i];
119  for (int32 j = 0; j < am.GetPdf(pdf_index).NumGauss(); j++) {
120  baseclasses_[0].push_back(std::make_pair(pdf_index, j));
121  gauss2bclass_[pdf_index][j] = 0;
122  }
123  }
124  }
125  DeletePointers(&gauss_means);
126 }
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
std::vector< std::vector< int32 > > gauss2bclass_
Mapping from (pdf, gaussian) indices to baseclasses.
int32 num_baseclasses_
Number of leaf nodes.
kaldi::int32 int32
std::vector< std::vector< std::pair< int32, int32 > > > baseclasses_
Each baseclass (leaf of regression tree) is a vector of Gaussian indices.
std::vector< int32 > parents_
For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node.
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
BaseFloat TreeCluster(const std::vector< Clusterable *> &points, int32 max_clust, std::vector< Clusterable *> *clusters_out, std::vector< int32 > *assignments_out, std::vector< int32 > *clust_assignments_out, int32 *num_leaves_out, TreeClusterOptions cfg)
TreeCluster is a top-down clustering algorithm, using a binary tree (not necessarily balanced)...
bool IsSortedAndUniq(const std::vector< T > &vec)
Returns true if the vector is sorted and contains each element only once.
Definition: stl-utils.h:63
int32 num_nodes_
Total (non-leaf+leaf) nodes.

◆ GatherStats()

bool GatherStats ( const std::vector< AffineXformStats *> &  stats_in,
double  min_count,
std::vector< int32 > *  regclasses_out,
std::vector< AffineXformStats *> *  stats_out 
) const

Parses the regression tree and finds the nodes whose occupancies (read from stats_in) are greater than min_count.

The regclass_out vector has size equal to number of baseclasses, and contains the regression class index for each baseclass. The stats_out vector has size equal to number of regression classes. Return value is true if at least one regression class passed the count cutoff, false otherwise.

Definition at line 162 of file regression-tree.cc.

References kaldi::AssertEqual(), kaldi::DeletePointers(), kaldi::GetActiveParents(), KALDI_ASSERT, KALDI_WARN, RegressionTree::num_baseclasses_, RegressionTree::num_nodes_, and RegressionTree::parents_.

Referenced by RegressionTree::RegressionTree(), RegtreeMllrDiagGmmAccs::Update(), and RegtreeFmllrDiagGmmAccs::Update().

165  {
166  KALDI_ASSERT(static_cast<int32>(stats_in.size()) == num_baseclasses_);
167  if (static_cast<int32>(regclasses_out->size()) != num_baseclasses_)
168  regclasses_out->resize(static_cast<size_t>(num_baseclasses_), -1);
169  if (num_baseclasses_ == 1) // Only root node in tree
170  KALDI_ASSERT(num_nodes_ == 1);
171 
172  double total_occ = 0.0;
173  int32 num_regclasses = 0;
174  vector<double> node_occupancies(num_nodes_, 0.0);
175  vector<bool> generate_xform(num_nodes_, false);
176  vector<int32> regclasses(num_nodes_, -1);
177 
178  // Go through the leaves (baseclasses) and find where to generate transforms
179  for (int32 bclass = 0; bclass < num_baseclasses_; bclass++) {
180  total_occ += stats_in[bclass]->beta_;
181  node_occupancies[bclass] = stats_in[bclass]->beta_;
182  if (num_baseclasses_ != 1) { // Don't count twice if tree only has root.
183  node_occupancies[parents_[bclass]] += node_occupancies[bclass];
184  }
185  if (node_occupancies[bclass] < min_count) {
186  // Not enough count, so pass the responsibility to the parent.
187  generate_xform[bclass] = false;
188  generate_xform[parents_[bclass]] = true;
189  } else { // generate at the leaf level.
190  generate_xform[bclass] = true;
191  regclasses[bclass] = num_regclasses++;
192  }
193  }
194  // Check whether there is enough data for the single global transform (at
195  // the root of the regression tree). If not, no transforms will be computed.
196  if (total_occ < min_count) {
197  // Make all baseclasses use the unit transform at the root.
198  for (int32 bclass = 0; bclass < num_baseclasses_; bclass++) {
199  (*regclasses_out)[bclass] = 0;
200  }
201  DeletePointers(stats_out);
202  stats_out->clear();
203  KALDI_WARN << "Not enough data to compute global transform. Occupancy at "
204  << "root = " << total_occ << "<" << min_count;
205  return false;
206  }
207 
208  // Now go through the non-leaf nodes and find where to generate transforms.
209  // Iterates only till num_nodes_ - 1 so that it doesn't count root twice.
210  for (int32 node = num_baseclasses_; node < num_nodes_ - 1; node++) {
211  node_occupancies[parents_[node]] += node_occupancies[node];
212  // Only bother with generating transforms if a child asked for it.
213  if (generate_xform[node]) {
214  if (node_occupancies[node] < min_count) {
215  // Not enough count, so pass the responsibility to the parent.
216  generate_xform[node] = false;
217  generate_xform[parents_[node]] = true;
218  } else { // transform will be generated at this level.
219  regclasses[node] = num_regclasses++;
220  }
221  }
222  }
223 
224  AssertEqual(node_occupancies[num_nodes_-1], total_occ, 1.0e-9);
225  // If needed, generate a transform at the root.
226  if (generate_xform[num_nodes_-1] && regclasses[num_nodes_-1] < 0) {
227  KALDI_ASSERT(node_occupancies[num_nodes_-1] >= min_count);
228  regclasses[num_nodes_-1] = num_regclasses++;
229  }
230 
231  // Initialize the accumulators for output stats.
232  // NOTE: memory is allocated here; be careful to delete the pointers
233  stats_out->resize(num_regclasses);
234  for (int32 r = 0; r < num_regclasses; r++) {
235  (*stats_out)[r] = new AffineXformStats();
236  (*stats_out)[r]->Init(stats_in[0]->dim_, stats_in[0]->G_.size());
237  }
238 
239  // Finally go through the tree again and add stats
240  vector<int32> active_parents;
241  for (int32 bclass = 0; bclass < num_baseclasses_; bclass++) {
242  if (generate_xform[bclass]) {
243  KALDI_ASSERT(regclasses[bclass] > -1);
244  (*stats_out)[regclasses[bclass]]->CopyStats(*(stats_in[bclass]));
245  (*regclasses_out)[bclass] = regclasses[bclass];
246  if (GetActiveParents(bclass, parents_, generate_xform, &active_parents)) {
247  // Some other baseclass has less count
248  for (vector<int32>::const_iterator p = active_parents.begin(),
249  endp = active_parents.end(); p != endp; ++p) {
250  KALDI_ASSERT(regclasses[*p] > -1);
251  (*stats_out)[regclasses[*p]]->Add(*(stats_in[bclass]));
252  }
253  }
254  } else {
255  bool found = GetActiveParents(bclass, parents_, generate_xform,
256  &active_parents);
257  KALDI_ASSERT(found); // must have active parents
258  for (vector<int32>::const_iterator p = active_parents.begin(),
259  endp = active_parents.end(); p != endp; ++p) {
260  KALDI_ASSERT(regclasses[*p] > -1);
261  (*stats_out)[regclasses[*p]]->Add(*(stats_in[bclass]));
262  }
263  (*regclasses_out)[bclass] = regclasses[active_parents[0]];
264  }
265  }
266 
267  KALDI_ASSERT(num_regclasses <= num_baseclasses_);
268  return true;
269 }
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
int32 num_baseclasses_
Number of leaf nodes.
kaldi::int32 int32
std::vector< int32 > parents_
For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node.
static bool GetActiveParents(int32 node, const vector< int32 > &parents, const vector< bool > &is_active, vector< int32 > *active_parents_out)
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
int32 num_nodes_
Total (non-leaf+leaf) nodes.

◆ Gauss2BaseclassId()

int32 Gauss2BaseclassId ( size_t  pdf_id,
size_t  gauss_id 
) const
inline

◆ GetBaseclass()

const std::vector< std::pair<int32, int32> >& GetBaseclass ( int32  bclass) const
inline

Definition at line 69 of file regression-tree.h.

References RegressionTree::baseclasses_.

Referenced by RegtreeMllrDiagGmm::TransformModel().

70  { return baseclasses_[bclass]; }
std::vector< std::vector< std::pair< int32, int32 > > > baseclasses_
Each baseclass (leaf of regression tree) is a vector of Gaussian indices.

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( RegressionTree  )
private

◆ MakeGauss2Bclass()

void MakeGauss2Bclass ( const AmDiagGmm am)
private

Definition at line 351 of file regression-tree.cc.

References RegressionTree::baseclasses_, RegressionTree::gauss2bclass_, KALDI_ASSERT, KALDI_ERR, RegressionTree::num_baseclasses_, AmDiagGmm::NumGauss(), AmDiagGmm::NumGaussInPdf(), and AmDiagGmm::NumPdfs().

Referenced by RegressionTree::Read().

351  {
352  gauss2bclass_.resize(am.NumPdfs());
353  for (int32 pdf_index = 0, num_pdfs = am.NumPdfs(); pdf_index < num_pdfs;
354  ++pdf_index) {
355  gauss2bclass_[pdf_index].resize(am.NumGaussInPdf(pdf_index));
356  }
357 
358  int32 total_gauss = 0;
359  for (int32 bclass_index = 0; bclass_index < num_baseclasses_;
360  ++bclass_index) {
361  vector< pair<int32, int32> >::const_iterator itr =
362  baseclasses_[bclass_index].begin(), end =
363  baseclasses_[bclass_index].end();
364  for (; itr != end; ++itr) {
365  KALDI_ASSERT(itr->first < am.NumPdfs() &&
366  itr->second < am.NumGaussInPdf(itr->first));
367  gauss2bclass_[itr->first][itr->second] = bclass_index;
368  total_gauss++;
369  }
370  }
371 
372  if (total_gauss != am.NumGauss())
373  KALDI_ERR << "Expecting " << am.NumGauss() << " Gaussians in "
374  "regression tree, found " << total_gauss;
375 }
std::vector< std::vector< int32 > > gauss2bclass_
Mapping from (pdf, gaussian) indices to baseclasses.
int32 num_baseclasses_
Number of leaf nodes.
kaldi::int32 int32
std::vector< std::vector< std::pair< int32, int32 > > > baseclasses_
Each baseclass (leaf of regression tree) is a vector of Gaussian indices.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ NumBaseclasses()

int32 NumBaseclasses ( ) const
inline

◆ Read()

void Read ( std::istream &  in,
bool  binary,
const AmDiagGmm am 
)

Definition at line 308 of file regression-tree.cc.

References RegressionTree::baseclasses_, kaldi::ExpectToken(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, RegressionTree::MakeGauss2Bclass(), RegressionTree::num_baseclasses_, RegressionTree::num_nodes_, AmDiagGmm::NumGauss(), RegressionTree::parents_, kaldi::ReadBasicType(), and kaldi::ReadIntegerVector().

Referenced by main(), RegressionTree::RegressionTree(), and test_io().

309  {
310  int32 total_gauss = 0;
311  ExpectToken(in, binary, "<REGTREE>");
312  ExpectToken(in, binary, "<NUMNODES>");
313  ReadBasicType(in, binary, &num_nodes_);
315  parents_.resize(static_cast<size_t>(num_nodes_));
316  ExpectToken(in, binary, "<PARENTS>");
317  ReadIntegerVector(in, binary, &parents_);
318  ExpectToken(in, binary, "</PARENTS>");
319 
320  ExpectToken(in, binary, "<BASECLASSES>");
321  ExpectToken(in, binary, "<NUMBASECLASSES>");
322  ReadBasicType(in, binary, &num_baseclasses_);
324  baseclasses_.resize(static_cast<size_t>(num_baseclasses_));
325  for (int32 bclass = 0; bclass < num_baseclasses_; bclass++) {
326  ExpectToken(in, binary, "<CLASS>");
327  int32 class_id, num_comp, pdf_id, gauss_id;
328  ReadBasicType(in, binary, &class_id);
329  ReadBasicType(in, binary, &num_comp);
330  KALDI_ASSERT(class_id == bclass && num_comp > 0);
331  total_gauss += num_comp;
332  baseclasses_[bclass].reserve(num_comp);
333 
334  for (int32 i = 0; i < num_comp; i++) {
335  ReadBasicType(in, binary, &pdf_id);
336  ReadBasicType(in, binary, &gauss_id);
337  KALDI_ASSERT(pdf_id >= 0 && gauss_id >= 0);
338  baseclasses_[bclass].push_back(std::make_pair(pdf_id, gauss_id));
339  }
340 
341  ExpectToken(in, binary, "</CLASS>");
342  }
343  ExpectToken(in, binary, "</BASECLASSES>");
344 
345  if (total_gauss != am.NumGauss())
346  KALDI_ERR << "Expecting " << am.NumGauss() << " Gaussians in "
347  "regression tree, found " << total_gauss;
348  MakeGauss2Bclass(am);
349 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
int32 num_baseclasses_
Number of leaf nodes.
kaldi::int32 int32
std::vector< std::vector< std::pair< int32, int32 > > > baseclasses_
Each baseclass (leaf of regression tree) is a vector of Gaussian indices.
std::vector< int32 > parents_
For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node.
void MakeGauss2Bclass(const AmDiagGmm &am)
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
Definition: io-funcs-inl.h:232
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 num_nodes_
Total (non-leaf+leaf) nodes.

◆ Write()

void Write ( std::ostream &  out,
bool  binary 
) const

Definition at line 271 of file regression-tree.cc.

References RegressionTree::baseclasses_, RegressionTree::num_baseclasses_, RegressionTree::num_nodes_, RegressionTree::parents_, kaldi::WriteBasicType(), kaldi::WriteIntegerVector(), and kaldi::WriteToken().

Referenced by main(), RegressionTree::RegressionTree(), and test_io().

271  {
272  WriteToken(out, binary, "<REGTREE>");
273  WriteToken(out, binary, "<NUMNODES>");
274  WriteBasicType(out, binary, num_nodes_);
275  if (!binary) out << '\n';
276  WriteToken(out, binary, "<PARENTS>");
277  if (!binary) out << '\n';
278  WriteIntegerVector(out, binary, parents_);
279  WriteToken(out, binary, "</PARENTS>");
280  if (!binary) out << '\n';
281 
282  WriteToken(out, binary, "<BASECLASSES>");
283  if (!binary) out << '\n';
284  WriteToken(out, binary, "<NUMBASECLASSES>");
285  WriteBasicType(out, binary, num_baseclasses_);
286  if (!binary) out << '\n';
287  for (int32 bclass = 0; bclass < num_baseclasses_; bclass++) {
288  WriteToken(out, binary, "<CLASS>");
289  WriteBasicType(out, binary, bclass);
290  WriteBasicType(out, binary, static_cast<int32>(
291  baseclasses_[bclass].size()));
292  if (!binary) out << '\n';
293  for (vector< pair<int32, int32> >::const_iterator
294  it = baseclasses_[bclass].begin(), end = baseclasses_[bclass].end();
295  it != end; it++) {
296  WriteBasicType(out, binary, it->first);
297  WriteBasicType(out, binary, it->second);
298  if (!binary) out << '\n';
299  }
300 
301  WriteToken(out, binary, "</CLASS>");
302  if (!binary) out << '\n';
303  }
304  WriteToken(out, binary, "</BASECLASSES>");
305  if (!binary) out << '\n';
306 }
int32 num_baseclasses_
Number of leaf nodes.
kaldi::int32 int32
std::vector< std::vector< std::pair< int32, int32 > > > baseclasses_
Each baseclass (leaf of regression tree) is a vector of Gaussian indices.
std::vector< int32 > parents_
For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
int32 num_nodes_
Total (non-leaf+leaf) nodes.

Member Data Documentation

◆ baseclasses_

std::vector< std::vector< std::pair<int32, int32> > > baseclasses_
private

Each baseclass (leaf of regression tree) is a vector of Gaussian indices.

Each Gaussian in the model is indexed by (pdf, gaussian) indices pair.

Definition at line 86 of file regression-tree.h.

Referenced by RegressionTree::BuildTree(), RegressionTree::GetBaseclass(), RegressionTree::MakeGauss2Bclass(), RegressionTree::Read(), and RegressionTree::Write().

◆ gauss2bclass_

std::vector< std::vector<int32> > gauss2bclass_
private

Mapping from (pdf, gaussian) indices to baseclasses.

Definition at line 88 of file regression-tree.h.

Referenced by RegressionTree::BuildTree(), RegressionTree::Gauss2BaseclassId(), and RegressionTree::MakeGauss2Bclass().

◆ num_baseclasses_

◆ num_nodes_

int32 num_nodes_
private

Total (non-leaf+leaf) nodes.

Definition at line 76 of file regression-tree.h.

Referenced by RegressionTree::BuildTree(), RegressionTree::GatherStats(), RegressionTree::Read(), and RegressionTree::Write().

◆ parents_

std::vector<int32> parents_
private

For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node.

parents_[i] > i, except for the top node (last-numbered one), for which parents_[i] == i.

Definition at line 82 of file regression-tree.h.

Referenced by RegressionTree::BuildTree(), RegressionTree::GatherStats(), RegressionTree::Read(), and RegressionTree::Write().


The documentation for this class was generated from the following files: