31 const std::vector<int32> &phone_ids,
32 const std::vector<int32> &phone2hmm_length,
33 const std::vector<bool> &is_ctx_dep,
34 bool ensure_all_phones_covered,
43 int32 max_phone = *std::max_element(phone_ids.begin(), phone_ids.end());
44 KALDI_ASSERT(phone2hmm_length.size() >=
static_cast<size_t>(1 + max_phone));
45 KALDI_ASSERT(is_ctx_dep.size() >=
static_cast<size_t>(1 + max_phone));
49 std::vector<int32> tmp(phone_ids);
53 size_t num_phones = phone_ids.size();
57 for (
int32 i = 0;
i < max_phone+1;
i++)
61 std::map<EventType, Clusterable*> stats_tmp;
63 std::vector<bool> covered(1 + max_phone,
false);
65 bool all_covered =
false;
66 for (
int32 i = 0;
i < num_stats || (ensure_all_phones_covered && !all_covered);
i++) {
68 std::vector<int32> phone_vec(N);
69 for (
size_t i = 0;
i < (
size_t)N;
i++) phone_vec[
i] = phone_ids[(
Rand() % num_phones)];
71 int32 hmm_length = phone2hmm_length[phone_vec[P]];
73 covered[phone_vec[P]] =
true;
76 for (
int32 j = 0;
j < hmm_length;
j++) {
80 for (
size_t pos = 0; pos < (
size_t)N; pos++) {
81 if (pos == (
size_t)(P) || is_ctx_dep[phone_vec[P]])
93 for (
int32 k = 0; k < N; k++) {
95 BaseFloat j_pos = (hmm_length - 0.5 -
j) / hmm_length;
98 BaseFloat weight = j_pos*k_pos + (1.0-j_pos)*(1.0-k_pos);
101 if (k == P) weight += 1.0;
106 for (
int32 k = 0; k < N; k++)
107 mean.
AddVec(weights(k), phone_vecs.
Row(phone_vec[k]));
112 int32 num_samples = 10;
113 for (
size_t p = 0;p < (
size_t)num_samples; p++) {
116 this_stats->
AddStats(sample, count / num_samples);
120 if (stats_tmp.count(event_vec) != 0) {
121 stats_tmp[event_vec]->Add(*this_stats);
124 stats_tmp[event_vec] = this_stats;
128 for (
size_t i = 0;
i< num_phones;
i++)
if (!covered[phone_ids[
i]]) all_covered =
false;
137 const std::vector<std::vector<int32> > &phone_sets,
138 const std::vector<int32> &phone2num_pdf_classes,
139 const std::vector<bool> &share_roots,
140 const std::vector<bool> &do_split,
146 bool round_num_leaves) {
150 && phone_sets.size() == share_roots.size()
151 && do_split.size() == phone_sets.size());
154 int32 num_leaves = 0;
158 phone2num_pdf_classes,
161 KALDI_LOG <<
"BuildTree: before building trees, map has "<< num_leaves <<
" leaves.";
168 std::vector<int32> nonsplit_phones;
169 for (
size_t i = 0;
i < phone_sets.size();
i++)
171 nonsplit_phones.insert(nonsplit_phones.end(), phone_sets[
i].begin(), phone_sets[
i].end());
173 std::sort(nonsplit_phones.begin(), nonsplit_phones.end());
183 qopts, thresh, max_leaves,
184 &num_leaves, &impr, &smallest_split);
186 if (cluster_thresh < 0.0) {
187 KALDI_LOG <<
"Setting clustering threshold to smallest split " << smallest_split;
188 cluster_thresh = smallest_split;
192 impr_normalized = impr / normalizer,
194 impr_normalized_filt = impr / normalizer_filt;
196 KALDI_VLOG(1) <<
"After decision tree split, num-leaves = " << num_leaves
197 <<
", like-impr = " << impr_normalized <<
" per frame over " 198 << normalizer <<
" frames.";
200 KALDI_VLOG(1) <<
"Including just phones that were split, improvement is " 201 << impr_normalized_filt <<
" per frame over " 202 << normalizer_filt <<
" frames.";
205 if (cluster_thresh != 0.0) {
209 int32 num_removed = 0;
215 KALDI_LOG <<
"BuildTree: removed "<< num_removed <<
" leaves.";
217 int32 num_leaves_out = 0;
219 if (round_num_leaves) {
222 int32 num_leaves_required = ((num_leaves - num_removed) / 8) * 8;
223 std::vector<EventMap*> leaf_mapping;
225 int32 num_removed_in_rounding = 0;
227 *tree_clustered, stats, num_leaves_required, *tree_stub,
228 &num_removed_in_rounding);
230 if (num_removed_in_rounding > 0)
231 KALDI_LOG <<
"BuildTree: Rounded num leaves to multiple of 8 by" 232 <<
" removing " << num_removed_in_rounding <<
" leaves.";
234 if (num_leaves - num_removed - num_removed_in_rounding !=
235 num_leaves_required) {
236 KALDI_WARN <<
"Did not get expected number of leaves: " 237 << num_leaves <<
" - " << num_removed <<
" - " 238 << num_removed_in_rounding
239 <<
" != " << num_leaves_required;
244 if (num_leaves_out != num_leaves_required) {
245 KALDI_WARN <<
"num-leaves-out != num-leaves-required: " 246 << num_leaves_out <<
" != " << num_leaves_required;
256 KALDI_VLOG(1) <<
"Objf change due to clustering " 257 << ((objf_after_cluster-objf_before_cluster) / normalizer)
259 KALDI_VLOG(1) <<
"Normalizing over only split phones, this is: " 260 << ((objf_after_cluster-objf_before_cluster) / normalizer_filt)
262 KALDI_VLOG(1) <<
"Num-leaves is now "<< num_leaves_out;
264 delete tree_clustered;
267 return tree_renumbered;
269 if (round_num_leaves) {
275 int32 num_leaves_required = (num_leaves / 8) * 8;
276 std::vector<EventMap*> leaf_mapping;
278 int32 num_removed_in_rounding = 0;
280 *tree_split, stats, num_leaves_required, *tree_stub,
281 &num_removed_in_rounding);
283 if (num_removed_in_rounding > 0)
284 KALDI_LOG <<
"BuildTree: Rounded num leaves to multiple of 8 by" 285 <<
" removing " << num_removed_in_rounding <<
" leaves.";
289 int32 num_leaves_out;
294 KALDI_VLOG(1) <<
"Objf change due to clustering " 295 << ((objf_after_cluster-objf_before_cluster) / normalizer)
297 KALDI_VLOG(1) <<
"Normalizing over only split phones, this is: " 298 << ((objf_after_cluster-objf_before_cluster) / normalizer_filt)
300 KALDI_VLOG(1) <<
"Num-leaves is now "<< num_leaves_out;
304 return tree_renumbered;
323 std::vector<int32> *leaf_map) {
324 std::vector<BuildTreeStatsType> split_stats_small;
326 num_leaves_small = small_tree.
MaxResult() + 1;
328 KALDI_ASSERT(static_cast<int32>(split_stats_small.size()) <=
331 leaf_map->resize(num_leaves_big, -1);
333 std::vector<int32> small_leaves_unseen;
339 for (
int32 i = 0;
i < num_leaves_small;
i++) {
340 if (static_cast<size_t>(
i) >= split_stats_small.size() ||
341 split_stats_small[
i].empty()) {
342 KALDI_WARN <<
"No stats mapping to " <<
i <<
" in small tree. " 343 <<
"Continuing but this is a serious error.";
344 small_leaves_unseen.push_back(i);
346 for (
size_t j = 0;
j < split_stats_small[
i].size();
j++) {
348 bool ok = big_tree.
Map(split_stats_small[
i][
j].first, &leaf);
350 KALDI_ERR <<
"Could not map stats with big tree: probable code error.";
351 if (leaf < 0 || leaf >= num_leaves_big)
352 KALDI_ERR <<
"Leaf out of range: " << leaf <<
" vs. " << num_leaves_big;
353 if ((*leaf_map)[leaf] != -1 && (*leaf_map)[leaf] !=
i)
354 KALDI_ERR <<
"Inconsistent mapping for big tree: " 355 <<
i <<
" vs. " << (*leaf_map)[leaf];
356 (*leaf_map)[leaf] =
i;
364 for (
int32 leaf = 0; leaf < num_leaves_big; leaf++) {
365 int32 small_leaf = (*leaf_map)[leaf];
366 if (small_leaf == -1) {
367 KALDI_WARN <<
"In ComputeTreeMapping, could not get mapping from leaf " 369 if (!small_leaves_unseen.empty()) {
370 small_leaf = small_leaves_unseen.back();
371 KALDI_WARN <<
"Assigning it to unseen small-tree leaf " << small_leaf;
372 small_leaves_unseen.pop_back();
373 (*leaf_map)[leaf] = small_leaf;
375 KALDI_WARN <<
"Could not find any unseen small-tree leaf to assign " 376 <<
"it to. Making it zero, but this is bad. ";
377 (*leaf_map)[leaf] = 0;
379 }
else if (small_leaf < 0 || small_leaf >= num_leaves_small)
380 KALDI_ERR <<
"Leaf in leaf mapping out of range: for big-map leaf " 381 << leaf <<
", mapped to " << small_leaf <<
", vs. " 388 const std::vector<std::vector<int32> > &phone_sets,
389 const std::vector<int32> &phone2num_pdf_classes,
390 const std::vector<bool> &share_roots,
391 const std::vector<bool> &do_split,
393 int32 max_leaves_first,
394 int32 max_leaves_second,
397 std::vector<int32> *leaf_map) {
399 KALDI_LOG <<
"****BuildTreeTwoLevel: building first level tree";
401 phone2num_pdf_classes,
402 share_roots, do_split, stats, 0.0,
403 max_leaves_first, 0.0, P);
405 KALDI_LOG <<
"****BuildTreeTwoLevel: done building first level tree";
408 std::vector<int32> nonsplit_phones;
409 for (
size_t i = 0;
i < phone_sets.size();
i++)
411 nonsplit_phones.insert(nonsplit_phones.end(), phone_sets[
i].begin(), phone_sets[
i].end());
412 std::sort(nonsplit_phones.begin(), nonsplit_phones.end());
421 old_num_leaves = num_leaves;
428 qopts, 0.0, max_leaves_second,
429 &num_leaves, &impr, &smallest_split);
431 KALDI_LOG <<
"Building second-level tree: increased #leaves from " 432 << old_num_leaves <<
" to " << num_leaves <<
", smallest split was " 436 impr_normalized = impr / normalizer;
438 KALDI_LOG <<
"After second decision tree split, num-leaves = " 439 << num_leaves <<
", like-impr = " << impr_normalized
440 <<
" per frame over " << normalizer <<
" frames.";
442 if (cluster_leaves) {
443 KALDI_LOG <<
"Clustering leaves of larger tree.";
447 int32 num_removed = 0;
453 KALDI_LOG <<
"BuildTreeTwoLevel: removed " << num_removed <<
" leaves.";
455 int32 num_leaves = 0;
460 KALDI_LOG <<
"Objf change due to clustering " 461 << ((objf_after_cluster-objf_before_cluster) /
SumNormalizer(stats))
463 KALDI_LOG <<
"Num-leaves now "<< num_leaves;
465 delete tree_clustered;
466 tree = tree_renumbered;
476 std::vector<std::pair<int32, int32> > leaf_pairs;
477 for (
size_t i = 0; i < leaf_map->size(); i++)
478 leaf_pairs.push_back(std::make_pair((*leaf_map)[i], static_cast<int32>(i)));
480 std::sort(leaf_pairs.begin(), leaf_pairs.end());
481 std::vector<int32> old2new_map(leaf_map->size()),
482 new_leaf_map(leaf_map->size());
486 for (
size_t i = 0; i < leaf_pairs.size(); i++) {
487 int32 old_number = leaf_pairs[
i].second, new_number =
i;
488 old2new_map[old_number] = new_number;
489 new_leaf_map[new_number] = (*leaf_map)[old_number];
491 *leaf_map = new_leaf_map;
494 tree = renumbered_tree;
497 delete first_level_tree;
504 std::vector<int32> *syms) {
505 std::ifstream is(filename.c_str());
507 KALDI_ERR <<
"ReadSymbolTableAsIntegers: could not open symbol table "<<filename;
511 while (getline(is, line)) {
514 std::istringstream ss(line);
515 ss >> sym >> index >> std::ws;
516 if (ss.fail() || !ss.eof()) {
517 KALDI_ERR <<
"Bad line in symbol table: "<< line<<
", file is: "<<filename;
519 if (include_eps || index != 0)
520 syms->push_back(index);
521 if (index == 0 && sym !=
"<eps>") {
522 KALDI_WARN <<
"Symbol zero is "<<sym<<
", traditionally <eps> is used. Make sure this is not a \"real\" symbol.";
525 size_t sz = syms->size();
527 if (syms->size() != sz)
528 KALDI_ERR <<
"Symbol table "<<filename<<
" seems to contain duplicate symbols.";
535 std::vector<std::vector<int32 > > new_vecs;
536 new_vecs.reserve(vecs->size());
537 int32 num_not_inserted = 0;
538 for (std::vector<std::vector<int32 > >::const_iterator iter = vecs->begin(),
539 end = vecs->end(); iter != end; iter++) {
540 if (vec_set.insert(*iter).second) {
542 new_vecs.push_back(*iter);
547 KALDI_VLOG(2) <<
"Removed " << num_not_inserted
548 <<
" duplicates from the phone sets.";
549 vecs->swap(new_vecs);
559 const std::vector<int32> &assignments,
560 const std::vector<int32> &clust_assignments,
562 std::vector<std::vector<int32> > *sets_out) {
565 std::vector<std::vector<int32> > raw_sets(clust_assignments.size());
567 KALDI_ASSERT(num_leaves < static_cast<int32>(clust_assignments.size()));
569 for (
size_t i = 0;
i < assignments.size();
i++) {
570 int32 clust = assignments[
i];
572 for (
size_t j = 0;
j < phone_sets[
i].size();
j++) {
574 raw_sets[clust].push_back(phone_sets[
i][
j]);
582 for (
int32 j = 0; j < static_cast<int32>(clust_assignments.size());
j++) {
583 int32 parent = clust_assignments[
j];
584 std::sort(raw_sets[
j].begin(), raw_sets[
j].end());
586 if (parent < static_cast<int32>(clust_assignments.size())-1) {
588 raw_sets[parent].insert(raw_sets[parent].end(),
596 std::reverse(raw_sets.begin(), raw_sets.end());
603 for (
size_t i = 0;
i < phone_sets.size();
i++) {
604 raw_sets.push_back(phone_sets[
i]);
608 sets_out->reserve(raw_sets.size());
609 for (
size_t i = 0;
i < raw_sets.size();
i++)
610 if (! raw_sets[
i].empty())
611 sets_out->push_back(raw_sets[
i]);
616 const std::vector<std::vector<int32> > &phone_sets_in,
617 const std::vector<int32> &all_pdf_classes_in,
619 std::vector<std::vector<int32> > *questions_out) {
620 std::vector<std::vector<int32> > phone_sets(phone_sets_in);
621 std::vector<int32> phones;
622 for (
size_t i = 0;
i < phone_sets.size() ;
i++) {
623 std::sort(phone_sets[
i].begin(), phone_sets[
i].end());
624 if (phone_sets[
i].empty())
625 KALDI_ERR <<
"Empty phone set in AutomaticallyObtainQuestions";
627 KALDI_ERR <<
"Phone set in AutomaticallyObtainQuestions contains duplicate phones";
628 for (
size_t j = 0;
j < phone_sets[
i].size();
j++)
629 phones.push_back(phone_sets[i][
j]);
631 std::sort(phones.begin(), phones.end());
633 KALDI_ERR <<
"Phones are present in more than one phone set.";
637 std::vector<int32> all_pdf_classes(all_pdf_classes_in);
646 if (retained_stats.size() * 10 < stats.size()) {
647 std::ostringstream ss;
648 for (
size_t i = 0;
i < all_pdf_classes.size();
i++)
649 ss << all_pdf_classes[
i] <<
' ';
650 KALDI_WARN <<
"After filtering the tree statistics to retain only stats where " 651 <<
"pdf-class is in the set { " << ss.str() <<
"}, most of your " 652 <<
"stats disappeared: the size changed from " << stats.size()
653 <<
" to " << retained_stats.size() <<
". You might be using " 654 <<
"a nonstandard topology but forgot to modify the " 655 <<
"--pdf-class-list option (it defaults to { 1 } which is " 656 <<
"the central state in a 3-state left-to-right topology)." 657 <<
" E.g. a 1-state HMM topology would require the option " 658 <<
"--pdf-class-list=0.";
662 std::vector<BuildTreeStatsType> split_stats;
665 std::vector<Clusterable*> summed_stats;
668 int32 max_phone = phones.back();
669 if (static_cast<int32>(summed_stats.size()) < max_phone+1) {
673 summed_stats.resize(max_phone+1, NULL);
676 for (
int32 i = 0;
static_cast<size_t>(
i) < summed_stats.size();
i++) {
677 if (summed_stats[
i] != NULL &&
678 !binary_search(phones.begin(), phones.end(),
i)) {
679 KALDI_WARN <<
"Phone "<<
i <<
" is present in stats but is not in phone list [make sure you intended this].";
686 std::vector<Clusterable*> summed_stats_per_set(phone_sets.size(), NULL);
687 for (
size_t i = 0;
i < phone_sets.size();
i++) {
688 const std::vector<int32> &this_set = phone_sets[
i];
689 summed_stats_per_set[
i] = summed_stats[this_set[0]]->Copy();
690 for (
size_t j = 1;
j < this_set.size();
j++)
691 summed_stats_per_set[
i]->Add(*(summed_stats[this_set[
j]]));
694 int32 num_no_data = 0;
695 for (
size_t i = 0;
i < summed_stats_per_set.size();
i++) {
696 if (summed_stats_per_set[
i]->Normalizer() == 0.0) {
698 std::ostringstream ss;
699 ss <<
"AutomaticallyObtainQuestions: no stats available for phone set: ";
700 for (
size_t j = 0;
j < phone_sets[
i].size();
j++)
701 ss << phone_sets[
i][
j] <<
' ' ;
705 if (num_no_data + 1 >= summed_stats_per_set.size()) {
706 std::ostringstream ss;
707 for (
size_t i = 0;
i < all_pdf_classes.size();
i++)
708 ss << all_pdf_classes[
i] <<
' ';
709 KALDI_WARN <<
"All or all but one of your classes of phones had no data. " 710 <<
"Note that we only consider data where pdf-class is in the " 711 <<
"set ( " << ss.str() <<
"). If you have an unusual HMM " 712 <<
"topology this may not be what you want; use the " 713 <<
"--pdf-class-list option to change this if needed. See " 714 <<
"also any warnings above.";
722 std::vector<int32> assignments;
723 std::vector<int32> clust_assignments;
726 summed_stats_per_set.size(),
749 const std::vector<std::vector<int32> > &phone_sets_in,
750 const std::vector<int32> &all_pdf_classes_in,
753 std::vector<std::vector<int32> > *sets_out) {
754 std::vector<std::vector<int32> > phone_sets(phone_sets_in);
755 std::vector<int32> phones;
756 for (
size_t i = 0;
i < phone_sets.size() ;
i++) {
757 std::sort(phone_sets[
i].begin(), phone_sets[
i].end());
758 if (phone_sets[
i].empty())
759 KALDI_ERR <<
"Empty phone set in AutomaticallyObtainQuestions";
761 KALDI_ERR <<
"Phone set in AutomaticallyObtainQuestions contains duplicate phones";
762 for (
size_t j = 0;
j < phone_sets[
i].size();
j++)
763 phones.push_back(phone_sets[i][
j]);
765 std::sort(phones.begin(), phones.end());
767 KALDI_ERR <<
"Phones are present in more than one phone set.";
771 std::vector<int32> all_pdf_classes(all_pdf_classes_in);
781 std::vector<BuildTreeStatsType> split_stats;
784 std::vector<Clusterable*> summed_stats;
787 int32 max_phone = phones.back();
788 if (static_cast<int32>(summed_stats.size()) < max_phone+1) {
792 summed_stats.resize(max_phone+1, NULL);
795 for (
int32 i = 0;
static_cast<size_t>(
i) < summed_stats.size();
i++) {
797 if (summed_stats[
i] != NULL &&
798 !binary_search(phones.begin(), phones.end(),
i)) {
799 KALDI_WARN <<
"Phone "<<
i <<
" is present in stats but is not in phone list [make sure you intended this].";
806 std::vector<Clusterable*> summed_stats_per_set(phone_sets.size(), NULL);
807 for (
size_t i = 0;
i < phone_sets.size();
i++) {
808 const std::vector<int32> &this_set = phone_sets[
i];
809 summed_stats_per_set[
i] = summed_stats[this_set[0]]->Copy();
810 for (
size_t j = 1;
j < this_set.size();
j++)
811 summed_stats_per_set[
i]->Add(*(summed_stats[this_set[
j]]));
814 for (
size_t i = 0;
i < summed_stats_per_set.size();
i++) {
815 if (summed_stats_per_set[
i]->Normalizer() == 0.0) {
816 std::ostringstream ss;
817 ss <<
"AutomaticallyObtainQuestions: no stats available for phone set: ";
818 for (
size_t j = 0;
j < phone_sets[
i].size();
j++)
819 ss << phone_sets[
i][
j] <<
' ' ;
827 std::vector<int32> assignments;
836 KALDI_LOG <<
"ClusterKMeans: objf change from clustering [versus single set] is " 837 << (objf_impr/
count) <<
" over " << count <<
" frames.";
839 sets_out->resize(num_classes);
841 for (
size_t i = 0;
i < assignments.size();
i++) {
842 int32 class_idx = assignments[
i];
843 KALDI_ASSERT(static_cast<size_t>(class_idx) < sets_out->size());
844 for (
size_t j = 0;
j < phone_sets[
i].size();
j++)
845 (*sets_out)[class_idx].push_back(phone_sets[
i][
j]);
847 for (
size_t i = 0;
i < sets_out->size();
i++) {
848 std::sort( (*sets_out)[
i].begin(), (*sets_out)[
i].end() );
858 std::vector<std::vector<int32> > *phone_sets,
859 std::vector<bool> *is_shared_root,
860 std::vector<bool> *is_split_root) {
861 KALDI_ASSERT(phone_sets != NULL && is_shared_root != NULL &&
862 is_split_root != NULL && phone_sets->empty()
863 && is_shared_root->empty() && is_split_root->empty());
867 while ( ! getline(is, line).fail() ) {
869 std::istringstream ss(line);
872 if (ss.fail() && shared !=
"shared" && shared !=
"not-shared")
873 KALDI_ERR <<
"Bad line in roots file: line "<< line_number <<
": " << line;
874 is_shared_root->push_back(shared ==
"shared");
878 if (ss.fail() && shared !=
"split" && shared !=
"not-split")
879 KALDI_ERR <<
"Bad line in roots file: line "<< line_number <<
": " << line;
880 is_split_root->push_back(split ==
"split");
882 phone_sets->push_back(std::vector<int32>());
884 while ( !(ss >> i).fail() ) {
885 phone_sets->back().push_back(i);
887 std::sort(phone_sets->back().begin(), phone_sets->back().end());
889 || phone_sets->back().front() <= 0)
890 KALDI_ERR <<
"Bad line in roots file [empty, or contains non-positive " 891 <<
" or duplicate phone-ids]: line " << line_number <<
": " 894 if (phone_sets->empty())
void AddStats(const VectorBase< BaseFloat > &vec, BaseFloat weight=1.0)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in)
Sums the normalizer [typically, data-count] over the stats.
EventMap * BuildTreeTwoLevel(Questions &qopts, const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &phone2num_pdf_classes, const std::vector< bool > &share_roots, const std::vector< bool > &do_split, const BuildTreeStatsType &stats, int32 max_leaves_first, int32 max_leaves_second, bool cluster_leaves, int32 P, std::vector< int32 > *leaf_map)
BuildTreeTwoLevel builds a two-level tree, useful for example in building tied mixture systems with m...
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
A hashing function-object for vectors.
This class defines, for each EventKeyType, a set of initial questions that it tries and also a number...
float RandUniform(struct RandomState *state=NULL)
Returns a random number strictly between 0 and 1.
virtual EventAnswerType MaxResult() const
BaseFloat SumClusterableNormalizer(const std::vector< Clusterable *> &vec)
Returns the total normalizer (usually count) of the cluster (pointers may be NULL).
BaseFloat ClusterKMeans(const std::vector< Clusterable *> &points, int32 num_clust, std::vector< Clusterable *> *clusters_out, std::vector< int32 > *assignments_out, ClusterKMeansOptions cfg)
ClusterKMeans is a K-means-like clustering algorithm.
void AutomaticallyObtainQuestions(BuildTreeStatsType &stats, const std::vector< std::vector< int32 > > &phone_sets_in, const std::vector< int32 > &all_pdf_classes_in, int32 P, std::vector< std::vector< int32 > > *questions_out)
Outputs sets of phones that are reasonable for questions to ask in the tree-building algorithm...
void SplitStatsByMap(const BuildTreeStatsType &stats, const EventMap &e, std::vector< BuildTreeStatsType > *stats_out)
Splits stats according to the EventMap, indexing them at output by the leaf type. ...
static void RemoveDuplicates(std::vector< std::vector< int32 > > *vecs)
float RandGauss(struct RandomState *state=NULL)
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
void GenRandStats(int32 dim, int32 num_stats, int32 N, int32 P, const std::vector< int32 > &phone_ids, const std::vector< int32 > &phone2hmm_length, const std::vector< bool > &is_ctx_dep, bool ensure_all_phones_covered, BuildTreeStatsType *stats_out)
GenRandStats generates random statistics of the form used by BuildTree.
void SplitStatsByKey(const BuildTreeStatsType &stats_in, EventKeyType key, std::vector< BuildTreeStatsType > *stats_out)
SplitStatsByKey splits stats up according to the value of a particular key, which must be always defi...
virtual bool Map(const EventType &event, EventAnswerType *ans) const =0
static const EventKeyType kPdfClass
std::vector< std::pair< EventKeyType, EventValueType > > EventType
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
void EnsureClusterableVectorNotNull(std::vector< Clusterable *> *stats)
Fills in any (NULL) holes in "stats" vector, with empty stats, because certain algorithms require non...
void ReadSymbolTableAsIntegers(std::string filename, bool include_eps, std::vector< int32 > *syms)
included here because it's used in some tree-building calling code.
static void ObtainSetsOfPhones(const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &assignments, const std::vector< int32 > &clust_assignments, int32 num_leaves, std::vector< std::vector< int32 > > *sets_out)
ObtainSetsOfPhones is called by AutomaticallyObtainQuestions.
int32 EventKeyType
Things of type EventKeyType can take any value.
void SumStatsVec(const std::vector< BuildTreeStatsType > &stats_in, std::vector< Clusterable *> *stats_out)
Sum a vector of stats.
BaseFloat ObjfGivenMap(const BuildTreeStatsType &stats_in, const EventMap &e)
Cluster the stats given the event map return the total objf given those clusters. ...
void Scale(Real alpha)
Multiplies all elements by this constant.
int Rand(struct RandomState *state)
Real Sum() const
Returns sum of the elements.
static void ComputeTreeMapping(const EventMap &small_tree, const EventMap &big_tree, const BuildTreeStatsType &stats, std::vector< int32 > *leaf_map)
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
A class representing a vector.
#define KALDI_ASSERT(cond)
EventMap * BuildTree(Questions &qopts, const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &phone2num_pdf_classes, const std::vector< bool > &share_roots, const std::vector< bool > &do_split, const BuildTreeStatsType &stats, BaseFloat thresh, int32 max_leaves, BaseFloat cluster_thresh, int32 P, bool round_num_leaves)
BuildTree is the normal way to build a set of decision trees.
void ReadRootsFile(std::istream &is, std::vector< std::vector< int32 > > *phone_sets, std::vector< bool > *is_shared_root, std::vector< bool > *is_split_root)
Reads the roots file (throws on error).
void FilterStatsByKey(const BuildTreeStatsType &stats_in, EventKeyType key, std::vector< EventValueType > &values, bool include_if_present, BuildTreeStatsType *stats_out)
FilterStatsByKey filters the stats according the value of a specified key.
BaseFloat TreeCluster(const std::vector< Clusterable *> &points, int32 max_clust, std::vector< Clusterable *> *clusters_out, std::vector< int32 > *assignments_out, std::vector< int32 > *clust_assignments_out, int32 *num_leaves_out, TreeClusterOptions cfg)
TreeCluster is a top-down clustering algorithm, using a binary tree (not necessarily balanced)...
int32 EventValueType
Given current code, things of type EventValueType should generally be nonnegative and in a reasonably...
ClusterKMeansOptions kmeans_cfg
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
bool IsSortedAndUniq(const std::vector< T > &vec)
Returns true if the vector is sorted and contains each element only once.
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
void CopyMapToVector(const std::map< A, B > &m, std::vector< std::pair< A, B > > *v)
Copies the (key, value) pairs in a map to a vector of pairs.
void KMeansClusterPhones(BuildTreeStatsType &stats, const std::vector< std::vector< int32 > > &phone_sets_in, const std::vector< int32 > &all_pdf_classes_in, int32 P, int32 num_classes, std::vector< std::vector< int32 > > *sets_out)
This function clusters the phones (or some initially specified sets of phones) into sets of phones...