All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
fgmm-global-acc-stats-post.cc File Reference
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "gmm/model-common.h"
#include "gmm/full-gmm.h"
#include "gmm/diag-gmm.h"
#include "gmm/mle-full-gmm.h"
#include "hmm/posterior.h"
Include dependency graph for fgmm-global-acc-stats-post.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 32 of file fgmm-global-acc-stats-post.cc.

References VectorBase< Real >::Dim(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, rnnlm::j, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), ParseOptions::Register(), kaldi::ScalePosterior(), kaldi::StringToGmmFlags(), kaldi::TotalPosterior(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and kaldi::WriteKaldiObject().

32  {
33  try {
34  using namespace kaldi;
35  typedef kaldi::int32 int32;
36 
37  const char *usage =
38  "Accumulate stats from posteriors and features for instantiating "
39  "a full-covariance GMM. See also fgmm-global-acc-stats.\n"
40  "Usage: fgmm-global-acc-stats-post [options] <posterior-rspecifier> "
41  "<number-of-components> <feature-rspecifier> <stats-out>\n"
42  "e.g.: fgmm-global-acc-stats-post scp:post.scp 2048 "
43  "scp:train.scp 1.acc\n";
44 
45  ParseOptions po(usage);
46  bool binary = true;
47  std::string update_flags_str = "mvw";
48  std::string weights_rspecifier;
49  po.Register("binary", &binary, "Write output in binary mode");
50  po.Register("update-flags", &update_flags_str, "Which GMM parameters will be "
51  "updated: subset of mvw.");
52  po.Register("weights", &weights_rspecifier, "rspecifier for a vector of floats "
53  "for each utterance, that's a per-frame weight.");
54  po.Read(argc, argv);
55 
56  if (po.NumArgs() != 4) {
57  po.PrintUsage();
58  exit(1);
59  }
60 
61  std::string post_rspecifier = po.GetArg(1),
62  feature_rspecifier = po.GetArg(3),
63  accs_wxfilename = po.GetArg(4);
64 
65  int32 num_components = atoi(po.GetArg(2).c_str());
66 
67  AccumFullGmm fgmm_accs;
68 
69  double tot_like = 0.0, tot_weight = 0.0;
70 
71  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
72  RandomAccessPosteriorReader post_reader(post_rspecifier);
73  RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier);
74  int32 num_done = 0, num_err = 0;
75 
76  for (; !feature_reader.Done(); feature_reader.Next()) {
77  std::string key = feature_reader.Key();
78  const Matrix<BaseFloat> &mat = feature_reader.Value();
79  int32 file_frames = mat.NumRows();
80  if (!post_reader.HasKey(key)) {
81  KALDI_WARN << "No posteriors available for utterance "
82  << key;
83  num_err++;
84  continue;
85  }
86 
87  Posterior post = post_reader.Value(key);
88  // Initialize the FGMM accs before processing the first utt.
89  if (num_done == 0) {
90  fgmm_accs.Resize(num_components, mat.NumCols(),
91  StringToGmmFlags(update_flags_str));
92  }
93 
94  BaseFloat file_like = 0.0,
95  file_weight = 0.0; // total of weights of frames (will each be
96  // 1 unless --weights option supplied.
97  Vector<BaseFloat> weights;
98  if (weights_rspecifier != "") { // We have per-frame weighting.
99  if (!weights_reader.HasKey(key)) {
100  KALDI_WARN << "No per-frame weights available for utterance "
101  << key;
102  num_err++;
103  continue;
104  }
105  weights = weights_reader.Value(key);
106  if (weights.Dim() != file_frames) {
107  KALDI_WARN << "Weights for utterance " << key << " have wrong dim "
108  << weights.Dim() << " vs. " << file_frames;
109  num_err++;
110  continue;
111  }
112  }
113 
114  if (post.size() != static_cast<size_t>(file_frames)) {
115  KALDI_WARN << "posterior information for utterance " << key
116  << " has wrong size " << post.size() << " vs. "
117  << file_frames;
118  num_err++;
119  continue;
120  }
121 
122  for (int32 i = 0; i < file_frames; i++) {
123  BaseFloat weight = (weights.Dim() != 0) ? weights(i) : 1.0;
124  if (weight == 0.0) continue;
125  file_weight += weight;
126  SubVector<BaseFloat> data(mat, i);
127  ScalePosterior(weight, &post);
128  file_like += TotalPosterior(post);
129  for (int32 j = 0; j < post[i].size(); j++)
130  fgmm_accs.AccumulateForComponent(data, post[i][j].first,
131  post[i][j].second);
132  }
133 
134  KALDI_VLOG(2) << "File '" << key << "': Average likelihood = "
135  << (file_like/file_weight) << " over "
136  << file_weight <<" frames.";
137  tot_like += file_like;
138  tot_weight += file_weight;
139  num_done++;
140  }
141  KALDI_LOG << "Done " << num_done << " files; "
142  << num_err << " with errors.";
143  KALDI_LOG << "Overall likelihood per "
144  << "frame = " << (tot_like/tot_weight) << " over "
145  << tot_weight << " (weighted) frames.";
146 
147  WriteKaldiObject(fgmm_accs, accs_wxfilename, binary);
148  KALDI_LOG << "Written accs to " << accs_wxfilename;
149  return (num_done != 0 ? 0 : 1);
150  } catch(const std::exception &e) {
151  std::cerr << e.what();
152  return -1;
153  }
154 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
Definition: model-common.cc:26
BaseFloat TotalPosterior(const Posterior &post)
Returns the total of all the weights in "post".
Definition: posterior.cc:230
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
Class for computing the maximum-likelihood estimates of the parameters of a Gaussian mixture model...
Definition: mle-full-gmm.h:74
#define KALDI_WARN
Definition: kaldi-error.h:130
void ScalePosterior(BaseFloat scale, Posterior *post)
Scales the BaseFloat (weight) element in the posterior entries.
Definition: posterior.cc:218
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
#define KALDI_LOG
Definition: kaldi-error.h:133
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:482
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:62