fgmm-global-acc-stats-post.cc
Go to the documentation of this file.
1 // fgmmbin/fgmm-global-acc-stats-post.cc
2 
3 // Copyright 2015 David Snyder
4 // 2015 Johns Hopkins University (Author: Daniel Povey)
5 // 2015 Johns Hopkins University (Author: Daniel Garcia-Romero)
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 
23 #include "base/kaldi-common.h"
24 #include "util/common-utils.h"
25 #include "gmm/model-common.h"
26 #include "gmm/full-gmm.h"
27 #include "gmm/diag-gmm.h"
28 #include "gmm/mle-full-gmm.h"
29 #include "hmm/posterior.h"
30 
31 
32 int main(int argc, char *argv[]) {
33  try {
34  using namespace kaldi;
35  typedef kaldi::int32 int32;
36 
37  const char *usage =
38  "Accumulate stats from posteriors and features for instantiating "
39  "a full-covariance GMM. See also fgmm-global-acc-stats.\n"
40  "Usage: fgmm-global-acc-stats-post [options] <posterior-rspecifier> "
41  "<number-of-components> <feature-rspecifier> <stats-out>\n"
42  "e.g.: fgmm-global-acc-stats-post scp:post.scp 2048 "
43  "scp:train.scp 1.acc\n";
44 
45  ParseOptions po(usage);
46  bool binary = true;
47  std::string update_flags_str = "mvw";
48  std::string weights_rspecifier;
49  po.Register("binary", &binary, "Write output in binary mode");
50  po.Register("update-flags", &update_flags_str, "Which GMM parameters will be "
51  "updated: subset of mvw.");
52  po.Register("weights", &weights_rspecifier, "rspecifier for a vector of floats "
53  "for each utterance, that's a per-frame weight.");
54  po.Read(argc, argv);
55 
56  if (po.NumArgs() != 4) {
57  po.PrintUsage();
58  exit(1);
59  }
60 
61  std::string post_rspecifier = po.GetArg(1),
62  feature_rspecifier = po.GetArg(3),
63  accs_wxfilename = po.GetArg(4);
64 
65  int32 num_components = atoi(po.GetArg(2).c_str());
66 
67  AccumFullGmm fgmm_accs;
68 
69  double tot_like = 0.0, tot_weight = 0.0;
70 
71  SequentialPosteriorReader post_reader(post_rspecifier);
72  RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
73  RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier);
74  int32 num_done = 0, num_err = 0;
75 
76  for (; !post_reader.Done(); post_reader.Next()) {
77  std::string key = post_reader.Key();
78  Posterior post = post_reader.Value();
79  if (!feature_reader.HasKey(key)) {
80  KALDI_WARN << "No features available for utterance "
81  << key;
82  num_err++;
83  continue;
84  }
85  const Matrix<BaseFloat> &mat = feature_reader.Value(key);
86  int32 file_frames = mat.NumRows();
87 
88  // Initialize the FGMM accs before processing the first utt.
89  if (num_done == 0) {
90  fgmm_accs.Resize(num_components, mat.NumCols(),
91  StringToGmmFlags(update_flags_str));
92  }
93 
94  BaseFloat file_like = 0.0,
95  file_weight = 0.0; // total of weights of frames (will each be
96  // 1 unless --weights option supplied.
97  Vector<BaseFloat> weights;
98  if (weights_rspecifier != "") { // We have per-frame weighting.
99  if (!weights_reader.HasKey(key)) {
100  KALDI_WARN << "No per-frame weights available for utterance "
101  << key;
102  num_err++;
103  continue;
104  }
105  weights = weights_reader.Value(key);
106  if (weights.Dim() != file_frames) {
107  KALDI_WARN << "Weights for utterance " << key << " have wrong dim "
108  << weights.Dim() << " vs. " << file_frames;
109  num_err++;
110  continue;
111  }
112  }
113 
114  if (post.size() != static_cast<size_t>(file_frames)) {
115  KALDI_WARN << "posterior information for utterance " << key
116  << " has wrong size " << post.size() << " vs. "
117  << file_frames;
118  num_err++;
119  continue;
120  }
121 
122  for (int32 i = 0; i < file_frames; i++) {
123  BaseFloat weight = (weights.Dim() != 0) ? weights(i) : 1.0;
124  if (weight == 0.0) continue;
125  file_weight += weight;
126  SubVector<BaseFloat> data(mat, i);
127  ScalePosterior(weight, &post);
128  file_like += TotalPosterior(post);
129  for (int32 j = 0; j < post[i].size(); j++)
130  fgmm_accs.AccumulateForComponent(data, post[i][j].first,
131  post[i][j].second);
132  }
133 
134  KALDI_VLOG(2) << "File '" << key << "': Average likelihood = "
135  << (file_like/file_weight) << " over "
136  << file_weight <<" frames.";
137  tot_like += file_like;
138  tot_weight += file_weight;
139  num_done++;
140  }
141  KALDI_LOG << "Done " << num_done << " files; "
142  << num_err << " with errors.";
143  KALDI_LOG << "Overall likelihood per "
144  << "frame = " << (tot_like/tot_weight) << " over "
145  << tot_weight << " (weighted) frames.";
146 
147  WriteKaldiObject(fgmm_accs, accs_wxfilename, binary);
148  KALDI_LOG << "Written accs to " << accs_wxfilename;
149  return (num_done != 0 ? 0 : 1);
150  } catch(const std::exception &e) {
151  std::cerr << e.what();
152  return -1;
153  }
154 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
Definition: model-common.cc:26
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
void Register(const std::string &name, bool *ptr, const std::string &doc)
BaseFloat TotalPosterior(const Posterior &post)
Returns the total of all the weights in "post".
Definition: posterior.cc:230
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
int main(int argc, char *argv[])
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
Class for computing the maximum-likelihood estimates of the parameters of a Gaussian mixture model...
Definition: mle-full-gmm.h:74
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
bool HasKey(const std::string &key)
void ScalePosterior(BaseFloat scale, Posterior *post)
Scales the BaseFloat (weight) element in the posterior entries.
Definition: posterior.cc:218
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
#define KALDI_LOG
Definition: kaldi-error.h:153
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501