All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
acc-tree-stats.cc File Reference
Include dependency graph for acc-tree-stats.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 Accumulate tree statistics for decision tree training. More...
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Accumulate tree statistics for decision tree training.

The program reads in a feature archive, and the corresponding alignments, and generates the sufficient statistics for the decision tree creation. Context width and central phone position are used to identify the contexts.Transition model is used as an input to identify the PDF's and the phones.

Definition at line 34 of file acc-tree-stats.cc.

References kaldi::AccumulateTreeStats(), kaldi::DeleteBuildTreeStats(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), ParseOptions::GetOptArg(), RandomAccessTableReader< Holder >::HasKey(), KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), TransitionModel::Read(), AccumulateTreeStatsOptions::Register(), ParseOptions::Register(), Output::Stream(), Input::Stream(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and kaldi::WriteBuildTreeStats().

34  {
35  using namespace kaldi;
36  typedef kaldi::int32 int32;
37  try {
38  const char *usage =
39  "Accumulate statistics for phonetic-context tree building.\n"
40  "Usage: acc-tree-stats [options] <model-in> <features-rspecifier> <alignments-rspecifier> <tree-accs-out>\n"
41  "e.g.: \n"
42  " acc-tree-stats 1.mdl scp:train.scp ark:1.ali 1.tacc\n";
43 
44  bool binary = true;
46  ParseOptions po(usage);
47  po.Register("binary", &binary, "Write output in binary mode");
48  opts.Register(&po);
49 
50  po.Read(argc, argv);
51 
52  if (po.NumArgs() != 4) {
53  po.PrintUsage();
54  exit(1);
55  }
56 
57  std::string model_filename = po.GetArg(1),
58  feature_rspecifier = po.GetArg(2),
59  alignment_rspecifier = po.GetArg(3),
60  accs_out_wxfilename = po.GetOptArg(4);
61 
62 
63  AccumulateTreeStatsInfo acc_tree_stats_info(opts);
64 
65  TransitionModel trans_model;
66  {
67  bool binary;
68  Input ki(model_filename, &binary);
69  trans_model.Read(ki.Stream(), binary);
70  // There is more in this file but we don't need it.
71  }
72 
73  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
74  RandomAccessInt32VectorReader alignment_reader(alignment_rspecifier);
75 
76  std::map<EventType, GaussClusterable*> tree_stats;
77 
78  int num_done = 0, num_no_alignment = 0, num_other_error = 0;
79 
80  for (; !feature_reader.Done(); feature_reader.Next()) {
81  std::string key = feature_reader.Key();
82  if (!alignment_reader.HasKey(key)) {
83  num_no_alignment++;
84  } else {
85  const Matrix<BaseFloat> &mat = feature_reader.Value();
86  const std::vector<int32> &alignment = alignment_reader.Value(key);
87 
88  if (alignment.size() != mat.NumRows()) {
89  KALDI_WARN << "Alignments has wrong size "<< (alignment.size())<<" vs. "<< (mat.NumRows());
90  num_other_error++;
91  continue;
92  }
93 
94  AccumulateTreeStats(trans_model,
95  acc_tree_stats_info,
96  alignment,
97  mat,
98  &tree_stats);
99  num_done++;
100  if (num_done % 1000 == 0)
101  KALDI_LOG << "Processed " << num_done << " utterances.";
102  }
103  }
104 
105  BuildTreeStatsType stats; // vectorized form.
106 
107  for (std::map<EventType, GaussClusterable*>::const_iterator iter = tree_stats.begin();
108  iter != tree_stats.end();
109  ++iter) {
110  stats.push_back(std::make_pair(iter->first, iter->second));
111  }
112  tree_stats.clear();
113 
114  {
115  Output ko(accs_out_wxfilename, binary);
116  WriteBuildTreeStats(ko.Stream(), binary, stats);
117  }
118  KALDI_LOG << "Accumulated stats for " << num_done << " files, "
119  << num_no_alignment << " failed due to no alignment, "
120  << num_other_error << " failed for other reasons.";
121  KALDI_LOG << "Number of separate stats (context-dependent states) is "
122  << stats.size();
123  DeleteBuildTreeStats(&stats);
124  if (num_done != 0) return 0;
125  else return 1;
126  } catch(const std::exception &e) {
127  std::cerr << e.what();
128  return -1;
129  }
130 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void Register(OptionsItf *opts)
Definition: tree-accu.h:47
void AccumulateTreeStats(const TransitionModel &trans_model, const AccumulateTreeStatsInfo &info, const std::vector< int32 > &alignment, const Matrix< BaseFloat > &features, std::map< EventType, GaussClusterable * > *stats)
Accumulates the stats needed for training context-dependency trees (in the "normal" way)...
Definition: tree-accu.cc:36
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:130
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
#define KALDI_LOG
Definition: kaldi-error.h:133
void WriteBuildTreeStats(std::ostream &os, bool binary, const BuildTreeStatsType &stats)
Writes BuildTreeStats object. This works even if pointers are NULL.