analyze-counts.cc
Go to the documentation of this file.
1 // bin/analyze-counts.cc
2 
3 // Copyright 2012-2016 Brno University of Technology (Author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
22 #include "base/kaldi-common.h"
23 #include "util/common-utils.h"
24 #include "fst/fstlib.h"
25 
26 #include <iomanip>
27 #include <algorithm>
28 #include <numeric>
29 
30 int main(int argc, char *argv[]) {
31  using namespace kaldi;
32  typedef kaldi::int32 int32;
33  typedef kaldi::uint64 uint64;
34  try {
35  const char *usage =
36  "Computes element counts from integer vector table.\n"
37  "(e.g. get pdf-counts to estimate DNN-output priors "
38  "for data analysis)\n"
39  "Verbosity : level 1 => print frequencies and histogram\n"
40  "\n"
41  "Usage: analyze-counts [options] <alignments-rspecifier> "
42  "<counts-wxfilname>\n"
43  "e.g.: \n"
44  " analyze-counts ark:1.ali prior.counts\n"
45  " Show phone counts by:\n"
46  " ali-to-phones --per-frame=true ark:1.ali ark:- |"
47  " analyze-counts --verbose=1 ark:- - >/dev/null\n"
48  "Note: this is deprecated, see post-to-tacc.\n";
49 
50  ParseOptions po(usage);
51 
52  bool binary = false;
53  std::string symbol_table_filename = "";
54 
55  po.Register("binary", &binary, "write in binary mode");
56  po.Register("symbol-table", &symbol_table_filename,
57  "Read symbol table for display of counts");
58 
59  int32 counts_dim = 0;
60  po.Register("counts-dim", &counts_dim,
61  "Output dimension of the counts, "
62  "a hint for dimension auto-detection.");
63 
64  std::string frame_weights;
65  po.Register("frame-weights", &frame_weights,
66  "Per-frame weights (counting weighted frames).");
67  std::string utt_weights;
68  po.Register("utt-weights", &utt_weights,
69  "Per-utterance weights (counting weighted frames).");
70 
71  po.Read(argc, argv);
72 
73  if (po.NumArgs() != 2) {
74  po.PrintUsage();
75  exit(1);
76  }
77 
78  std::string alignments_rspecifier = po.GetArg(1),
79  wxfilename = po.GetArg(2);
80 
81  SequentialInt32VectorReader alignment_reader(alignments_rspecifier);
82 
83  RandomAccessBaseFloatVectorReader weights_reader;
84  if (frame_weights != "") {
85  weights_reader.Open(frame_weights);
86  }
87  RandomAccessBaseFloatReader utt_weights_reader;
88  if (utt_weights != "") {
89  utt_weights_reader.Open(utt_weights);
90  }
91 
92  // Buffer for accumulating the counts
93  Vector<double> counts(counts_dim, kSetZero);
94 
95  int32 num_done = 0, num_other_error = 0;
96  for (; !alignment_reader.Done(); alignment_reader.Next()) {
97  std::string utt = alignment_reader.Key();
98  const std::vector<int32> &alignment = alignment_reader.Value();
99 
100  BaseFloat utt_w = 1.0;
101  // Check if per-utterance weights are provided
102  if (utt_weights != "") {
103  if (!utt_weights_reader.HasKey(utt)) {
104  KALDI_WARN << utt << ", missing per-utterance weight";
105  num_other_error++;
106  continue;
107  } else {
108  utt_w = utt_weights_reader.Value(utt);
109  }
110  }
111 
112  Vector<BaseFloat> frame_w;
113  // Check if per-frame weights are provided
114  if (frame_weights != "") {
115  if (!weights_reader.HasKey(utt)) {
116  KALDI_WARN << utt << ", missing per-frame weights";
117  num_other_error++;
118  continue;
119  } else {
120  frame_w = weights_reader.Value(utt);
121  KALDI_ASSERT(frame_w.Dim() == alignment.size());
122  }
123  }
124 
125  // Accumulate the counts
126  for (size_t i = 0; i < alignment.size(); i++) {
127  KALDI_ASSERT(alignment[i] >= 0);
128  // Extend the vector if it is not large enough to hold every pdf-ids
129  if (alignment[i] >= counts.Dim()) {
130  counts.Resize(alignment[i]+1, kCopyData);
131  }
132  if (frame_weights != "") {
133  counts(alignment[i]) += 1.0 * utt_w * frame_w(i);
134  } else {
135  counts(alignment[i]) += 1.0 * utt_w;
136  }
137  }
138  num_done++;
139  }
140 
141  // Report elements with zero counts
142  for (size_t i = 0; i < counts.Dim(); i++) {
143  if (0.0 == counts(i)) {
144  KALDI_WARN << "Zero count for label " << i << ", this is suspicious.";
145  }
146  }
147 
148  // Add a ``half-frame'' to all the elements to
149  // avoid zero-counts which would cause problems in decoding
150  Vector<double> counts_nozero(counts);
151  counts_nozero.Add(0.5);
152 
153  Output ko(wxfilename, binary);
154  counts_nozero.Write(ko.Stream(), binary);
155 
156  //
157  // THE REST IS FOR ANALYSIS, IT GETS PRINTED TO LOG
158  //
159  if (symbol_table_filename != "" || (kaldi::g_kaldi_verbose_level >= 1)) {
160  // load the symbol table
161  fst::SymbolTable *elem_syms = NULL;
162  if (symbol_table_filename != "") {
163  elem_syms = fst::SymbolTable::ReadText(symbol_table_filename);
164  if (!elem_syms)
165  KALDI_ERR << "Could not read symbol table from file "
166  << symbol_table_filename;
167  }
168 
169  // sort the counts
170  std::vector<std::pair<double, int32> > sorted_counts;
171  for (int32 i = 0; i < counts.Dim(); i++) {
172  sorted_counts.push_back(
173  std::make_pair(static_cast<double>(counts(i)), i));
174  }
175  std::sort(sorted_counts.begin(), sorted_counts.end());
176  std::ostringstream os;
177  double sum = counts.Sum();
178  os << "Printing...\n### The sorted count table," << std::endl;
179  os << "count\t(norm),\tid\t(symbol):" << std::endl;
180  for (int32 i = 0; i < sorted_counts.size(); i++) {
181  os << sorted_counts[i].first << "\t("
182  << static_cast<float>(sorted_counts[i].first) / sum << "),\t"
183  << sorted_counts[i].second << "\t"
184  << (elem_syms != NULL ? "(" +
185  elem_syms->Find(sorted_counts[i].second) + ")" : "")
186  << std::endl;
187  }
188  os << "\n#total " << sum
189  << " (" << static_cast<float>(sum)/100/3600 << "h)"
190  << std::endl;
191  KALDI_LOG << os.str();
192  }
193 
194  KALDI_LOG << "Summed " << num_done << " int32 vectors to counts, "
195  << "skipped " << num_other_error << " vectors.";
196  KALDI_LOG << "Counts written to " << wxfilename;
197  return 0;
198  } catch(const std::exception &e) {
199  std::cerr << e.what();
200  return -1;
201  }
202 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
bool Open(const std::string &rspecifier)
kaldi::int32 int32
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
bool HasKey(const std::string &key)
Real Sum() const
Returns sum of the elements.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 g_kaldi_verbose_level
This is set by util/parse-options.
Definition: kaldi-error.cc:46
void Add(Real c)
Add a constant to each element of a vector.
#define KALDI_LOG
Definition: kaldi-error.h:153
int main(int argc, char *argv[])
Sums the pdf vectors to counts, this is used to obtain prior counts for hybrid decoding.