All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
acc-tree-stats.cc
Go to the documentation of this file.
1 // bin/acc-tree-stats.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation, GoVivace Inc.
4 // 2013 Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "tree/context-dep.h"
24 #include "tree/build-tree-utils.h"
25 #include "hmm/transition-model.h"
26 #include "hmm/tree-accu.h"
27 
34 int main(int argc, char *argv[]) {
35  using namespace kaldi;
36  typedef kaldi::int32 int32;
37  try {
38  const char *usage =
39  "Accumulate statistics for phonetic-context tree building.\n"
40  "Usage: acc-tree-stats [options] <model-in> <features-rspecifier> <alignments-rspecifier> <tree-accs-out>\n"
41  "e.g.: \n"
42  " acc-tree-stats 1.mdl scp:train.scp ark:1.ali 1.tacc\n";
43 
44  bool binary = true;
46  ParseOptions po(usage);
47  po.Register("binary", &binary, "Write output in binary mode");
48  opts.Register(&po);
49 
50  po.Read(argc, argv);
51 
52  if (po.NumArgs() != 4) {
53  po.PrintUsage();
54  exit(1);
55  }
56 
57  std::string model_filename = po.GetArg(1),
58  feature_rspecifier = po.GetArg(2),
59  alignment_rspecifier = po.GetArg(3),
60  accs_out_wxfilename = po.GetOptArg(4);
61 
62 
63  AccumulateTreeStatsInfo acc_tree_stats_info(opts);
64 
65  TransitionModel trans_model;
66  {
67  bool binary;
68  Input ki(model_filename, &binary);
69  trans_model.Read(ki.Stream(), binary);
70  // There is more in this file but we don't need it.
71  }
72 
73  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
74  RandomAccessInt32VectorReader alignment_reader(alignment_rspecifier);
75 
76  std::map<EventType, GaussClusterable*> tree_stats;
77 
78  int num_done = 0, num_no_alignment = 0, num_other_error = 0;
79 
80  for (; !feature_reader.Done(); feature_reader.Next()) {
81  std::string key = feature_reader.Key();
82  if (!alignment_reader.HasKey(key)) {
83  num_no_alignment++;
84  } else {
85  const Matrix<BaseFloat> &mat = feature_reader.Value();
86  const std::vector<int32> &alignment = alignment_reader.Value(key);
87 
88  if (alignment.size() != mat.NumRows()) {
89  KALDI_WARN << "Alignments has wrong size "<< (alignment.size())<<" vs. "<< (mat.NumRows());
90  num_other_error++;
91  continue;
92  }
93 
94  AccumulateTreeStats(trans_model,
95  acc_tree_stats_info,
96  alignment,
97  mat,
98  &tree_stats);
99  num_done++;
100  if (num_done % 1000 == 0)
101  KALDI_LOG << "Processed " << num_done << " utterances.";
102  }
103  }
104 
105  BuildTreeStatsType stats; // vectorized form.
106 
107  for (std::map<EventType, GaussClusterable*>::const_iterator iter = tree_stats.begin();
108  iter != tree_stats.end();
109  ++iter) {
110  stats.push_back(std::make_pair(iter->first, iter->second));
111  }
112  tree_stats.clear();
113 
114  {
115  Output ko(accs_out_wxfilename, binary);
116  WriteBuildTreeStats(ko.Stream(), binary, stats);
117  }
118  KALDI_LOG << "Accumulated stats for " << num_done << " files, "
119  << num_no_alignment << " failed due to no alignment, "
120  << num_other_error << " failed for other reasons.";
121  KALDI_LOG << "Number of separate stats (context-dependent states) is "
122  << stats.size();
123  DeleteBuildTreeStats(&stats);
124  if (num_done != 0) return 0;
125  else return 1;
126  } catch(const std::exception &e) {
127  std::cerr << e.what();
128  return -1;
129  }
130 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void Register(OptionsItf *opts)
Definition: tree-accu.h:47
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void AccumulateTreeStats(const TransitionModel &trans_model, const AccumulateTreeStatsInfo &info, const std::vector< int32 > &alignment, const Matrix< BaseFloat > &features, std::map< EventType, GaussClusterable * > *stats)
Accumulates the stats needed for training context-dependency trees (in the "normal" way)...
Definition: tree-accu.cc:36
std::string GetOptArg(int param) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
void DeleteBuildTreeStats(BuildTreeStatsType *stats)
This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL...
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::istream & Stream()
Definition: kaldi-io.cc:812
int main(int argc, char *argv[])
Accumulate tree statistics for decision tree training.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:687
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:130
bool HasKey(const std::string &key)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
std::vector< std::pair< EventType, Clusterable * > > BuildTreeStatsType
#define KALDI_LOG
Definition: kaldi-error.h:133
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void WriteBuildTreeStats(std::ostream &os, bool binary, const BuildTreeStatsType &stats)
Writes BuildTreeStats object. This works even if pointers are NULL.