gmm-global-get-frame-likes.cc
Go to the documentation of this file.
1 // gmmbin/gmm-global-get-frame-likes.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation; Saarland University
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "gmm/model-common.h"
24 #include "gmm/full-gmm.h"
25 #include "gmm/diag-gmm.h"
26 #include "gmm/mle-full-gmm.h"
27 
28 
29 int main(int argc, char *argv[]) {
30  try {
31  using namespace kaldi;
32 
33  const char *usage =
34  "Print out per-frame log-likelihoods for each utterance, as an archive\n"
35  "of vectors of floats. If --average=true, prints out the average per-frame\n"
36  "log-likelihood for each utterance, as a single float.\n"
37  "Usage: gmm-global-get-frame-likes [options] <model-in> <feature-rspecifier> "
38  "<likes-out-wspecifier>\n"
39  "e.g.: gmm-global-get-frame-likes 1.mdl scp:train.scp ark:1.likes\n";
40 
41  ParseOptions po(usage);
42  bool average = false;
43  std::string gselect_rspecifier;
44  po.Register("gselect", &gselect_rspecifier, "rspecifier for gselect objects "
45  "to limit the #Gaussians accessed on each frame.");
46  po.Register("average", &average, "If true, print out the average per-frame "
47  "log-likelihood as a single float per utterance.");
48  po.Read(argc, argv);
49 
50  if (po.NumArgs() != 3) {
51  po.PrintUsage();
52  exit(1);
53  }
54 
55  std::string model_filename = po.GetArg(1),
56  feature_rspecifier = po.GetArg(2),
57  likes_wspecifier = po.GetArg(3);
58 
59  DiagGmm gmm;
60  {
61  bool binary_read;
62  Input ki(model_filename, &binary_read);
63  gmm.Read(ki.Stream(), binary_read);
64  }
65 
66  double tot_like = 0.0, tot_frames = 0.0;
67 
68  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
69  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
70  BaseFloatVectorWriter likes_writer(average ? "" : likes_wspecifier);
71  BaseFloatWriter average_likes_writer(average ? likes_wspecifier : "");
72  int32 num_done = 0, num_err = 0;
73 
74  for (; !feature_reader.Done(); feature_reader.Next()) {
75  std::string key = feature_reader.Key();
76  const Matrix<BaseFloat> &mat = feature_reader.Value();
77  int32 file_frames = mat.NumRows();
78  Vector<BaseFloat> likes(file_frames);
79 
80  if (gselect_rspecifier != "") {
81  if (!gselect_reader.HasKey(key)) {
82  KALDI_WARN << "No gselect information for utterance " << key;
83  num_err++;
84  continue;
85  }
86  const std::vector<std::vector<int32> > &gselect =
87  gselect_reader.Value(key);
88  if (gselect.size() != static_cast<size_t>(file_frames)) {
89  KALDI_WARN << "gselect information for utterance " << key
90  << " has wrong size " << gselect.size() << " vs. "
91  << file_frames;
92  num_err++;
93  continue;
94  }
95 
96  for (int32 i = 0; i < file_frames; i++) {
97  SubVector<BaseFloat> data(mat, i);
98  const std::vector<int32> &this_gselect = gselect[i];
99  int32 gselect_size = this_gselect.size();
100  KALDI_ASSERT(gselect_size > 0);
101  Vector<BaseFloat> loglikes;
102  gmm.LogLikelihoodsPreselect(data, this_gselect, &loglikes);
103  likes(i) = loglikes.LogSumExp();
104  }
105  } else { // no gselect..
106  for (int32 i = 0; i < file_frames; i++)
107  likes(i) = gmm.LogLikelihood(mat.Row(i));
108  }
109 
110  tot_like += likes.Sum();
111  tot_frames += file_frames;
112  if (average)
113  average_likes_writer.Write(key, likes.Sum() / file_frames);
114  else
115  likes_writer.Write(key, likes);
116  num_done++;
117  }
118  KALDI_LOG << "Done " << num_done << " files; " << num_err
119  << " with errors.";
120  KALDI_LOG << "Overall likelihood per "
121  << "frame = " << (tot_like/tot_frames) << " over " << tot_frames
122  << " frames.";
123  return (num_done != 0 ? 0 : 1);
124  } catch(const std::exception &e) {
125  std::cerr << e.what();
126  return -1;
127  }
128 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void LogLikelihoodsPreselect(const VectorBase< BaseFloat > &data, const std::vector< int32 > &indices, Vector< BaseFloat > *loglikes) const
Outputs the per-component log-likelihoods of a subset of mixture components.
Definition: diag-gmm.cc:566
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
Real LogSumExp(Real prune=-1.0) const
Returns log(sum(exp())) without exp overflow If prune > 0.0, ignores terms less than the max - prune...
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
int main(int argc, char *argv[])
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
BaseFloat LogLikelihood(const VectorBase< BaseFloat > &data) const
Returns the log-likelihood of a data point (vector) given the GMM.
Definition: diag-gmm.cc:517
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
Real Sum() const
Returns sum of the elements.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Read(std::istream &in, bool binary)
Definition: diag-gmm.cc:728
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
#define KALDI_LOG
Definition: kaldi-error.h:153
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501