All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
fgmm-global-gselect-to-post.cc File Reference
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "gmm/full-gmm.h"
#include "hmm/posterior.h"
Include dependency graph for fgmm-global-gselect-to-post.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file fgmm-global-gselect-to-post.cc.

References VectorBase< Real >::ApplySoftMax(), VectorBase< Real >::Dim(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), FullGmm::LogLikelihoodsPreselect(), VectorBase< Real >::Max(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), VectorBase< Real >::Scale(), VectorBase< Real >::Sum(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

27  {
28  try {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31  typedef kaldi::int64 int64;
32 
33  const char *usage =
34  "Given features and Gaussian-selection (gselect) information for\n"
35  "a full-covariance GMM, output per-frame posteriors for the selected\n"
36  "indices. Also supports pruning the posteriors if they are below\n"
37  "a stated threshold, (and renormalizing the rest to sum to one)\n"
38  "See also: gmm-gselect, fgmm-gselect, gmm-global-get-post,\n"
39  " gmm-global-gselect-to-post\n"
40  "\n"
41  "Usage: fgmm-global-gselect-to-post [options] <model-in> <feature-rspecifier> "
42  "<gselect-rspecifier> <post-wspecifier>\n"
43  "e.g.: fgmm-global-gselect-to-post 1.ubm ark:- 'ark:gunzip -c 1.gselect|' ark:-\n";
44 
45  ParseOptions po(usage);
46 
47  BaseFloat min_post = 0.0;
48  po.Register("min-post", &min_post, "If nonzero, posteriors below this "
49  "threshold will be pruned away and the rest will be renormalized "
50  "to sum to one.");
51 
52  po.Read(argc, argv);
53 
54  if (po.NumArgs() != 4) {
55  po.PrintUsage();
56  exit(1);
57  }
58 
59  std::string model_rxfilename = po.GetArg(1),
60  feature_rspecifier = po.GetArg(2),
61  gselect_rspecifier = po.GetArg(3),
62  post_wspecifier = po.GetArg(4);
63 
64  FullGmm fgmm;
65  ReadKaldiObject(model_rxfilename, &fgmm);
66 
67  double tot_loglike = 0.0, tot_frames = 0.0;
68  int64 tot_posts = 0;
69 
70  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
71  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
72  PosteriorWriter post_writer(post_wspecifier);
73  int32 num_done = 0, num_err = 0;
74 
75  for (; !feature_reader.Done(); feature_reader.Next()) {
76  std::string utt = feature_reader.Key();
77  const Matrix<BaseFloat> &mat = feature_reader.Value();
78 
79  int32 num_frames = mat.NumRows();
80  // typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
81  Posterior post(num_frames);
82 
83  if (!gselect_reader.HasKey(utt)) {
84  KALDI_WARN << "No gselect information for utterance " << utt;
85  num_err++;
86  continue;
87  }
88  const std::vector<std::vector<int32> > &gselect(gselect_reader.Value(utt));
89  if (static_cast<int32>(gselect.size()) != num_frames) {
90  KALDI_WARN << "gselect information for utterance " << utt
91  << " has wrong size " << gselect.size() << " vs. "
92  << num_frames;
93  num_err++;
94  continue;
95  }
96 
97  double this_tot_loglike = 0;
98  bool utt_ok = true;
99 
100  for (int32 t = 0; t < num_frames; t++) {
101  SubVector<BaseFloat> frame(mat, t);
102  const std::vector<int32> &this_gselect = gselect[t];
103  KALDI_ASSERT(!gselect[t].empty());
104  Vector<BaseFloat> loglikes;
105  fgmm.LogLikelihoodsPreselect(frame, this_gselect, &loglikes);
106  this_tot_loglike += loglikes.ApplySoftMax();
107  // now "loglikes" contains posteriors.
108  if (fabs(loglikes.Sum() - 1.0) > 0.01) {
109  utt_ok = false;
110  } else {
111  if (min_post != 0.0) {
112  int32 max_index = 0; // in case all pruned away...
113  loglikes.Max(&max_index);
114  for (int32 i = 0; i < loglikes.Dim(); i++)
115  if (loglikes(i) < min_post)
116  loglikes(i) = 0.0;
117  BaseFloat sum = loglikes.Sum();
118  if (sum == 0.0) {
119  loglikes(max_index) = 1.0;
120  } else {
121  loglikes.Scale(1.0 / sum);
122  }
123  }
124  for (int32 i = 0; i < loglikes.Dim(); i++) {
125  if (loglikes(i) != 0.0) {
126  post[t].push_back(std::make_pair(this_gselect[i], loglikes(i)));
127  tot_posts++;
128  }
129  }
130  KALDI_ASSERT(!post[t].empty());
131  }
132  }
133  if (!utt_ok) {
134  KALDI_WARN << "Skipping utterance " << utt
135  << " because bad posterior-sum encountered (NaN?)";
136  num_err++;
137  } else {
138  post_writer.Write(utt, post);
139  num_done++;
140  KALDI_VLOG(2) << "Like/frame for utt " << utt << " was "
141  << (this_tot_loglike/num_frames) << " per frame over "
142  << num_frames << " frames.";
143  tot_loglike += this_tot_loglike;
144  tot_frames += num_frames;
145  }
146  }
147 
148  KALDI_LOG << "Done " << num_done << " files; " << num_err << " had errors.";
149  KALDI_LOG << "Overall loglike per frame is " << (tot_loglike / tot_frames)
150  << " with " << (tot_posts / tot_frames) << " entries per frame, "
151  << " over " << tot_frames << " frames";
152  return (num_done != 0 ? 0 : 1);
153  } catch(const std::exception &e) {
154  std::cerr << e.what();
155  return -1;
156  }
157 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
Definition for Gaussian Mixture Model with full covariances.
Definition: full-gmm.h:40
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:366
Real Sum() const
Returns sum of the elements.
Real ApplySoftMax()
Apply soft-max to vector and return normalizer (log sum of exponentials).
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:818
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
Real Max() const
Returns the maximum value of any element, or -infinity for the empty vector.
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void LogLikelihoodsPreselect(const VectorBase< BaseFloat > &data, const std::vector< int32 > &indices, Vector< BaseFloat > *loglikes) const
Outputs the per-component log-likelihoods of a subset of mixture components.
Definition: full-gmm.cc:613
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:130
void Scale(Real alpha)
Multiplies all elements by this constant.
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
#define KALDI_LOG
Definition: kaldi-error.h:133
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:482
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:62