30 int main(
int argc,
char *argv[]) {
31 using namespace kaldi;
33 using namespace kaldi;
37 "Train decision tree\n" 38 "Usage: build-tree [options] <tree-stats-in> <roots-file> <questions-file> <topo-file> <tree-out>\n" 40 " build-tree treeacc roots.txt 1.qst topo tree\n";
48 bool round_num_leaves =
true;
49 std::string occs_out_filename;
52 po.
Register(
"binary", &binary,
"Write output in binary mode");
53 po.
Register(
"context-width", &N,
"Context window size [must match " 55 po.
Register(
"central-position", &P,
"Central position in context window " 56 "[must match acc-tree-stats]");
57 po.
Register(
"max-leaves", &max_leaves,
"Maximum number of leaves to be " 58 "used in tree-buliding (if positive)");
59 po.
Register(
"thresh", &thresh,
"Log-likelihood change threshold for " 61 po.
Register(
"cluster-thresh", &cluster_thresh,
"Log-likelihood change " 62 "threshold for clustering after tree-building. 0 means " 63 "no clustering; -1 means use as a clustering threshold the " 64 "likelihood change of the final split.");
65 po.
Register(
"round-num-leaves", &round_num_leaves,
66 "If true, then the number of leaves will be reduced to a " 67 "multiple of 8 by clustering.");
76 std::string stats_filename = po.
GetArg(1),
77 roots_filename = po.
GetArg(2),
78 questions_filename = po.
GetArg(3),
79 topo_filename = po.
GetArg(4),
80 tree_out_filename = po.
GetArg(5);
86 std::vector<std::vector<int32> > phone_sets;
87 std::vector<bool> is_shared_root;
88 std::vector<bool> is_split_root;
90 Input ki(roots_filename.c_str());
91 ReadRootsFile(ki.Stream(), &phone_sets, &is_shared_root, &is_split_root);
101 Input ki(stats_filename, &binary_in);
104 KALDI_LOG <<
"Number of separate statistics is " << stats.size();
110 Input ki(questions_filename, &binary_in);
112 }
catch (
const std::exception &e) {
113 KALDI_ERR <<
"Error reading questions file "<<questions_filename<<
", error is: " << e.what();
118 std::vector<int32> phone2num_pdf_classes;
119 topo.GetPhoneToNumPdfClasses(&phone2num_pdf_classes);
127 phone2num_pdf_classes,
138 std::vector<BuildTreeStatsType> split_stats;
141 for (
size_t i = 0;
i < split_stats.size();
i++)
143 KALDI_VLOG(1) <<
"For pdf-id " << i <<
", low count " 155 std::vector<int32> all_phones;
156 for (
size_t i = 0;
i < phone_sets.size();
i++)
157 all_phones.insert(all_phones.end(),
158 phone_sets[
i].begin(), phone_sets[
i].end());
160 if (all_phones != topo.GetPhones()) {
161 std::ostringstream ss;
165 KALDI_WARN <<
"Mismatch between phone sets provided in roots file, and those in topology: " << ss.str();
167 std::vector<int32> phones_vec;
170 std::vector<int32> unseen_phones;
171 for (
size_t i = 0;
i < all_phones.size();
i++)
172 if (!std::binary_search(phones_vec.begin(), phones_vec.end(), all_phones[
i]))
173 unseen_phones.push_back(all_phones[
i]);
174 for (
size_t i = 0; i < phones_vec.size(); i++)
175 if (!std::binary_search(all_phones.begin(), all_phones.end(), phones_vec[
i]))
177 <<
" appears in stats but is not listed in roots file.";
178 if (!unseen_phones.empty()) {
179 std::ostringstream ss;
180 for (
size_t i = 0; i < unseen_phones.size(); i++)
181 ss << unseen_phones[i] <<
' ';
185 KALDI_WARN <<
"Saw no stats for following phones: " << ss.str();
192 }
catch(
const std::exception &e) {
193 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in)
Sums the normalizer [typically, data-count] over the stats.
This class defines, for each EventKeyType, a set of initial questions that it tries and also a number...
A class for storing topology information for phones.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void SplitStatsByMap(const BuildTreeStatsType &stats, const EventMap &e, std::vector< BuildTreeStatsType > *stats_out)
Splits stats according to the EventMap, indexing them at output by the leaf type. ...
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
void Register(const std::string &name, bool *ptr, const std::string &doc)
bool PossibleValues(EventKeyType key, const BuildTreeStatsType &stats, std::vector< EventValueType > *ans)
Convenience function e.g.
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
void ReadBuildTreeStats(std::istream &is, bool binary, const Clusterable &example, BuildTreeStatsType *stats)
Reads BuildTreeStats object.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int main(int argc, char *argv[])
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class that is capable of representing a generic mapping from EventType (which is a vector of (key...
EventMap * BuildTree(Questions &qopts, const std::vector< std::vector< int32 > > &phone_sets, const std::vector< int32 > &phone2num_pdf_classes, const std::vector< bool > &share_roots, const std::vector< bool > &do_split, const BuildTreeStatsType &stats, BaseFloat thresh, int32 max_leaves, BaseFloat cluster_thresh, int32 P, bool round_num_leaves)
BuildTree is the normal way to build a set of decision trees.
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
void ReadRootsFile(std::istream &is, std::vector< std::vector< int32 > > *phone_sets, std::vector< bool > *is_shared_root, std::vector< bool > *is_split_root)
Reads the roots file (throws on error).
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...