A regression tree is a clustering of Gaussian densities in an acoustic model, such that the group of Gaussians at each node of the tree are transformed by the same transform. More...
#include <regression-tree.h>
Public Member Functions | |
RegressionTree () | |
void | BuildTree (const Vector< BaseFloat > &state_occs, const std::vector< int32 > &sil_indices, const AmDiagGmm &am, int32 max_clusters) |
Top-down clustering of the Gaussians in a model based on their means. More... | |
bool | GatherStats (const std::vector< AffineXformStats *> &stats_in, double min_count, std::vector< int32 > *regclasses_out, std::vector< AffineXformStats *> *stats_out) const |
Parses the regression tree and finds the nodes whose occupancies (read from stats_in) are greater than min_count. More... | |
void | Write (std::ostream &out, bool binary) const |
void | Read (std::istream &in, bool binary, const AmDiagGmm &am) |
int32 | NumBaseclasses () const |
Accessors (const) More... | |
const std::vector< std::pair< int32, int32 > > & | GetBaseclass (int32 bclass) const |
int32 | Gauss2BaseclassId (size_t pdf_id, size_t gauss_id) const |
Private Member Functions | |
void | MakeGauss2Bclass (const AmDiagGmm &am) |
KALDI_DISALLOW_COPY_AND_ASSIGN (RegressionTree) | |
Private Attributes | |
int32 | num_nodes_ |
Total (non-leaf+leaf) nodes. More... | |
std::vector< int32 > | parents_ |
For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node. More... | |
int32 | num_baseclasses_ |
Number of leaf nodes. More... | |
std::vector< std::vector< std::pair< int32, int32 > > > | baseclasses_ |
Each baseclass (leaf of regression tree) is a vector of Gaussian indices. More... | |
std::vector< std::vector< int32 > > | gauss2bclass_ |
Mapping from (pdf, gaussian) indices to baseclasses. More... | |
A regression tree is a clustering of Gaussian densities in an acoustic model, such that the group of Gaussians at each node of the tree are transformed by the same transform.
Each node is thus called a regression class.
Definition at line 41 of file regression-tree.h.
|
inline |
Definition at line 43 of file regression-tree.h.
References RegressionTree::BuildTree(), RegressionTree::GatherStats(), RegressionTree::Read(), and RegressionTree::Write().
void BuildTree | ( | const Vector< BaseFloat > & | state_occs, |
const std::vector< int32 > & | sil_indices, | ||
const AmDiagGmm & | am, | ||
int32 | max_clusters | ||
) |
Top-down clustering of the Gaussians in a model based on their means.
If sil_indices is nonempty, will put silence in a special class using a top-level split.
Definition at line 34 of file regression-tree.cc.
References VectorBase< Real >::AddVec2(), RegressionTree::baseclasses_, kaldi::DeletePointers(), AmDiagGmm::Dim(), RegressionTree::gauss2bclass_, AmDiagGmm::GetGaussianMean(), AmDiagGmm::GetGaussianVariance(), AmDiagGmm::GetPdf(), rnnlm::i, kaldi::IsSortedAndUniq(), rnnlm::j, KALDI_ASSERT, RegressionTree::num_baseclasses_, RegressionTree::num_nodes_, DiagGmm::NumGauss(), AmDiagGmm::NumGauss(), AmDiagGmm::NumPdfs(), RegressionTree::parents_, VectorBase< Real >::Scale(), kaldi::TreeCluster(), and DiagGmm::weights().
Referenced by main(), RegressionTree::RegressionTree(), UnitTestRegressionTree(), kaldi::UnitTestRegtreeFmllrDiagGmm(), and UnitTestRegtreeMllrDiagGmm().
bool GatherStats | ( | const std::vector< AffineXformStats *> & | stats_in, |
double | min_count, | ||
std::vector< int32 > * | regclasses_out, | ||
std::vector< AffineXformStats *> * | stats_out | ||
) | const |
Parses the regression tree and finds the nodes whose occupancies (read from stats_in) are greater than min_count.
The regclass_out vector has size equal to number of baseclasses, and contains the regression class index for each baseclass. The stats_out vector has size equal to number of regression classes. Return value is true if at least one regression class passed the count cutoff, false otherwise.
Definition at line 162 of file regression-tree.cc.
References kaldi::AssertEqual(), kaldi::DeletePointers(), kaldi::GetActiveParents(), KALDI_ASSERT, KALDI_WARN, RegressionTree::num_baseclasses_, RegressionTree::num_nodes_, and RegressionTree::parents_.
Referenced by RegressionTree::RegressionTree(), RegtreeMllrDiagGmmAccs::Update(), and RegtreeFmllrDiagGmmAccs::Update().
Definition at line 71 of file regression-tree.h.
References RegressionTree::gauss2bclass_.
Referenced by RegtreeMllrDiagGmmAccs::AccumulateForGaussian(), RegtreeFmllrDiagGmmAccs::AccumulateForGaussian(), RegtreeMllrDiagGmmAccs::AccumulateForGmm(), RegtreeFmllrDiagGmmAccs::AccumulateForGmm(), RegtreeMllrDiagGmm::GetTransformedMeans(), and DecodableAmDiagGmmRegtreeFmllr::LogLikelihoodZeroBased().
Definition at line 69 of file regression-tree.h.
References RegressionTree::baseclasses_.
Referenced by RegtreeMllrDiagGmm::TransformModel().
|
private |
|
private |
Definition at line 351 of file regression-tree.cc.
References RegressionTree::baseclasses_, RegressionTree::gauss2bclass_, KALDI_ASSERT, KALDI_ERR, RegressionTree::num_baseclasses_, AmDiagGmm::NumGauss(), AmDiagGmm::NumGaussInPdf(), and AmDiagGmm::NumPdfs().
Referenced by RegressionTree::Read().
|
inline |
Accessors (const)
Definition at line 68 of file regression-tree.h.
References RegressionTree::num_baseclasses_.
Referenced by RegtreeMllrDiagGmm::GetTransformedMeans(), main(), RegtreeMllrDiagGmm::TransformModel(), kaldi::UnitTestRegtreeFmllrDiagGmm(), and UnitTestRegtreeMllrDiagGmm().
Definition at line 308 of file regression-tree.cc.
References RegressionTree::baseclasses_, kaldi::ExpectToken(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, RegressionTree::MakeGauss2Bclass(), RegressionTree::num_baseclasses_, RegressionTree::num_nodes_, AmDiagGmm::NumGauss(), RegressionTree::parents_, kaldi::ReadBasicType(), and kaldi::ReadIntegerVector().
Referenced by main(), RegressionTree::RegressionTree(), and test_io().
void Write | ( | std::ostream & | out, |
bool | binary | ||
) | const |
Definition at line 271 of file regression-tree.cc.
References RegressionTree::baseclasses_, RegressionTree::num_baseclasses_, RegressionTree::num_nodes_, RegressionTree::parents_, kaldi::WriteBasicType(), kaldi::WriteIntegerVector(), and kaldi::WriteToken().
Referenced by main(), RegressionTree::RegressionTree(), and test_io().
Each baseclass (leaf of regression tree) is a vector of Gaussian indices.
Each Gaussian in the model is indexed by (pdf, gaussian) indices pair.
Definition at line 86 of file regression-tree.h.
Referenced by RegressionTree::BuildTree(), RegressionTree::GetBaseclass(), RegressionTree::MakeGauss2Bclass(), RegressionTree::Read(), and RegressionTree::Write().
|
private |
Mapping from (pdf, gaussian) indices to baseclasses.
Definition at line 88 of file regression-tree.h.
Referenced by RegressionTree::BuildTree(), RegressionTree::Gauss2BaseclassId(), and RegressionTree::MakeGauss2Bclass().
|
private |
Number of leaf nodes.
Definition at line 83 of file regression-tree.h.
Referenced by RegressionTree::BuildTree(), RegressionTree::GatherStats(), RegressionTree::MakeGauss2Bclass(), RegressionTree::NumBaseclasses(), RegressionTree::Read(), and RegressionTree::Write().
|
private |
Total (non-leaf+leaf) nodes.
Definition at line 76 of file regression-tree.h.
Referenced by RegressionTree::BuildTree(), RegressionTree::GatherStats(), RegressionTree::Read(), and RegressionTree::Write().
|
private |
For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node.
parents_[i] > i, except for the top node (last-numbered one), for which parents_[i] == i.
Definition at line 82 of file regression-tree.h.
Referenced by RegressionTree::BuildTree(), RegressionTree::GatherStats(), RegressionTree::Read(), and RegressionTree::Write().