All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
analyze-counts.cc File Reference
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fst/fstlib.h"
#include <iomanip>
#include <algorithm>
#include <numeric>
Include dependency graph for analyze-counts.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 Sums the pdf vectors to counts, this is used to obtain prior counts for hybrid decoding. More...
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Sums the pdf vectors to counts, this is used to obtain prior counts for hybrid decoding.

Definition at line 30 of file analyze-counts.cc.

References VectorBase< Real >::Add(), VectorBase< Real >::Dim(), SequentialTableReader< Holder >::Done(), kaldi::g_kaldi_verbose_level, ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_WARN, kaldi::kCopyData, SequentialTableReader< Holder >::Key(), kaldi::kSetZero, SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), RandomAccessTableReader< Holder >::Open(), ParseOptions::PrintUsage(), ParseOptions::Read(), ParseOptions::Register(), Vector< Real >::Resize(), Output::Stream(), VectorBase< Real >::Sum(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and VectorBase< Real >::Write().

30  {
31  using namespace kaldi;
32  typedef kaldi::int32 int32;
33  typedef kaldi::uint64 uint64;
34  try {
35  const char *usage =
36  "Computes element counts from integer vector table.\n"
37  "(e.g. get pdf-counts to estimate DNN-output priors "
38  "for data analysis)\n"
39  "Verbosity : level 1 => print frequencies and histogram\n"
40  "\n"
41  "Usage: analyze-counts [options] <alignments-rspecifier> "
42  "<counts-wxfilname>\n"
43  "e.g.: \n"
44  " analyze-counts ark:1.ali prior.counts\n"
45  " Show phone counts by:\n"
46  " ali-to-phones --per-frame=true ark:1.ali ark:- |"
47  " analyze-counts --verbose=1 ark:- - >/dev/null\n"
48  "Note: this is deprecated, see post-to-tacc.\n";
49 
50  ParseOptions po(usage);
51 
52  bool binary = false;
53  std::string symbol_table_filename = "";
54 
55  po.Register("binary", &binary, "write in binary mode");
56  po.Register("symbol-table", &symbol_table_filename,
57  "Read symbol table for display of counts");
58 
59  int32 counts_dim = 0;
60  po.Register("counts-dim", &counts_dim,
61  "Output dimension of the counts, "
62  "a hint for dimension auto-detection.");
63 
64  std::string frame_weights;
65  po.Register("frame-weights", &frame_weights,
66  "Per-frame weights (counting weighted frames).");
67  std::string utt_weights;
68  po.Register("utt-weights", &utt_weights,
69  "Per-utterance weights (counting weighted frames).");
70 
71  po.Read(argc, argv);
72 
73  if (po.NumArgs() != 2) {
74  po.PrintUsage();
75  exit(1);
76  }
77 
78  std::string alignments_rspecifier = po.GetArg(1),
79  wxfilename = po.GetArg(2);
80 
81  SequentialInt32VectorReader alignment_reader(alignments_rspecifier);
82 
83  RandomAccessBaseFloatVectorReader weights_reader;
84  if (frame_weights != "") {
85  weights_reader.Open(frame_weights);
86  }
87  RandomAccessBaseFloatReader utt_weights_reader;
88  if (utt_weights != "") {
89  utt_weights_reader.Open(utt_weights);
90  }
91 
92  // Buffer for accumulating the counts
93  Vector<double> counts(counts_dim, kSetZero);
94 
95  int32 num_done = 0, num_other_error = 0;
96  for (; !alignment_reader.Done(); alignment_reader.Next()) {
97  std::string utt = alignment_reader.Key();
98  const std::vector<int32> &alignment = alignment_reader.Value();
99 
100  BaseFloat utt_w = 1.0;
101  // Check if per-utterance weights are provided
102  if (utt_weights != "") {
103  if (!utt_weights_reader.HasKey(utt)) {
104  KALDI_WARN << utt << ", missing per-utterance weight";
105  num_other_error++;
106  continue;
107  } else {
108  utt_w = utt_weights_reader.Value(utt);
109  }
110  }
111 
112  Vector<BaseFloat> frame_w;
113  // Check if per-frame weights are provided
114  if (frame_weights != "") {
115  if (!weights_reader.HasKey(utt)) {
116  KALDI_WARN << utt << ", missing per-frame weights";
117  num_other_error++;
118  continue;
119  } else {
120  frame_w = weights_reader.Value(utt);
121  KALDI_ASSERT(frame_w.Dim() == alignment.size());
122  }
123  }
124 
125  // Accumulate the counts
126  for (size_t i = 0; i < alignment.size(); i++) {
127  KALDI_ASSERT(alignment[i] >= 0);
128  // Extend the vector if it is not large enough to hold every pdf-ids
129  if (alignment[i] >= counts.Dim()) {
130  counts.Resize(alignment[i]+1, kCopyData);
131  }
132  if (frame_weights != "") {
133  counts(alignment[i]) += 1.0 * utt_w * frame_w(i);
134  } else {
135  counts(alignment[i]) += 1.0 * utt_w;
136  }
137  }
138  num_done++;
139  }
140 
141  // Report elements with zero counts
142  for (size_t i = 0; i < counts.Dim(); i++) {
143  if (0.0 == counts(i)) {
144  KALDI_WARN << "Zero count for label " << i << ", this is suspicious.";
145  }
146  }
147 
148  // Add a ``half-frame'' to all the elements to
149  // avoid zero-counts which would cause problems in decoding
150  Vector<double> counts_nozero(counts);
151  counts_nozero.Add(0.5);
152 
153  Output ko(wxfilename, binary);
154  counts_nozero.Write(ko.Stream(), binary);
155 
156  //
157  // THE REST IS FOR ANALYSIS, IT GETS PRINTED TO LOG
158  //
159  if (symbol_table_filename != "" || (kaldi::g_kaldi_verbose_level >= 1)) {
160  // load the symbol table
161  fst::SymbolTable *elem_syms = NULL;
162  if (symbol_table_filename != "") {
163  elem_syms = fst::SymbolTable::ReadText(symbol_table_filename);
164  if (!elem_syms)
165  KALDI_ERR << "Could not read symbol table from file "
166  << symbol_table_filename;
167  }
168 
169  // sort the counts
170  std::vector<std::pair<double, int32> > sorted_counts;
171  for (int32 i = 0; i < counts.Dim(); i++) {
172  sorted_counts.push_back(
173  std::make_pair(static_cast<double>(counts(i)), i));
174  }
175  std::sort(sorted_counts.begin(), sorted_counts.end());
176  std::ostringstream os;
177  double sum = counts.Sum();
178  os << "Printing...\n### The sorted count table," << std::endl;
179  os << "count\t(norm),\tid\t(symbol):" << std::endl;
180  for (int32 i = 0; i < sorted_counts.size(); i++) {
181  os << sorted_counts[i].first << "\t("
182  << static_cast<float>(sorted_counts[i].first) / sum << "),\t"
183  << sorted_counts[i].second << "\t"
184  << (elem_syms != NULL ? "(" +
185  elem_syms->Find(sorted_counts[i].second) + ")" : "")
186  << std::endl;
187  }
188  os << "\n#total " << sum
189  << " (" << static_cast<float>(sum)/100/3600 << "h)"
190  << std::endl;
191  KALDI_LOG << os.str();
192  }
193 
194  KALDI_LOG << "Summed " << num_done << " int32 vectors to counts, "
195  << "skipped " << num_other_error << " vectors.";
196  KALDI_LOG << "Counts written to " << wxfilename;
197  return 0;
198  } catch(const std::exception &e) {
199  std::cerr << e.what();
200  return -1;
201  }
202 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
bool Open(const std::string &rspecifier)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_WARN
Definition: kaldi-error.h:130
bool HasKey(const std::string &key)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
int32 g_kaldi_verbose_level
This is set by util/parse-options.{h, cc} if you set –verbose=? option.
Definition: kaldi-error.cc:40
#define KALDI_LOG
Definition: kaldi-error.h:133
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:62