gmm-global-gselect-to-post.cc
Go to the documentation of this file.
1 // gmmbin/gmm-global-gselect-to-post.cc
2 
3 // Copyright 2013 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "gmm/diag-gmm.h"
24 #include "hmm/posterior.h"
25 
26 
27 int main(int argc, char *argv[]) {
28  try {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31  typedef kaldi::int64 int64;
32 
33  const char *usage =
34  "Given features and Gaussian-selection (gselect) information for\n"
35  "a diagonal-covariance GMM, output per-frame posteriors for the selected\n"
36  "indices. Also supports pruning the posteriors if they are below\n"
37  "a stated threshold, (and renormalizing the rest to sum to one)\n"
38  "See also: gmm-gselect, fgmm-gselect, gmm-global-get-post,\n"
39  " fgmm-global-gselect-to-post\n"
40  "\n"
41  "Usage: gmm-global-gselect-to-post [options] <model-in> <feature-rspecifier> "
42  "<gselect-rspecifier> <post-wspecifier>\n"
43  "e.g.: gmm-global-gselect-to-post 1.dubm ark:- 'ark:gunzip -c 1.gselect|' ark:-\n";
44 
45  ParseOptions po(usage);
46 
47  BaseFloat min_post = 0.0;
48  po.Register("min-post", &min_post, "If nonzero, posteriors below this "
49  "threshold will be pruned away and the rest will be renormalized "
50  "to sum to one.");
51 
52  po.Read(argc, argv);
53 
54  if (po.NumArgs() != 4) {
55  po.PrintUsage();
56  exit(1);
57  }
58 
59  std::string model_rxfilename = po.GetArg(1),
60  feature_rspecifier = po.GetArg(2),
61  gselect_rspecifier = po.GetArg(3),
62  post_wspecifier = po.GetArg(4);
63 
64  DiagGmm gmm;
65  ReadKaldiObject(model_rxfilename, &gmm);
66 
67  double tot_loglike = 0.0, tot_frames = 0.0;
68  int64 tot_posts = 0;
69 
70  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
71  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
72  PosteriorWriter post_writer(post_wspecifier);
73  int32 num_done = 0, num_err = 0;
74 
75  for (; !feature_reader.Done(); feature_reader.Next()) {
76  std::string utt = feature_reader.Key();
77  const Matrix<BaseFloat> &mat = feature_reader.Value();
78 
79  int32 num_frames = mat.NumRows();
80  // typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
81  Posterior post(num_frames);
82 
83  if (!gselect_reader.HasKey(utt)) {
84  KALDI_WARN << "No gselect information for utterance " << utt;
85  num_err++;
86  continue;
87  }
88  const std::vector<std::vector<int32> > &gselect(gselect_reader.Value(utt));
89  if (static_cast<int32>(gselect.size()) != num_frames) {
90  KALDI_WARN << "gselect information for utterance " << utt
91  << " has wrong size " << gselect.size() << " vs. "
92  << num_frames;
93  num_err++;
94  continue;
95  }
96 
97  double this_tot_loglike = 0;
98  bool utt_ok = true;
99 
100  for (int32 t = 0; t < num_frames; t++) {
101  SubVector<BaseFloat> frame(mat, t);
102  const std::vector<int32> &this_gselect = gselect[t];
103  KALDI_ASSERT(!gselect[t].empty());
104  Vector<BaseFloat> loglikes;
105  gmm.LogLikelihoodsPreselect(frame, this_gselect, &loglikes);
106  this_tot_loglike += loglikes.ApplySoftMax();
107  // now "loglikes" contains posteriors.
108  if (fabs(loglikes.Sum() - 1.0) > 0.01) {
109  utt_ok = false;
110  } else {
111  if (min_post != 0.0) {
112  int32 max_index = 0; // in case all pruned away...
113  loglikes.Max(&max_index);
114  for (int32 i = 0; i < loglikes.Dim(); i++)
115  if (loglikes(i) < min_post)
116  loglikes(i) = 0.0;
117  BaseFloat sum = loglikes.Sum();
118  if (sum == 0.0) {
119  loglikes(max_index) = 1.0;
120  } else {
121  loglikes.Scale(1.0 / sum);
122  }
123  }
124  for (int32 i = 0; i < loglikes.Dim(); i++) {
125  if (loglikes(i) != 0.0) {
126  post[t].push_back(std::make_pair(this_gselect[i], loglikes(i)));
127  tot_posts++;
128  }
129  }
130  KALDI_ASSERT(!post[t].empty());
131  }
132  }
133  if (!utt_ok) {
134  KALDI_WARN << "Skipping utterance " << utt
135  << " because bad posterior-sum encountered (NaN?)";
136  num_err++;
137  } else {
138  post_writer.Write(utt, post);
139  num_done++;
140  KALDI_VLOG(2) << "Like/frame for utt " << utt << " was "
141  << (this_tot_loglike/num_frames) << " per frame over "
142  << num_frames << " frames.";
143  tot_loglike += this_tot_loglike;
144  tot_frames += num_frames;
145  }
146  }
147 
148  KALDI_LOG << "Done " << num_done << " files; " << num_err << " had errors.";
149  KALDI_LOG << "Overall loglike per frame is " << (tot_loglike / tot_frames)
150  << " with " << (tot_posts / tot_frames) << " entries per frame, "
151  << " over " << tot_frames << " frames";
152  return (num_done != 0 ? 0 : 1);
153  } catch(const std::exception &e) {
154  std::cerr << e.what();
155  return -1;
156  }
157 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void LogLikelihoodsPreselect(const VectorBase< BaseFloat > &data, const std::vector< int32 > &indices, Vector< BaseFloat > *loglikes) const
Outputs the per-component log-likelihoods of a subset of mixture components.
Definition: diag-gmm.cc:566
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
Real ApplySoftMax()
Apply soft-max to vector and return normalizer (log sum of exponentials).
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
Real Max() const
Returns the maximum value of any element, or -infinity for the empty vector.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int main(int argc, char *argv[])
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void Scale(Real alpha)
Multiplies all elements by this constant.
bool HasKey(const std::string &key)
Real Sum() const
Returns sum of the elements.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
#define KALDI_LOG
Definition: kaldi-error.h:153
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501