All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
analyze-counts.cc
Go to the documentation of this file.
1 // bin/analyze-counts.cc
2 
3 // Copyright 2012-2016 Brno University of Technology (Author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
22 #include "base/kaldi-common.h"
23 #include "util/common-utils.h"
24 #include "fst/fstlib.h"
25 
26 #include <iomanip>
27 #include <algorithm>
28 #include <numeric>
29 
30 int main(int argc, char *argv[]) {
31  using namespace kaldi;
32  typedef kaldi::int32 int32;
33  typedef kaldi::uint64 uint64;
34  try {
35  const char *usage =
36  "Computes element counts from integer vector table.\n"
37  "(e.g. get pdf-counts to estimate DNN-output priors "
38  "for data analysis)\n"
39  "Verbosity : level 1 => print frequencies and histogram\n"
40  "\n"
41  "Usage: analyze-counts [options] <alignments-rspecifier> "
42  "<counts-wxfilname>\n"
43  "e.g.: \n"
44  " analyze-counts ark:1.ali prior.counts\n"
45  " Show phone counts by:\n"
46  " ali-to-phone --per-frame=true ark:1.ali ark:- |"
47  " analyze-counts --verbose=1 ark:- - >/dev/null\n";
48 
49  ParseOptions po(usage);
50 
51  bool binary = false;
52  std::string symbol_table_filename = "";
53 
54  po.Register("binary", &binary, "write in binary mode");
55  po.Register("symbol-table", &symbol_table_filename,
56  "Read symbol table for display of counts");
57 
58  int32 counts_dim = 0;
59  po.Register("counts-dim", &counts_dim,
60  "Output dimension of the counts, "
61  "a hint for dimension auto-detection.");
62 
63  std::string frame_weights;
64  po.Register("frame-weights", &frame_weights,
65  "Per-frame weights (counting weighted frames).");
66  std::string utt_weights;
67  po.Register("utt-weights", &utt_weights,
68  "Per-utterance weights (counting weighted frames).");
69 
70  po.Read(argc, argv);
71 
72  if (po.NumArgs() != 2) {
73  po.PrintUsage();
74  exit(1);
75  }
76 
77  std::string alignments_rspecifier = po.GetArg(1),
78  wxfilename = po.GetArg(2);
79 
80  SequentialInt32VectorReader alignment_reader(alignments_rspecifier);
81 
82  RandomAccessBaseFloatVectorReader weights_reader;
83  if (frame_weights != "") {
84  weights_reader.Open(frame_weights);
85  }
86  RandomAccessBaseFloatReader utt_weights_reader;
87  if (utt_weights != "") {
88  utt_weights_reader.Open(utt_weights);
89  }
90 
91  // Buffer for accumulating the counts
92  Vector<double> counts(counts_dim, kSetZero);
93 
94  int32 num_done = 0, num_other_error = 0;
95  for (; !alignment_reader.Done(); alignment_reader.Next()) {
96  std::string utt = alignment_reader.Key();
97  const std::vector<int32> &alignment = alignment_reader.Value();
98 
99  BaseFloat utt_w = 1.0;
100  // Check if per-utterance weights are provided
101  if (utt_weights != "") {
102  if (!utt_weights_reader.HasKey(utt)) {
103  KALDI_WARN << utt << ", missing per-utterance weight";
104  num_other_error++;
105  continue;
106  } else {
107  utt_w = utt_weights_reader.Value(utt);
108  }
109  }
110 
111  Vector<BaseFloat> frame_w;
112  // Check if per-frame weights are provided
113  if (frame_weights != "") {
114  if (!weights_reader.HasKey(utt)) {
115  KALDI_WARN << utt << ", missing per-frame weights";
116  num_other_error++;
117  continue;
118  } else {
119  frame_w = weights_reader.Value(utt);
120  KALDI_ASSERT(frame_w.Dim() == alignment.size());
121  }
122  }
123 
124  // Accumulate the counts
125  for (size_t i = 0; i < alignment.size(); i++) {
126  KALDI_ASSERT(alignment[i] >= 0);
127  // Extend the vector if it is not large enough to hold every pdf-ids
128  if (alignment[i] >= counts.Dim()) {
129  counts.Resize(alignment[i]+1, kCopyData);
130  }
131  if (frame_weights != "") {
132  counts(alignment[i]) += 1.0 * utt_w * frame_w(i);
133  } else {
134  counts(alignment[i]) += 1.0 * utt_w;
135  }
136  }
137  num_done++;
138  }
139 
140  // Report elements with zero counts
141  for (size_t i = 0; i < counts.Dim(); i++) {
142  if (0.0 == counts(i)) {
143  KALDI_WARN << "Zero count for label " << i << ", this is suspicious.";
144  }
145  }
146 
147  // Add a ``half-frame'' to all the elements to
148  // avoid zero-counts which would cause problems in decoding
149  Vector<double> counts_nozero(counts);
150  counts_nozero.Add(0.5);
151 
152  Output ko(wxfilename, binary);
153  counts_nozero.Write(ko.Stream(), binary);
154 
155  //
156  // THE REST IS FOR ANALYSIS, IT GETS PRINTED TO LOG
157  //
158  if (symbol_table_filename != "" || (kaldi::g_kaldi_verbose_level >= 1)) {
159  // load the symbol table
160  fst::SymbolTable *elem_syms = NULL;
161  if (symbol_table_filename != "") {
162  elem_syms = fst::SymbolTable::ReadText(symbol_table_filename);
163  if (!elem_syms)
164  KALDI_ERR << "Could not read symbol table from file "
165  << symbol_table_filename;
166  }
167 
168  // sort the counts
169  std::vector<std::pair<double, int32> > sorted_counts;
170  for (int32 i = 0; i < counts.Dim(); i++) {
171  sorted_counts.push_back(
172  std::make_pair(static_cast<double>(counts(i)), i));
173  }
174  std::sort(sorted_counts.begin(), sorted_counts.end());
175  std::ostringstream os;
176  double sum = counts.Sum();
177  os << "Printing...\n### The sorted count table," << std::endl;
178  os << "count\t(norm),\tid\t(symbol):" << std::endl;
179  for (int32 i = 0; i < sorted_counts.size(); i++) {
180  os << sorted_counts[i].first << "\t("
181  << static_cast<float>(sorted_counts[i].first) / sum << "),\t"
182  << sorted_counts[i].second << "\t"
183  << (elem_syms != NULL ? "(" +
184  elem_syms->Find(sorted_counts[i].second) + ")" : "")
185  << std::endl;
186  }
187  os << "\n#total " << sum
188  << " (" << static_cast<float>(sum)/100/3600 << "h)"
189  << std::endl;
190  KALDI_LOG << os.str();
191  }
192 
193  KALDI_LOG << "Summed " << num_done << " int32 vectors to counts, "
194  << "skipped " << num_other_error << " vectors.";
195  KALDI_LOG << "Counts written to " << wxfilename;
196  return 0;
197  } catch(const std::exception &e) {
198  std::cerr << e.what();
199  return -1;
200  }
201 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
bool Open(const std::string &rspecifier)
Real Sum() const
Returns sum of the elements.
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:687
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_WARN
Definition: kaldi-error.h:130
bool HasKey(const std::string &key)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
int32 g_kaldi_verbose_level
This is set by util/parse-options.{h, cc} if you set –verbose=? option.
Definition: kaldi-error.cc:40
void Add(Real c)
Add a constant to each element of a vector.
#define KALDI_LOG
Definition: kaldi-error.h:133
int main(int argc, char *argv[])
Sums the pdf vectors to counts, this is used to obtain prior counts for hybrid decoding.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:59
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.