regression-tree.h
Go to the documentation of this file.
1 // transform/regression-tree.h
2 
3 // Copyright 2009-2011 Saarland University
4 // Author: Arnab Ghoshal
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 
22 #ifndef KALDI_TRANSFORM_REGRESSION_TREE_H_
23 #define KALDI_TRANSFORM_REGRESSION_TREE_H_
24 
25 #include <utility>
26 #include <vector>
27 
28 #include "base/kaldi-common.h"
29 #include "tree/cluster-utils.h"
30 #include "gmm/am-diag-gmm.h"
32 
33 namespace kaldi {
34 
42  public:
44 
48  void BuildTree(const Vector<BaseFloat> &state_occs,
49  const std::vector<int32> &sil_indices,
50  const AmDiagGmm &am,
51  int32 max_clusters);
52 
59  bool GatherStats(const std::vector<AffineXformStats*> &stats_in,
60  double min_count,
61  std::vector<int32> *regclasses_out,
62  std::vector<AffineXformStats*> *stats_out) const;
63 
64  void Write(std::ostream &out, bool binary) const;
65  void Read(std::istream &in, bool binary, const AmDiagGmm &am);
66 
69  const std::vector< std::pair<int32, int32> >& GetBaseclass(int32 bclass)
70  const { return baseclasses_[bclass]; }
71  int32 Gauss2BaseclassId(size_t pdf_id, size_t gauss_id) const {
72  return gauss2bclass_[pdf_id][gauss_id];
73  }
74 
75  private:
77 
82  std::vector<int32> parents_;
84  std::vector< std::vector< std::pair<int32, int32> > > baseclasses_;
88  std::vector< std::vector<int32> > gauss2bclass_;
89 
90  void MakeGauss2Bclass(const AmDiagGmm &am);
91 
92  // Cannot have copy constructor and assigment operator
94 
95 };
96 
97 } // namespace kaldi
98 
99 #endif // KALDI_TRANSFORM_REGRESSION_TREE_H_
void Read(std::istream &in, bool binary, const AmDiagGmm &am)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 Gauss2BaseclassId(size_t pdf_id, size_t gauss_id) const
std::vector< std::vector< int32 > > gauss2bclass_
Mapping from (pdf, gaussian) indices to baseclasses.
const std::vector< std::pair< int32, int32 > > & GetBaseclass(int32 bclass) const
int32 num_baseclasses_
Number of leaf nodes.
kaldi::int32 int32
std::vector< std::vector< std::pair< int32, int32 > > > baseclasses_
Each baseclass (leaf of regression tree) is a vector of Gaussian indices.
std::vector< int32 > parents_
For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node.
void MakeGauss2Bclass(const AmDiagGmm &am)
void Write(std::ostream &out, bool binary) const
bool GatherStats(const std::vector< AffineXformStats *> &stats_in, double min_count, std::vector< int32 > *regclasses_out, std::vector< AffineXformStats *> *stats_out) const
Parses the regression tree and finds the nodes whose occupancies (read from stats_in) are greater tha...
int32 NumBaseclasses() const
Accessors (const)
void BuildTree(const Vector< BaseFloat > &state_occs, const std::vector< int32 > &sil_indices, const AmDiagGmm &am, int32 max_clusters)
Top-down clustering of the Gaussians in a model based on their means.
A regression tree is a clustering of Gaussian densities in an acoustic model, such that the group of ...
A class representing a vector.
Definition: kaldi-vector.h:406
KALDI_DISALLOW_COPY_AND_ASSIGN(RegressionTree)
int32 num_nodes_
Total (non-leaf+leaf) nodes.