31 uint32 size = stats.size();
33 for (
size_t i = 0;
i < size;
i++) {
35 bool nonNull = (stats[
i].second != NULL);
37 if (nonNull) stats[
i].second->Write(os, binary);
40 KALDI_ERR <<
"WriteBuildTreeStats: write failed.";
42 if (!binary) os <<
'\n';
53 for (
size_t i = 0;
i < size;
i++) {
57 if (nonNull) (*stats)[
i].second = example.
ReadNew(is, binary);
58 else (*stats)[
i].second = NULL;
65 std::vector<EventValueType> *ans) {
66 bool all_present =
true;
67 std::set<EventValueType> values;
68 BuildTreeStatsType::const_iterator iter = stats.begin(), end = stats.end();
69 for (; iter != end; ++iter) {
82 keys->resize(vec.size());
83 EventType::const_iterator iter = vec.begin(), end = vec.end();
84 std::vector<EventKeyType>::iterator out_iter = keys->begin();
85 for (; iter!= end; ++iter, ++out_iter)
86 *out_iter = iter->first;
94 BuildTreeStatsType::const_iterator iter = stats.begin(), end = stats.end();
95 if (iter == end)
return;
96 std::vector<EventKeyType> keys;
99 for (; iter!= end; ++iter) {
100 std::vector<EventKeyType> keys2;
104 KALDI_ERR <<
"AllKeys: keys in events are not all the same [called with kAllKeysInsistIdentical and all are not identical.";
106 std::vector<EventKeyType> new_keys(std::max(keys.size(), keys2.size()));
108 std::vector<EventKeyType>::iterator end_iter =
109 std::set_intersection(keys.begin(), keys.end(), keys2.begin(), keys2.end(), new_keys.begin());
110 new_keys.erase(end_iter, new_keys.end());
114 std::vector<EventKeyType> new_keys(keys.size()+keys2.size());
116 std::vector<EventKeyType>::iterator end_iter =
117 std::set_union(keys.begin(), keys.end(), keys2.begin(), keys2.end(), new_keys.begin());
118 new_keys.erase(end_iter, new_keys.end());
129 std::vector<BuildTreeStatsType> split_stats;
132 std::vector<EventMap*> splits(split_stats.size(), NULL);
134 if (!split_stats[leaf].empty()) {
136 std::vector<EventValueType> vals;
141 std::vector<EventMap*> table(vals.back()+1, (
EventMap*)NULL);
142 for (
size_t idx = 0;idx < vals.size();idx++) {
158 if (keys.empty())
return orig.
Copy();
161 for (
size_t i = 0;
i < keys.size();
i++) {
173 BuildTreeStatsType::const_iterator iter, end = stats.end();
177 for (iter = stats.begin(); iter != end; ++iter) {
180 if (!e.
Map(evec, &ans))
182 <<
"if error seen during tree-building, check that " 183 <<
"--context-width and --central-position match stats, " 184 <<
"and that phones that are context-independent (CI) during " 185 <<
"stats accumulation do not share roots with non-CI phones.";
186 size = std::max(size, (
size_t)(ans+1));
188 stats_out->resize(size);
189 for (iter = stats.begin(); iter != end; ++iter) {
192 bool b = e.
Map(evec, &ans);
194 (*stats_out)[ans].push_back(*iter);
199 BuildTreeStatsType::const_iterator iter, end = stats_in.end();
204 for (iter = stats_in.begin(); iter != end; ++iter) {
209 size = std::max(size, (
size_t)(val+1));
211 stats_out->resize(size);
213 for (iter = stats_in.begin(); iter != end; ++iter) {
217 (*stats_out)[val].push_back(*iter);
224 std::vector<EventValueType> &values,
225 bool include_if_present,
229 BuildTreeStatsType::const_iterator iter, end = stats_in.end();
233 for (iter = stats_in.begin(); iter != end; ++iter) {
238 bool in_values = std::binary_search(values.begin(), values.end(), val);
239 if (in_values == include_if_present)
240 stats_out->push_back(*iter);
247 BuildTreeStatsType::const_iterator iter = stats_in.begin(), end = stats_in.end();
248 for (; iter != end; ++iter) {
251 if (!ans) ans = cl->
Copy();
260 BuildTreeStatsType::const_iterator iter = stats_in.begin(), end = stats_in.end();
261 for (; iter != end; ++iter) {
270 BuildTreeStatsType::const_iterator iter = stats_in.begin(), end = stats_in.end();
271 for (; iter != end; ++iter) {
273 if (cl != NULL) ans += cl->
Objf();
279 void SumStatsVec(
const std::vector<BuildTreeStatsType> &stats_in, std::vector<Clusterable*> *stats_out) {
281 stats_out->resize(stats_in.size(), NULL);
282 for (
size_t i = 0;
i < stats_in.size();
i++) (*stats_out)[
i] =
SumStats(stats_in[
i]);
286 std::vector<BuildTreeStatsType> split_stats;
288 std::vector<Clusterable*> summed_stats;
299 std::vector<EventValueType> *yes_set) {
307 if (total == NULL)
return 0.0;
310 const std::vector<std::vector<EventValueType> > &questions_of_this_key = key_opts.
initial_questions;
315 for (
size_t i = 0;
i < questions_of_this_key.size();
i++) {
316 const std::vector<EventValueType> &yes_set = questions_of_this_key[
i];
317 std::vector<int32> assignments(summed_stats.size(), 0);
318 std::vector<Clusterable*> clusters(2);
319 for (std::vector<EventValueType>::const_iterator iter = yes_set.begin(); iter != yes_set.end(); ++iter) {
321 if (*iter < (
EventValueType)assignments.size()) assignments[*iter] = 1;
326 if (this_objf < unsplit_objf- 0.001*std::abs(unsplit_objf)) {
328 KALDI_WARN <<
"Objective function got worse when building tree: "<< this_objf <<
" < " << unsplit_objf;
329 KALDI_ASSERT(!(this_objf < unsplit_objf - 0.01*(200 + std::abs(unsplit_objf))));
332 BaseFloat this_objf_change = this_objf - unsplit_objf;
333 if (this_objf_change > best_objf_change) {
334 best_objf_change = this_objf_change;
341 *yes_set = questions_of_this_key[best_idx];
342 return best_objf_change;
351 std::vector<EventValueType> *yes_set_out) {
352 if (stats.size()<=1)
return 0.0;
354 yes_set_out->clear();
357 std::vector<Clusterable*> summed_stats;
359 std::vector<BuildTreeStatsType> split_stats;
364 std::vector<EventValueType> yes_set;
366 q_opts, key, &yes_set);
369 std::vector<int32> assignments(summed_stats.size(), 0);
370 for (std::vector<EventValueType>::const_iterator iter = yes_set.begin(); iter != yes_set.end(); ++iter) {
375 assignments[*iter] = 1;
378 std::vector<Clusterable*> clusters(2, (
Clusterable*)NULL);
396 KALDI_ASSERT(refine_impr > std::min(-1.0, -0.1*fabs(improvement)));
398 improvement += refine_impr;
400 for (
size_t i = 0;
i < assignments.size();
i++)
if (assignments[
i] == 1) yes_set.push_back(
i);
402 *yes_set_out = yes_set;
405 #ifdef KALDI_PARANOID 407 KALDI_ASSERT(clusters.size() == 2 && clusters[0] == 0 && clusters[1] == 0);
410 if (clusters[0] == NULL || clusters[1] == NULL) impr = 0.0;
411 else impr = clusters[0]->Distance(*(clusters[1]));
412 if (!
ApproxEqual(impr, improvement) && fabs(impr-improvement) > 0.01) {
413 KALDI_WARN <<
"FindBestSplitForKey: improvements do not agree: "<< impr
414 <<
" vs. " << improvement;
465 yes_stats.reserve(
stats_.size()); no_stats.reserve(
stats_.size());
466 for (BuildTreeStatsType::const_iterator iter =
stats_.begin(); iter !=
stats_.end(); ++iter) {
470 if (std::binary_search(
yes_set_.begin(),
yes_set_.end(), val)) yes_stats.push_back(*iter);
471 else no_stats.push_back(*iter);
473 #ifdef KALDI_PARANOID 481 delete yes_clust;
delete no_clust;
493 std::vector<EventKeyType> all_keys;
495 if (all_keys.size() == 0) {
496 KALDI_WARN <<
"DecisionTreeSplitter::FindBestSplit(), no keys available to split on (maybe no key covered all of your events, or there was a problem with your questions configuration?)";
499 for (
size_t i = 0;
i < all_keys.size();
i++) {
501 std::vector<EventValueType> temp_yes_set;
541 int32 num_empty_leaves = 0;
543 BaseFloat smallest_split_change = 1.0e+20;
544 std::vector<DecisionTreeSplitter*> builders;
547 std::vector<BuildTreeStatsType> split_stats;
550 builders.resize(split_stats.size());
551 for (
size_t i = 0;
i < split_stats.size();
i++) {
553 if (split_stats[
i].size() == 0) num_empty_leaves++;
560 std::priority_queue<std::pair<BaseFloat, size_t> > queue;
563 for (
size_t i = 0;
i < builders.size();
i++)
564 queue.push(std::make_pair(builders[
i]->BestSplit(),
i));
567 while (queue.top().first > thresh
568 && (max_leaves<=0 || *num_leaves < max_leaves)) {
569 smallest_split_change = std::min(smallest_split_change, queue.top().first);
570 size_t i = queue.top().second;
571 like_impr += queue.top().first;
572 builders[
i]->DoSplit(num_leaves);
574 queue.push(std::make_pair(builders[i]->
BestSplit(), i));
577 KALDI_LOG <<
"DoDecisionTreeSplit: split "<< count <<
" times, #leaves now " << (*num_leaves);
580 if (smallest_split_change_out)
581 *smallest_split_change_out = smallest_split_change;
586 std::vector<EventMap*> sub_trees(builders.size());
587 for (
size_t i = 0;
i < sub_trees.size();
i++) sub_trees[
i] = builders[
i]->
GetMap();
588 answer = input_map.
Copy(sub_trees);
589 for (
size_t i = 0;
i < sub_trees.size();
i++)
delete sub_trees[
i];
592 for (
size_t i = 0;
i < builders.size();
i++)
delete builders[
i];
594 if (obj_impr_out != NULL) *obj_impr_out = like_impr;
602 std::vector<EventMap*> *mapping) {
605 std::vector<BuildTreeStatsType> split_stats;
607 std::vector<Clusterable*> summed_stats;
610 std::vector<int32> indexes;
611 std::vector<Clusterable*> summed_stats_contiguous;
612 size_t max_index = 0;
613 for (
size_t i = 0;
i < summed_stats.size();
i++) {
614 if (summed_stats[
i] != NULL) {
615 indexes.push_back(
i);
616 summed_stats_contiguous.push_back(summed_stats[
i]);
617 if (i > max_index) max_index =
i;
620 if (summed_stats_contiguous.empty()) {
621 KALDI_WARN <<
"ClusterBottomUp: nothing to cluster.";
625 std::vector<int32> assignments;
635 KALDI_ASSERT(assignments.size() == summed_stats_contiguous.size() && !assignments.empty());
636 size_t num_clust = * std::max_element(assignments.begin(), assignments.end()) + 1;
637 int32 num_combined = summed_stats_contiguous.size() - num_clust;
640 KALDI_VLOG(2) <<
"ClusterBottomUp combined "<< num_combined
641 <<
" leaves and gave a likelihood change of " << change
642 <<
", normalized = " << (change/normalizer)
643 <<
", normalizer = " << normalizer;
647 if (max_index >= mapping->size()) mapping->resize(max_index+1, NULL);
649 for (
size_t i = 0;
i < summed_stats_contiguous.size();
i++) {
650 size_t index = indexes[
i];
651 size_t new_index = indexes[assignments[
i]];
654 KALDI_ASSERT((*mapping)[index] == NULL ||
"Error: Cluster seems to have been " 655 "called for different parts of the tree with overlapping sets of " 666 std::vector<EventAnswerType> initial_leaves;
667 e_in.
MultiMap(empty_vec, &initial_leaves);
668 if (initial_leaves.empty()) {
670 if (num_leaves) *num_leaves = 0;
675 std::vector<EventMap*> mapping(max_leaf_plus_one, (
EventMap*)NULL);
676 std::vector<EventAnswerType>::iterator iter = initial_leaves.begin(), end = initial_leaves.end();
678 for (; iter != end; ++iter) {
684 KALDI_ASSERT((
size_t)cur_leaf == initial_leaves.size());
685 if (num_leaves) *num_leaves = cur_leaf;
690 const std::vector<int32> &mapping_in) {
691 std::vector<EventMap*> mapping(mapping_in.size());
692 for (
size_t i = 0;
i < mapping_in.size();
i++)
701 std::vector<EventMap*> mapping;
705 if (num_removed_ptr != NULL) *num_removed_ptr = num_removed;
711 std::vector<std::vector<EventValueType> > &values,
715 std::vector<std::vector<EventAnswerType> > pdfs(values.size());
716 for (
size_t i = 0;
i < values.size();
i++) {
718 for (
size_t j = 0;
j < values[
i].size();
j++) {
720 size_t size_at_start = pdfs[
i].size();
722 if (pdfs[i].size() == size_at_start) {
723 KALDI_WARN <<
"ShareEventMapLeaves: had no leaves for key = " << key
724 <<
", value = " << (values[
i][
j]);
729 std::vector<EventMap*> remapping;
730 for (
size_t i = 0;
i < values.size();
i++) {
732 KALDI_WARN <<
"ShareEventMapLeaves: no leaves in one bucket.";
736 for (
size_t j = 1;
j < pdfs[
i].size();
j++) {
739 if (remapping.size() <=
static_cast<size_t>(leaf))
740 remapping.resize(leaf+1, NULL);
756 BuildTreeStatsType::iterator iter = stats->begin(), end = stats->end();
757 for (; iter!= end; ++iter)
if (iter->second != NULL) {
delete iter->second; iter->second = NULL; }
761 const std::vector<EventValueType> *phones,
762 int32 default_length) {
763 std::vector<BuildTreeStatsType> stats_by_phone;
768 "You seem to have provided invalid stats [no central-phone key].";
770 std::map<EventValueType, EventAnswerType> phone_to_length;
771 for (
size_t p = 0; p < stats_by_phone.size(); p++) {
772 if (! stats_by_phone[p].empty()) {
773 std::vector<BuildTreeStatsType> stats_by_length;
778 "You seem to have provided invalid stats [no position key].";
780 size_t length = stats_by_length.size();
781 for (
size_t i = 0;
i < length;
i++) {
782 if (stats_by_length[
i].empty()) {
783 KALDI_ERR <<
"There are no stats available for position " <<
i 784 <<
" of phone " << p;
787 phone_to_length[p] = length;
790 if (phones != NULL) {
791 for (
size_t i = 0;
i < phones->size();
i++) {
792 if (phone_to_length.count( (*phones)[
i] ) == 0) {
793 phone_to_length[(*phones)[
i]] = default_length;
806 std::vector<EventKeyType> keys,
807 std::vector<EventMap*> *leaf_mapping) {
808 if (keys.size() == 0) {
812 std::vector<BuildTreeStatsType> split_stats;
815 for (
size_t i = 0;
i< split_stats.size();
i++)
816 if (split_stats[
i].size() != 0)
825 const std::vector<EventKeyType> &keys,
826 int32 *num_removed) {
827 std::vector<EventMap*> leaf_mapping;
830 if (num_removed != NULL) *num_removed = nr;
842 int32 *num_removed_ptr) {
843 std::vector<EventMap*> leaf_mapping;
845 std::vector<BuildTreeStatsType> split_stats;
848 for (
size_t i = 0;
i < split_stats.size();
i++) {
849 if (!split_stats[
i].empty())
854 if (num_removed_ptr != NULL) *num_removed_ptr = num_removed;
864 int32 num_clusters_required,
866 int32 *num_removed_ptr) {
867 std::vector<BuildTreeStatsType> split_stats;
870 if (num_clusters_required < split_stats.size()) {
871 KALDI_WARN <<
"num-clusters-required is less than size of map. Not doing anything.";
872 if (num_removed_ptr) *num_removed_ptr = 0;
876 std::vector<std::vector<int32> > indexes(split_stats.size());
877 std::vector<std::vector<Clusterable*> > summed_stats_contiguous(split_stats.size());
881 size_t max_index = 0;
883 int32 num_non_empty_clusters_required = num_clusters_required;
885 int32 num_non_empty_clusters_in_map = 0;
886 int32 num_non_empty_clusters = 0;
888 for (
size_t i = 0;
i < split_stats.size();
i++) {
889 if (!split_stats[
i].empty()) {
890 num_non_empty_clusters_in_map++;
892 std::vector<BuildTreeStatsType> split_stats_i;
894 std::vector<Clusterable*> summed_stats_i;
897 for (
size_t j = 0;
j < summed_stats_i.size();
j++) {
898 if (summed_stats_i[
j] != NULL) {
899 num_non_empty_clusters++;
900 indexes[
i].push_back(
j);
901 summed_stats_contiguous[
i].push_back(summed_stats_i[
j]);
902 if (j > max_index) max_index =
j;
910 num_non_empty_clusters_required--;
914 KALDI_VLOG(1) <<
"Number of non-empty clusters in map = " << num_non_empty_clusters_in_map;
915 KALDI_VLOG(1) <<
"Number of non-empty clusters = " << num_non_empty_clusters;
917 if (num_non_empty_clusters_required > num_non_empty_clusters) {
918 KALDI_WARN <<
"Cannot get required num-clusters " << num_clusters_required
919 <<
" as number of non-empty clusters required is larger than " 920 <<
" number of non-empty clusters: " << num_non_empty_clusters_required
921 <<
" > " << num_non_empty_clusters;
922 if (num_removed_ptr) *num_removed_ptr = 0;
926 std::vector<std::vector<int32> > assignments;
928 summed_stats_contiguous,
929 std::numeric_limits<BaseFloat>::infinity(),
930 num_non_empty_clusters_required,
936 int32 num_combined = 0;
937 for (
size_t i = 0;
i < split_stats.size();
i++) {
938 KALDI_ASSERT(assignments[
i].size() == summed_stats_contiguous[
i].size());
939 if (assignments[
i].size() == 0)
continue;
940 size_t num_clust_i = *std::max_element(assignments[
i].begin(),
941 assignments[
i].end()) + 1;
942 num_combined += summed_stats_contiguous[
i].size() - num_clust_i;
945 KALDI_VLOG(2) <<
"ClusterBottomUpCompartmentalized combined " << num_combined
946 <<
" leaves and gave a likelihood change of " << change
947 <<
", normalized = " << (change / normalizer)
948 <<
", normalizer = " << normalizer;
951 std::vector<EventMap*> leaf_mapping(max_index + 1, NULL);
953 for (
size_t i = 0;
i < split_stats.size();
i++) {
954 for (
size_t j = 0;
j < summed_stats_contiguous[
i].size();
j++) {
955 size_t index = indexes[
i][
j];
956 size_t new_index = indexes[
i][assignments[
i][
j]];
960 KALDI_ASSERT(leaf_mapping[index] == NULL ||
"Error: Cluster seems to have been " 961 "called for different parts of the tree with overlapping sets of " 968 if (num_removed_ptr) *num_removed_ptr = num_combined;
976 const std::vector<std::vector<int32> > &phone_sets,
977 const std::vector<int32> &phone2num_pdf_classes,
978 const std::vector<bool> &share_roots,
979 int32 *num_leaves_out) {
982 KALDI_ASSERT(!phone_sets.empty() && share_roots.size() == phone_sets.size());
983 std::set<int32> all_phones;
984 for (
size_t i = 0;
i < phone_sets.size();
i++) {
987 for (
size_t j = 0;
j < phone_sets[
i].size();
j++) {
989 all_phones.insert(phone_sets[i][j]);
996 size_t max_set_size = 0;
997 int32 highest_numbered_phone = 0;
998 for (
size_t i = 0;
i < phone_sets.size();
i++) {
999 max_set_size = std::max(max_set_size, phone_sets[
i].size());
1000 highest_numbered_phone =
1001 std::max(highest_numbered_phone,
1002 * std::max_element(phone_sets[
i].begin(), phone_sets[
i].end()));
1005 if (phone_sets.size() == 1) {
1006 if (share_roots[0]) {
1011 for (
size_t i = 0;
i < phone_sets[0].size();
i++) {
1014 KALDI_ASSERT(static_cast<size_t>(phone) < phone2num_pdf_classes.size());
1015 len = phone2num_pdf_classes[phone];
1017 if (
i == 0) max_len = len;
1019 if (len != max_len) {
1020 KALDI_WARN <<
"Mismatching lengths within a phone set: " << len
1021 <<
" vs. " << max_len <<
" [unusual, but not necessarily fatal]. ";
1022 max_len = std::max(len, max_len);
1026 std::map<EventValueType, EventAnswerType> m;
1028 m[p] = (*num_leaves_out)++;
1032 }
else if (max_set_size == 1
1033 && static_cast<int32>(phone_sets.size()) <= 2*highest_numbered_phone) {
1036 std::map<EventValueType, EventMap*> m;
1037 for (
size_t i = 0;
i < phone_sets.size();
i++) {
1038 std::vector<std::vector<int32> > phone_sets_tmp;
1039 phone_sets_tmp.push_back(phone_sets[
i]);
1040 std::vector<bool> share_roots_tmp;
1041 share_roots_tmp.push_back(share_roots[i]);
1046 m[phone_sets_tmp[0][0]] = this_stub;
1051 size_t half_sz = phone_sets.size() / 2;
1052 std::vector<std::vector<int32> >::const_iterator half_phones =
1053 phone_sets.begin() + half_sz;
1054 std::vector<bool>::const_iterator half_share =
1055 share_roots.begin() + half_sz;
1056 std::vector<std::vector<int32> > phone_sets_1, phone_sets_2;
1057 std::vector<bool> share_roots_1, share_roots_2;
1058 phone_sets_1.insert(phone_sets_1.end(), phone_sets.begin(), half_phones);
1059 phone_sets_2.insert(phone_sets_2.end(), half_phones, phone_sets.end());
1060 share_roots_1.insert(share_roots_1.end(), share_roots.begin(), half_share);
1061 share_roots_2.insert(share_roots_2.end(), half_share, share_roots.end());
1063 EventMap *map1 =
GetStubMap(P, phone_sets_1, phone2num_pdf_classes, share_roots_1, num_leaves_out);
1064 EventMap *map2 =
GetStubMap(P, phone_sets_2, phone2num_pdf_classes, share_roots_2, num_leaves_out);
1066 std::vector<EventKeyType> all_in_first_set;
1067 for (
size_t i = 0;
i < half_sz;
i++)
1068 for (
size_t j = 0;
j < phone_sets[
i].size();
j++)
1069 all_in_first_set.push_back(phone_sets[
i][
j]);
1070 std::sort(all_in_first_set.begin(), all_in_first_set.end());
1079 bool warned =
false;
1080 KALDI_ASSERT(stats != NULL && oldN > 0 && newN > 0 && oldP >= 0
1081 && newP >= 0 && newP < newN && oldP < oldN);
1083 KALDI_WARN <<
"Cannot convert stats to larger context: " << newN
1088 KALDI_WARN <<
"Cannot convert stats to have more left-context: " << newP
1091 if (newN-newP-1 > oldN-oldP-1) {
1092 KALDI_WARN <<
"Cannot convert stats to have more right-context: " << (newN-newP-1)
1093 <<
" > " << (oldN-oldP-1);
1097 int32 shift = newP - oldP;
1099 for (
size_t i = 0;
i < stats->size();
i++) {
1102 for (
size_t j = 0;
j < evec.size();
j++) {
1104 if (key >= 0 && key < oldN) {
1106 if (key >= 0 && key < newN)
1107 evec_new.push_back(std::make_pair(key, evec[
j].second));
1113 KALDI_WARN <<
"Stats had keys defined that we cannot interpret";
1117 evec_new.push_back(evec[
j]);
void GetKeysWithQuestions(std::vector< EventKeyType > *keys_out) const
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in)
Sums the normalizer [typically, data-count] over the stats.
static void GetEventKeys(const EventType &vec, std::vector< EventKeyType > *keys)
void CopySetToVector(const std::set< T > &s, std::vector< T > *v)
Copies the elements of a set to a vector.
virtual void Add(const Clusterable &other)=0
Add other stats.
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
std::pair< EventKeyType, EventValueType > MakeEventPair(EventKeyType k, EventValueType v)
bool ConvertStats(int32 oldN, int32 oldP, int32 newN, int32 newP, BuildTreeStatsType *stats)
Converts stats from a given context-window (N) and central-position (P) to a different N and P...
This class defines, for each EventKeyType, a set of initial questions that it tries and also a number...
BaseFloat RefineClusters(const std::vector< Clusterable *> &points, std::vector< Clusterable *> *clusters, std::vector< int32 > *assignments, RefineClustersOptions cfg)
RefineClusters is mainly used internally by other clustering algorithms.
DecisionTreeSplitter * yes_
Clusterable * SumStats(const BuildTreeStatsType &stats_in)
Sums stats, or returns NULL stats_in has no non-NULL stats.
BaseFloat best_split_impr_
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
virtual BaseFloat Objf() const =0
Return the objective function associated with the stats [assuming ML estimation]. ...
BaseFloat SumClusterableNormalizer(const std::vector< Clusterable *> &vec)
Returns the total normalizer (usually count) of the cluster (pointers may be NULL).
std::string EventTypeToString(const EventType &evec)
void SplitStatsByMap(const BuildTreeStatsType &stats, const EventMap &e, std::vector< BuildTreeStatsType > *stats_out)
Splits stats according to the EventMap, indexing them at output by the leaf type. ...
const Questions & q_opts_
void FindAllKeys(const BuildTreeStatsType &stats, AllKeysType keys_type, std::vector< EventKeyType > *keys_out)
FindAllKeys puts in *keys the (sorted, unique) list of all key identities in the stats.
const QuestionsForKey & GetQuestionsOf(EventKeyType key) const
static int32 ClusterEventMapRestrictedHelper(const EventMap &e_in, const BuildTreeStatsType &stats, BaseFloat thresh, std::vector< EventKeyType > keys, std::vector< EventMap *> *leaf_mapping)
BaseFloat ClusterBottomUpCompartmentalized(const std::vector< std::vector< Clusterable *> > &points, BaseFloat thresh, int32 min_clust, std::vector< std::vector< Clusterable *> > *clusters_out, std::vector< std::vector< int32 > > *assignments_out)
This is a bottom-up clustering where the points are pre-clustered in a set of compartments, such that only points in the same compartment are clustered together.
Kaldi fatal runtime error exception.
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
std::vector< EventValueType > yes_set_
void SplitStatsByKey(const BuildTreeStatsType &stats_in, EventKeyType key, std::vector< BuildTreeStatsType > *stats_out)
SplitStatsByKey splits stats up according to the value of a particular key, which must be always defi...
virtual bool Map(const EventType &event, EventAnswerType *ans) const =0
virtual Clusterable * Copy() const =0
Return a copy of this object.
bool PossibleValues(EventKeyType key, const BuildTreeStatsType &stats, std::vector< EventValueType > *ans)
Convenience function e.g.
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
void ReadBuildTreeStats(std::istream &is, bool binary, const Clusterable &example, BuildTreeStatsType *stats)
Reads BuildTreeStats object.
static const EventKeyType kPdfClass
void WriteEventType(std::ostream &os, bool binary, const EventType &evec)
virtual BaseFloat Distance(const Clusterable &other) const
Return the objective function decrease from merging the two clusters, negated to be a positive number...
EventMap * GetToLengthMap(const BuildTreeStatsType &stats, int32 P, const std::vector< EventValueType > *phones, int32 default_length)
AllKeysType
Typedef used when we get "all keys" from a set of stats– used in specifying which kinds of questions...
std::vector< std::pair< EventKeyType, EventValueType > > EventType
void DoSplitInternal(int32 *next_leaf)
void EnsureClusterableVectorNotNull(std::vector< Clusterable *> *stats)
Fills in any (NULL) holes in "stats" vector, with empty stats, because certain algorithms require non...
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
BaseFloat ClusterBottomUp(const std::vector< Clusterable *> &points, BaseFloat max_merge_thresh, int32 min_clust, std::vector< Clusterable *> *clusters_out, std::vector< int32 > *assignments_out)
A bottom-up clustering algorithm.
int32 EventKeyType
Things of type EventKeyType can take any value.
void SumStatsVec(const std::vector< BuildTreeStatsType > &stats_in, std::vector< Clusterable *> *stats_out)
Sum a vector of stats.
QuestionsForKey is a class used to define the questions for a key, and also options that allow us to ...
std::vector< std::vector< EventValueType > > initial_questions
void AddToClusters(const std::vector< Clusterable *> &stats, const std::vector< int32 > &assignments, std::vector< Clusterable *> *clusters)
Given stats and a vector "assignments" of the same size (that maps to cluster indices), sums the stats up into "clusters." It will add to any stats already present in "clusters" (although typically "clusters" will be empty when called), and it will extend with NULL pointers for any unseen indices.
void ReadEventType(std::istream &is, bool binary, EventType *evec)
RefineClustersOptions refine_opts
void AddToClustersOptimized(const std::vector< Clusterable *> &stats, const std::vector< int32 > &assignments, const Clusterable &total, std::vector< Clusterable *> *clusters)
AddToClustersOptimized does the same as AddToClusters (it sums up the stats within each cluster...
DecisionTreeSplitter * no_
virtual EventMap * Copy(const std::vector< EventMap *> &new_leaves) const =0
BaseFloat ObjfGivenMap(const BuildTreeStatsType &stats_in, const EventMap &e)
Cluster the stats given the event map return the total objf given those clusters. ...
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
virtual BaseFloat Normalizer() const =0
Return the normalizer (typically, count) associated with the stats.
DecisionTreeSplitter(EventAnswerType leaf, const BuildTreeStatsType &stats, const Questions &q_opts)
static bool Lookup(const EventType &event, EventKeyType key, EventValueType *ans)
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
bool HasQuestionsForKey(EventKeyType key) const
#define KALDI_ASSERT(cond)
BaseFloat SumObjf(const BuildTreeStatsType &stats_in)
Sums the objective function over the stats.
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
BaseFloat SumClusterableObjf(const std::vector< Clusterable *> &vec)
Returns the total objective function after adding up all the statistics in the vector (pointers may b...
BaseFloat ComputeInitialSplit(const std::vector< Clusterable *> &summed_stats, const Questions &q_opts, EventKeyType key, std::vector< EventValueType > *yes_set)
int32 EventAnswerType
As far as the event-map code itself is concerned, things of type EventAnswerType may take any value e...
void FilterStatsByKey(const BuildTreeStatsType &stats_in, EventKeyType key, std::vector< EventValueType > &values, bool include_if_present, BuildTreeStatsType *stats_out)
FilterStatsByKey filters the stats according the value of a specified key.
int32 EventValueType
Given current code, things of type EventValueType should generally be nonnegative and in a reasonably...
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
virtual Clusterable * ReadNew(std::istream &os, bool binary) const =0
Read data from a stream and return the corresponding object (const function; it's a class member beca...
virtual void MultiMap(const EventType &event, std::vector< EventAnswerType > *ans) const =0
bool IsSortedAndUniq(const std::vector< T > &vec)
Returns true if the vector is sorted and contains each element only once.
BuildTreeStatsType stats_
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Clusterable * SumClusterable(const std::vector< Clusterable *> &vec)
Sums stats (ptrs may be NULL). Returns NULL if no non-NULL stats present.
void DoSplit(int32 *next_leaf)
void WriteBuildTreeStats(std::ostream &os, bool binary, const BuildTreeStatsType &stats)
Writes BuildTreeStats object. This works even if pointers are NULL.