35 const std::vector<int32> &sil_indices,
41 vector<Clusterable*> gauss_means;
43 vector< pair<int32, int32> > gauss_indices;
50 gauss_indices.reserve(am.
NumGauss());
52 for (
int32 pdf_index = 0; pdf_index < num_pdfs; pdf_index++) {
55 gauss_index = 0; gauss_index < num_gauss; ++gauss_index) {
57 if (std::binary_search(sil_indices.begin(), sil_indices.end(), pdf_index))
63 BaseFloat this_weight = state_occs(pdf_index) *
65 tmp_mean.
Scale(this_weight);
66 tmp_var.
Scale(this_weight);
67 gauss_indices.push_back(std::make_pair(pdf_index, gauss_index));
74 vector<int32> clust_parents;
78 (sil_indices.empty() ? max_clusters : max_clusters-1),
80 &leaves, &clust_parents, &num_leaves, opts);
82 if (sil_indices.empty()) {
85 parents_.resize(clust_parents.size());
86 for (
int32 i = 0, num_nodes = clust_parents.size();
i < num_nodes;
i++) {
90 for (
int32 i = 0; i < static_cast<int32>(gauss_indices.size());
i++) {
99 parents_.resize(clust_parents.size()+2);
101 int32 top_node = clust_parents.size() + 1;
102 for (
int32 i = 0; i < static_cast<int32>(clust_parents.size());
i++) {
106 parents_[clust_parents.size()] = top_node;
112 for (
int32 i = 0; i < static_cast<int32>(gauss_indices.size());
i++) {
114 gauss2bclass_[gauss_indices[
i].first][gauss_indices[
i].second] = leaves[
i]+1;
117 for (
int32 i = 0; i < static_cast<int32>(sil_indices.size());
i++) {
118 int32 pdf_index = sil_indices[
i];
130 const vector<bool> &is_active,
131 vector<int32> *active_parents_out) {
133 KALDI_ASSERT(static_cast<size_t>(node) < parents.size());
134 active_parents_out->clear();
136 if (node == static_cast<int32> (parents.size() - 1)) {
137 if (is_active[node]) {
138 active_parents_out->push_back(node);
145 bool ret_val =
false;
146 while (node < static_cast<int32> (parents.size() - 1)) {
147 node = parents[node];
148 if (is_active[node]) {
149 active_parents_out->push_back(node);
164 vector<int32> *regclasses_out,
165 vector<AffineXformStats*> *stats_out)
const {
172 double total_occ = 0.0;
173 int32 num_regclasses = 0;
174 vector<double> node_occupancies(
num_nodes_, 0.0);
175 vector<bool> generate_xform(
num_nodes_,
false);
180 total_occ += stats_in[bclass]->beta_;
181 node_occupancies[bclass] = stats_in[bclass]->beta_;
182 if (num_baseclasses_ != 1) {
183 node_occupancies[
parents_[bclass]] += node_occupancies[bclass];
185 if (node_occupancies[bclass] < min_count) {
187 generate_xform[bclass] =
false;
188 generate_xform[
parents_[bclass]] =
true;
190 generate_xform[bclass] =
true;
191 regclasses[bclass] = num_regclasses++;
196 if (total_occ < min_count) {
199 (*regclasses_out)[bclass] = 0;
203 KALDI_WARN <<
"Not enough data to compute global transform. Occupancy at " 204 <<
"root = " << total_occ <<
"<" << min_count;
211 node_occupancies[
parents_[node]] += node_occupancies[node];
213 if (generate_xform[node]) {
214 if (node_occupancies[node] < min_count) {
216 generate_xform[node] =
false;
217 generate_xform[parents_[node]] =
true;
219 regclasses[node] = num_regclasses++;
224 AssertEqual(node_occupancies[num_nodes_-1], total_occ, 1.0e-9);
226 if (generate_xform[num_nodes_-1] && regclasses[num_nodes_-1] < 0) {
227 KALDI_ASSERT(node_occupancies[num_nodes_-1] >= min_count);
228 regclasses[num_nodes_-1] = num_regclasses++;
233 stats_out->resize(num_regclasses);
234 for (
int32 r = 0; r < num_regclasses; r++) {
236 (*stats_out)[r]->Init(stats_in[0]->dim_, stats_in[0]->G_.size());
240 vector<int32> active_parents;
242 if (generate_xform[bclass]) {
244 (*stats_out)[regclasses[bclass]]->CopyStats(*(stats_in[bclass]));
245 (*regclasses_out)[bclass] = regclasses[bclass];
248 for (vector<int32>::const_iterator p = active_parents.begin(),
249 endp = active_parents.end(); p != endp; ++p) {
251 (*stats_out)[regclasses[*p]]->Add(*(stats_in[bclass]));
258 for (vector<int32>::const_iterator p = active_parents.begin(),
259 endp = active_parents.end(); p != endp; ++p) {
261 (*stats_out)[regclasses[*p]]->Add(*(stats_in[bclass]));
263 (*regclasses_out)[bclass] = regclasses[active_parents[0]];
275 if (!binary) out <<
'\n';
277 if (!binary) out <<
'\n';
280 if (!binary) out <<
'\n';
283 if (!binary) out <<
'\n';
286 if (!binary) out <<
'\n';
292 if (!binary) out <<
'\n';
293 for (vector< pair<int32, int32> >::const_iterator
298 if (!binary) out <<
'\n';
302 if (!binary) out <<
'\n';
305 if (!binary) out <<
'\n';
310 int32 total_gauss = 0;
327 int32 class_id, num_comp, pdf_id, gauss_id;
331 total_gauss += num_comp;
334 for (
int32 i = 0;
i < num_comp;
i++) {
338 baseclasses_[bclass].push_back(std::make_pair(pdf_id, gauss_id));
347 "regression tree, found " << total_gauss;
353 for (
int32 pdf_index = 0, num_pdfs = am.
NumPdfs(); pdf_index < num_pdfs;
358 int32 total_gauss = 0;
361 vector< pair<int32, int32> >::const_iterator itr =
364 for (; itr != end; ++itr) {
374 "regression tree, found " << total_gauss;
void Read(std::istream &in, bool binary, const AmDiagGmm &am)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
std::vector< std::vector< int32 > > gauss2bclass_
Mapping from (pdf, gaussian) indices to baseclasses.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
int32 num_baseclasses_
Number of leaf nodes.
std::vector< std::vector< std::pair< int32, int32 > > > baseclasses_
Each baseclass (leaf of regression tree) is a vector of Gaussian indices.
int32 NumGaussInPdf(int32 pdf_index) const
std::vector< int32 > parents_
For each node, index of its parent: size = num_nodes_ If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class); else a non-leaf node.
void MakeGauss2Bclass(const AmDiagGmm &am)
void Write(std::ostream &out, bool binary) const
bool GatherStats(const std::vector< AffineXformStats *> &stats_in, double min_count, std::vector< int32 > *regclasses_out, std::vector< AffineXformStats *> *stats_out) const
Parses the regression tree and finds the nodes whose occupancies (read from stats_in) are greater tha...
void AddVec2(const Real alpha, const VectorBase< Real > &v)
Add vector : *this = *this + alpha * rv^2 [element-wise squaring].
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
static bool GetActiveParents(int32 node, const vector< int32 > &parents, const vector< bool > &is_active, vector< int32 > *active_parents_out)
void BuildTree(const Vector< BaseFloat > &state_occs, const std::vector< int32 > &sil_indices, const AmDiagGmm &am, int32 max_clusters)
Top-down clustering of the Gaussians in a model based on their means.
void GetGaussianVariance(int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
const Vector< BaseFloat > & weights() const
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
int32 NumGauss() const
Returns the number of mixture components in the GMM.
void Scale(Real alpha)
Multiplies all elements by this constant.
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
A class representing a vector.
#define KALDI_ASSERT(cond)
void GetGaussianMean(int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
BaseFloat TreeCluster(const std::vector< Clusterable *> &points, int32 max_clust, std::vector< Clusterable *> *clusters_out, std::vector< int32 > *assignments_out, std::vector< int32 > *clust_assignments_out, int32 *num_leaves_out, TreeClusterOptions cfg)
TreeCluster is a top-down clustering algorithm, using a binary tree (not necessarily balanced)...
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
bool IsSortedAndUniq(const std::vector< T > &vec)
Returns true if the vector is sorted and contains each element only once.
int32 num_nodes_
Total (non-leaf+leaf) nodes.