gmm-fmpe-acc-stats.cc
Go to the documentation of this file.
1 // gmmbin/gmm-fmpe-acc-stats.cc
2 
3 // Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "gmm/am-diag-gmm.h"
24 #include "hmm/transition-model.h"
25 #include "transform/fmpe.h"
26 
27 
28 int main(int argc, char *argv[]) {
29  using namespace kaldi;
30  using kaldi::int32;
31  try {
32  const char *usage =
33  "Accumulate stats for fMPE training, using GMM model. Note: this could\n"
34  "be done using gmm-get-feat-deriv and fmpe-acc-stats (but you'd be computing\n"
35  "the features twice). Features input should be pre-fMPE features.\n"
36  "\n"
37  "Usage: gmm-fmpe-acc-stats [options] <model-in> <fmpe-in> <feature-rspecifier> "
38  "<gselect-rspecifier> <posteriors-rspecifier> <fmpe-stats-out>\n"
39  "e.g.: \n"
40  " gmm-fmpe-acc-stats --model-derivative 1.accs 1.mdl 1.fmpe \"$feats\" ark:1.gselect ark:1.post 1.fmpe_stats\n";
41 
42  ParseOptions po(usage);
43  bool binary = true;
44  std::string model_derivative_rxfilename;
45  po.Register("binary", &binary, "If true, write stats in binary mode.");
46  po.Register("model-derivative", &model_derivative_rxfilename,
47  "GMM-accs file containing model derivative [note: contains no transition stats]. Used for indirect differential. Warning: this will only work correctly in the case of MMI/BMMI objective function, with non-canceled stats.");
48  po.Read(argc, argv);
49 
50  if (po.NumArgs() != 6) {
51  po.PrintUsage();
52  exit(1);
53  }
54 
55  std::string model_rxfilename = po.GetArg(1),
56  fmpe_rxfilename = po.GetArg(2),
57  feature_rspecifier = po.GetArg(3),
58  gselect_rspecifier = po.GetArg(4),
59  posteriors_rspecifier = po.GetArg(5),
60  stats_wxfilename = po.GetArg(6);
61 
62  AmDiagGmm am_gmm;
63  TransitionModel trans_model;
64  {
65  bool binary;
66  Input ki(model_rxfilename, &binary);
67  trans_model.Read(ki.Stream(), binary);
68  am_gmm.Read(ki.Stream(), binary);
69  }
70 
71  Fmpe fmpe;
72  ReadKaldiObject(fmpe_rxfilename, &fmpe);
73 
74 
75  bool have_indirect = (model_derivative_rxfilename != "");
76  AccumAmDiagGmm model_derivative;
77  if (have_indirect)
78  ReadKaldiObject(model_derivative_rxfilename, &model_derivative);
79 
80  FmpeStats fmpe_stats(fmpe);
81 
82  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
83  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
84  RandomAccessPosteriorReader posteriors_reader(posteriors_rspecifier);
85 
86  BaseFloat tot_like = 0.0; // tot like weighted by posterior.
87  int32 num_frames = 0;
88  int32 num_done = 0, num_err = 0;
89 
90  for (; !feature_reader.Done(); feature_reader.Next()) {
91  std::string key = feature_reader.Key();
92  if (!posteriors_reader.HasKey(key)) {
93  num_err++;
94  KALDI_WARN << "No posteriors for utterance " << key;
95  continue;
96  }
97  const Matrix<BaseFloat> &feat_in = feature_reader.Value();
98  const Posterior &posterior = posteriors_reader.Value(key);
99 
100  if (static_cast<int32>(posterior.size()) != feat_in.NumRows()) {
101  KALDI_WARN << "Posterior vector has wrong size " <<
102  (posterior.size()) << " vs. "<< (feat_in.NumRows());
103  num_err++;
104  continue;
105  }
106 
107  if (!gselect_reader.HasKey(key)) {
108  KALDI_WARN << "No gselect information for key " << key;
109  num_err++;
110  continue;
111  }
112  const std::vector<std::vector<int32> > &gselect =
113  gselect_reader.Value(key);
114  if (static_cast<int32>(gselect.size()) != feat_in.NumRows()) {
115  KALDI_WARN << "gselect information has wrong size";
116  num_err++;
117  continue;
118  }
119 
120  num_done++;
121  Matrix<BaseFloat> fmpe_feat(feat_in.NumRows(), feat_in.NumCols());
122  fmpe.ComputeFeatures(feat_in, gselect, &fmpe_feat);
123  fmpe_feat.AddMat(1.0, feat_in);
124 
125  Matrix<BaseFloat> direct_deriv, indirect_deriv;
126 
127  tot_like += ComputeAmGmmFeatureDeriv(am_gmm, trans_model, posterior,
128  fmpe_feat, &direct_deriv,
129  (have_indirect ? &model_derivative : NULL),
130  (have_indirect ? &indirect_deriv : NULL));
131  num_frames += feat_in.NumRows();
132 
133  fmpe.AccStats(feat_in, gselect, direct_deriv,
134  (have_indirect ? &indirect_deriv : NULL), &fmpe_stats);
135 
136  if (num_done % 100 == 0)
137  KALDI_LOG << "Processed " << num_done << " utterances.";
138  }
139 
140  KALDI_LOG << "Done " << num_done << " files, " << num_err
141  << " with errors.";
142  KALDI_LOG << "Overall weighted acoustic likelihood per frame is "
143  << (tot_like/num_frames) << " over " << num_frames << " frames.";
144 
145  Output ko(stats_wxfilename, binary);
146  fmpe_stats.Write(ko.Stream(), binary);
147 
148  return (num_done != 0 ? 0 : 1);
149  } catch(const std::exception &e) {
150  std::cerr << e.what();
151  return -1;
152  }
153 }
154 
155 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ComputeFeatures(const MatrixBase< BaseFloat > &feat_in, const std::vector< std::vector< int32 > > &gselect, Matrix< BaseFloat > *feat_out) const
Definition: fmpe.cc:370
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
BaseFloat ComputeAmGmmFeatureDeriv(const AmDiagGmm &am_gmm, const TransitionModel &trans_model, const Posterior &posterior, const MatrixBase< BaseFloat > &features, Matrix< BaseFloat > *direct_deriv, const AccumAmDiagGmm *model_diff, Matrix< BaseFloat > *indirect_deriv)
Computes derivatives of the likelihood of these states (weighted), w.r.t.
Definition: fmpe.cc:522
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
Definition: fmpe.cc:680
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
void AccStats(const MatrixBase< BaseFloat > &feat_in, const std::vector< std::vector< int32 > > &gselect, const MatrixBase< BaseFloat > &direct_feat_deriv, const MatrixBase< BaseFloat > *indirect_feat_deriv, FmpeStats *stats) const
Definition: fmpe.cc:395
int main(int argc, char *argv[])