All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
gmm-fmpe-acc-stats.cc File Reference
Include dependency graph for gmm-fmpe-acc-stats.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 28 of file gmm-fmpe-acc-stats.cc.

References Fmpe::AccStats(), kaldi::ComputeAmGmmFeatureDeriv(), Fmpe::ComputeFeatures(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), AmDiagGmm::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), Input::Stream(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and FmpeStats::Write().

28  {
29  using namespace kaldi;
30  using kaldi::int32;
31  try {
32  const char *usage =
33  "Accumulate stats for fMPE training, using GMM model. Note: this could\n"
34  "be done using gmm-get-feat-deriv and fmpe-acc-stats (but you'd be computing\n"
35  "the features twice). Features input should be pre-fMPE features.\n"
36  "\n"
37  "Usage: gmm-fmpe-acc-stats [options] <model-in> <fmpe-in> <feature-rspecifier> "
38  "<gselect-rspecifier> <posteriors-rspecifier> <fmpe-stats-out>\n"
39  "e.g.: \n"
40  " gmm-fmpe-acc-stats --model-derivative 1.accs 1.mdl 1.fmpe \"$feats\" ark:1.gselect ark:1.post 1.fmpe_stats\n";
41 
42  ParseOptions po(usage);
43  bool binary = true;
44  std::string model_derivative_rxfilename;
45  po.Register("binary", &binary, "If true, write stats in binary mode.");
46  po.Register("model-derivative", &model_derivative_rxfilename,
47  "GMM-accs file containing model derivative [note: contains no transition stats]. Used for indirect differential. Warning: this will only work correctly in the case of MMI/BMMI objective function, with non-canceled stats.");
48  po.Read(argc, argv);
49 
50  if (po.NumArgs() != 6) {
51  po.PrintUsage();
52  exit(1);
53  }
54 
55  std::string model_rxfilename = po.GetArg(1),
56  fmpe_rxfilename = po.GetArg(2),
57  feature_rspecifier = po.GetArg(3),
58  gselect_rspecifier = po.GetArg(4),
59  posteriors_rspecifier = po.GetArg(5),
60  stats_wxfilename = po.GetArg(6);
61 
62  AmDiagGmm am_gmm;
63  TransitionModel trans_model;
64  {
65  bool binary;
66  Input ki(model_rxfilename, &binary);
67  trans_model.Read(ki.Stream(), binary);
68  am_gmm.Read(ki.Stream(), binary);
69  }
70 
71  Fmpe fmpe;
72  ReadKaldiObject(fmpe_rxfilename, &fmpe);
73 
74 
75  bool have_indirect = (model_derivative_rxfilename != "");
76  AccumAmDiagGmm model_derivative;
77  if (have_indirect)
78  ReadKaldiObject(model_derivative_rxfilename, &model_derivative);
79 
80  FmpeStats fmpe_stats(fmpe);
81 
82  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
83  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
84  RandomAccessPosteriorReader posteriors_reader(posteriors_rspecifier);
85 
86  BaseFloat tot_like = 0.0; // tot like weighted by posterior.
87  int32 num_frames = 0;
88  int32 num_done = 0, num_err = 0;
89 
90  for (; !feature_reader.Done(); feature_reader.Next()) {
91  std::string key = feature_reader.Key();
92  if (!posteriors_reader.HasKey(key)) {
93  num_err++;
94  KALDI_WARN << "No posteriors for utterance " << key;
95  continue;
96  }
97  const Matrix<BaseFloat> &feat_in = feature_reader.Value();
98  const Posterior &posterior = posteriors_reader.Value(key);
99 
100  if (static_cast<int32>(posterior.size()) != feat_in.NumRows()) {
101  KALDI_WARN << "Posterior vector has wrong size " <<
102  (posterior.size()) << " vs. "<< (feat_in.NumRows());
103  num_err++;
104  continue;
105  }
106 
107  if (!gselect_reader.HasKey(key)) {
108  KALDI_WARN << "No gselect information for key " << key;
109  num_err++;
110  continue;
111  }
112  const std::vector<std::vector<int32> > &gselect =
113  gselect_reader.Value(key);
114  if (static_cast<int32>(gselect.size()) != feat_in.NumRows()) {
115  KALDI_WARN << "gselect information has wrong size";
116  num_err++;
117  continue;
118  }
119 
120  num_done++;
121  Matrix<BaseFloat> fmpe_feat(feat_in.NumRows(), feat_in.NumCols());
122  fmpe.ComputeFeatures(feat_in, gselect, &fmpe_feat);
123  fmpe_feat.AddMat(1.0, feat_in);
124 
125  Matrix<BaseFloat> direct_deriv, indirect_deriv;
126 
127  tot_like += ComputeAmGmmFeatureDeriv(am_gmm, trans_model, posterior,
128  fmpe_feat, &direct_deriv,
129  (have_indirect ? &model_derivative : NULL),
130  (have_indirect ? &indirect_deriv : NULL));
131  num_frames += feat_in.NumRows();
132 
133  fmpe.AccStats(feat_in, gselect, direct_deriv,
134  (have_indirect ? &indirect_deriv : NULL), &fmpe_stats);
135 
136  if (num_done % 100 == 0)
137  KALDI_LOG << "Processed " << num_done << " utterances.";
138  }
139 
140  KALDI_LOG << "Done " << num_done << " files, " << num_err
141  << " with errors.";
142  KALDI_LOG << "Overall weighted acoustic likelihood per frame is "
143  << (tot_like/num_frames) << " over " << num_frames << " frames.";
144 
145  Output ko(stats_wxfilename, binary);
146  fmpe_stats.Write(ko.Stream(), binary);
147 
148  return (num_done != 0 ? 0 : 1);
149  } catch(const std::exception &e) {
150  std::cerr << e.what();
151  return -1;
152  }
153 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void AccStats(const MatrixBase< BaseFloat > &feat_in, const std::vector< std::vector< int32 > > &gselect, const MatrixBase< BaseFloat > &direct_feat_deriv, const MatrixBase< BaseFloat > *indirect_feat_deriv, FmpeStats *stats) const
Definition: fmpe.cc:395
BaseFloat ComputeAmGmmFeatureDeriv(const AmDiagGmm &am_gmm, const TransitionModel &trans_model, const Posterior &posterior, const MatrixBase< BaseFloat > &features, Matrix< BaseFloat > *direct_deriv, const AccumAmDiagGmm *model_diff, Matrix< BaseFloat > *indirect_deriv)
Computes derivatives of the likelihood of these states (weighted), w.r.t.
Definition: fmpe.cc:522
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:818
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void ComputeFeatures(const MatrixBase< BaseFloat > &feat_in, const std::vector< std::vector< int32 > > &gselect, Matrix< BaseFloat > *feat_out) const
Definition: fmpe.cc:370
#define KALDI_WARN
Definition: kaldi-error.h:130
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
#define KALDI_LOG
Definition: kaldi-error.h:133
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147