All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
gmm-global-get-frame-likes.cc File Reference
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "gmm/model-common.h"
#include "gmm/full-gmm.h"
#include "gmm/diag-gmm.h"
#include "gmm/mle-full-gmm.h"
Include dependency graph for gmm-global-get-frame-likes.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 29 of file gmm-global-get-frame-likes.cc.

References SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), DiagGmm::LogLikelihood(), DiagGmm::LogLikelihoodsPreselect(), VectorBase< Real >::LogSumExp(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), DiagGmm::Read(), ParseOptions::Register(), MatrixBase< Real >::Row(), Input::Stream(), VectorBase< Real >::Sum(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

29  {
30  try {
31  using namespace kaldi;
32 
33  const char *usage =
34  "Print out per-frame log-likelihoods for each utterance, as an archive\n"
35  "of vectors of floats. If --average=true, prints out the average per-frame\n"
36  "log-likelihood for each utterance, as a single float.\n"
37  "Usage: gmm-global-get-frame-likes [options] <model-in> <feature-rspecifier> "
38  "<likes-out-wspecifier>\n"
39  "e.g.: gmm-global-get-frame-likes 1.mdl scp:train.scp ark:1.likes\n";
40 
41  ParseOptions po(usage);
42  bool average = false;
43  std::string gselect_rspecifier;
44  po.Register("gselect", &gselect_rspecifier, "rspecifier for gselect objects "
45  "to limit the #Gaussians accessed on each frame.");
46  po.Register("average", &average, "If true, print out the average per-frame "
47  "log-likelihood as a single float per utterance.");
48  po.Read(argc, argv);
49 
50  if (po.NumArgs() != 3) {
51  po.PrintUsage();
52  exit(1);
53  }
54 
55  std::string model_filename = po.GetArg(1),
56  feature_rspecifier = po.GetArg(2),
57  likes_wspecifier = po.GetArg(3);
58 
59  DiagGmm gmm;
60  {
61  bool binary_read;
62  Input ki(model_filename, &binary_read);
63  gmm.Read(ki.Stream(), binary_read);
64  }
65 
66  double tot_like = 0.0, tot_frames = 0.0;
67 
68  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
69  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
70  BaseFloatVectorWriter likes_writer(average ? "" : likes_wspecifier);
71  BaseFloatWriter average_likes_writer(average ? likes_wspecifier : "");
72  int32 num_done = 0, num_err = 0;
73 
74  for (; !feature_reader.Done(); feature_reader.Next()) {
75  std::string key = feature_reader.Key();
76  const Matrix<BaseFloat> &mat = feature_reader.Value();
77  int32 file_frames = mat.NumRows();
78  Vector<BaseFloat> likes(file_frames);
79 
80  if (gselect_rspecifier != "") {
81  if (!gselect_reader.HasKey(key)) {
82  KALDI_WARN << "No gselect information for utterance " << key;
83  num_err++;
84  continue;
85  }
86  const std::vector<std::vector<int32> > &gselect =
87  gselect_reader.Value(key);
88  if (gselect.size() != static_cast<size_t>(file_frames)) {
89  KALDI_WARN << "gselect information for utterance " << key
90  << " has wrong size " << gselect.size() << " vs. "
91  << file_frames;
92  num_err++;
93  continue;
94  }
95 
96  for (int32 i = 0; i < file_frames; i++) {
97  SubVector<BaseFloat> data(mat, i);
98  const std::vector<int32> &this_gselect = gselect[i];
99  int32 gselect_size = this_gselect.size();
100  KALDI_ASSERT(gselect_size > 0);
101  Vector<BaseFloat> loglikes;
102  gmm.LogLikelihoodsPreselect(data, this_gselect, &loglikes);
103  likes(i) = loglikes.LogSumExp();
104  }
105  } else { // no gselect..
106  for (int32 i = 0; i < file_frames; i++)
107  likes(i) = gmm.LogLikelihood(mat.Row(i));
108  }
109 
110  tot_like += likes.Sum();
111  tot_frames += file_frames;
112  if (average)
113  average_likes_writer.Write(key, likes.Sum() / file_frames);
114  else
115  likes_writer.Write(key, likes);
116  num_done++;
117  }
118  KALDI_LOG << "Done " << num_done << " files; " << num_err
119  << " with errors.";
120  KALDI_LOG << "Overall likelihood per "
121  << "frame = " << (tot_like/tot_frames) << " over " << tot_frames
122  << " frames.";
123  return (num_done != 0 ? 0 : 1);
124  } catch(const std::exception &e) {
125  std::cerr << e.what();
126  return -1;
127  }
128 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void LogLikelihoodsPreselect(const VectorBase< BaseFloat > &data, const std::vector< int32 > &indices, Vector< BaseFloat > *loglikes) const
Outputs the per-component log-likelihoods of a subset of mixture components.
Definition: diag-gmm.cc:566
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:366
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:182
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
Real LogSumExp(Real prune=-1.0) const
Returns log(sum(exp())) without exp overflow If prune > 0.0, ignores terms less than the max - prune...
BaseFloat LogLikelihood(const VectorBase< BaseFloat > &data) const
Returns the log-likelihood of a data point (vector) given the GMM.
Definition: diag-gmm.cc:517
#define KALDI_WARN
Definition: kaldi-error.h:130
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
void Read(std::istream &in, bool binary)
Definition: diag-gmm.cc:728
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
#define KALDI_LOG
Definition: kaldi-error.h:133
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:482