gmm-global-gselect-to-post.cc File Reference
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "gmm/diag-gmm.h"
#include "hmm/posterior.h"
Include dependency graph for gmm-global-gselect-to-post.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file gmm-global-gselect-to-post.cc.

References VectorBase< Real >::ApplySoftMax(), VectorBase< Real >::Dim(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), DiagGmm::LogLikelihoodsPreselect(), VectorBase< Real >::Max(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), VectorBase< Real >::Scale(), VectorBase< Real >::Sum(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

27  {
28  try {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31  typedef kaldi::int64 int64;
32 
33  const char *usage =
34  "Given features and Gaussian-selection (gselect) information for\n"
35  "a diagonal-covariance GMM, output per-frame posteriors for the selected\n"
36  "indices. Also supports pruning the posteriors if they are below\n"
37  "a stated threshold, (and renormalizing the rest to sum to one)\n"
38  "See also: gmm-gselect, fgmm-gselect, gmm-global-get-post,\n"
39  " fgmm-global-gselect-to-post\n"
40  "\n"
41  "Usage: gmm-global-gselect-to-post [options] <model-in> <feature-rspecifier> "
42  "<gselect-rspecifier> <post-wspecifier>\n"
43  "e.g.: gmm-global-gselect-to-post 1.dubm ark:- 'ark:gunzip -c 1.gselect|' ark:-\n";
44 
45  ParseOptions po(usage);
46 
47  BaseFloat min_post = 0.0;
48  po.Register("min-post", &min_post, "If nonzero, posteriors below this "
49  "threshold will be pruned away and the rest will be renormalized "
50  "to sum to one.");
51 
52  po.Read(argc, argv);
53 
54  if (po.NumArgs() != 4) {
55  po.PrintUsage();
56  exit(1);
57  }
58 
59  std::string model_rxfilename = po.GetArg(1),
60  feature_rspecifier = po.GetArg(2),
61  gselect_rspecifier = po.GetArg(3),
62  post_wspecifier = po.GetArg(4);
63 
64  DiagGmm gmm;
65  ReadKaldiObject(model_rxfilename, &gmm);
66 
67  double tot_loglike = 0.0, tot_frames = 0.0;
68  int64 tot_posts = 0;
69 
70  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
71  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
72  PosteriorWriter post_writer(post_wspecifier);
73  int32 num_done = 0, num_err = 0;
74 
75  for (; !feature_reader.Done(); feature_reader.Next()) {
76  std::string utt = feature_reader.Key();
77  const Matrix<BaseFloat> &mat = feature_reader.Value();
78 
79  int32 num_frames = mat.NumRows();
80  // typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
81  Posterior post(num_frames);
82 
83  if (!gselect_reader.HasKey(utt)) {
84  KALDI_WARN << "No gselect information for utterance " << utt;
85  num_err++;
86  continue;
87  }
88  const std::vector<std::vector<int32> > &gselect(gselect_reader.Value(utt));
89  if (static_cast<int32>(gselect.size()) != num_frames) {
90  KALDI_WARN << "gselect information for utterance " << utt
91  << " has wrong size " << gselect.size() << " vs. "
92  << num_frames;
93  num_err++;
94  continue;
95  }
96 
97  double this_tot_loglike = 0;
98  bool utt_ok = true;
99 
100  for (int32 t = 0; t < num_frames; t++) {
101  SubVector<BaseFloat> frame(mat, t);
102  const std::vector<int32> &this_gselect = gselect[t];
103  KALDI_ASSERT(!gselect[t].empty());
104  Vector<BaseFloat> loglikes;
105  gmm.LogLikelihoodsPreselect(frame, this_gselect, &loglikes);
106  this_tot_loglike += loglikes.ApplySoftMax();
107  // now "loglikes" contains posteriors.
108  if (fabs(loglikes.Sum() - 1.0) > 0.01) {
109  utt_ok = false;
110  } else {
111  if (min_post != 0.0) {
112  int32 max_index = 0; // in case all pruned away...
113  loglikes.Max(&max_index);
114  for (int32 i = 0; i < loglikes.Dim(); i++)
115  if (loglikes(i) < min_post)
116  loglikes(i) = 0.0;
117  BaseFloat sum = loglikes.Sum();
118  if (sum == 0.0) {
119  loglikes(max_index) = 1.0;
120  } else {
121  loglikes.Scale(1.0 / sum);
122  }
123  }
124  for (int32 i = 0; i < loglikes.Dim(); i++) {
125  if (loglikes(i) != 0.0) {
126  post[t].push_back(std::make_pair(this_gselect[i], loglikes(i)));
127  tot_posts++;
128  }
129  }
130  KALDI_ASSERT(!post[t].empty());
131  }
132  }
133  if (!utt_ok) {
134  KALDI_WARN << "Skipping utterance " << utt
135  << " because bad posterior-sum encountered (NaN?)";
136  num_err++;
137  } else {
138  post_writer.Write(utt, post);
139  num_done++;
140  KALDI_VLOG(2) << "Like/frame for utt " << utt << " was "
141  << (this_tot_loglike/num_frames) << " per frame over "
142  << num_frames << " frames.";
143  tot_loglike += this_tot_loglike;
144  tot_frames += num_frames;
145  }
146  }
147 
148  KALDI_LOG << "Done " << num_done << " files; " << num_err << " had errors.";
149  KALDI_LOG << "Overall loglike per frame is " << (tot_loglike / tot_frames)
150  << " with " << (tot_posts / tot_frames) << " entries per frame, "
151  << " over " << tot_frames << " frames";
152  return (num_done != 0 ? 0 : 1);
153  } catch(const std::exception &e) {
154  std::cerr << e.what();
155  return -1;
156  }
157 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void LogLikelihoodsPreselect(const VectorBase< BaseFloat > &data, const std::vector< int32 > &indices, Vector< BaseFloat > *loglikes) const
Outputs the per-component log-likelihoods of a subset of mixture components.
Definition: diag-gmm.cc:566
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
Real ApplySoftMax()
Apply soft-max to vector and return normalizer (log sum of exponentials).
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
Real Max() const
Returns the maximum value of any element, or -infinity for the empty vector.
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void Scale(Real alpha)
Multiplies all elements by this constant.
Real Sum() const
Returns sum of the elements.
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
#define KALDI_LOG
Definition: kaldi-error.h:153
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501