gmm-decode-faster-regtree-fmllr.cc
Go to the documentation of this file.
1 // gmmbin/gmm-decode-faster-regtree-fmllr.cc
2 
3 // Copyright 2009-2012 Microsoft Corporation; Saarland University;
4 // Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <string>
22 #include <vector>
23 
24 #include "base/kaldi-common.h"
25 #include "util/common-utils.h"
26 #include "gmm/am-diag-gmm.h"
27 #include "hmm/transition-model.h"
31 #include "fstext/fstext-lib.h"
32 #include "decoder/faster-decoder.h"
34 #include "base/timer.h"
35 #include "lat/kaldi-lattice.h" // for {Compact}LatticeArc
36 
37 using fst::SymbolTable;
38 using fst::VectorFst;
39 using fst::StdArc;
40 using kaldi::BaseFloat;
41 using std::string;
42 using std::vector;
44 using kaldi::LatticeArc;
45 
46 struct DecodeInfo {
47  public:
50  BaseFloat scale, bool allow_partial,
51  const kaldi::Int32VectorWriter &wwriter,
52  const kaldi::Int32VectorWriter &awriter, fst::SymbolTable *wsyms)
53  : acoustic_model(am), trans_model(tm), decoder(decoder),
54  acoustic_scale(scale), allow_partial(allow_partial), words_writer(wwriter),
55  alignment_writer(awriter), word_syms(wsyms) {}
56 
64  fst::SymbolTable *word_syms;
65 
66  private:
68 };
69 
71  kaldi::DecodableInterface *decodable,
72  DecodeInfo *info,
73  const string &uttid,
74  int32 num_frames,
75  BaseFloat *total_like) {
76  decoder->Decode(decodable);
77  KALDI_LOG << "Length of file is " << num_frames;
78 
79  VectorFst<LatticeArc> decoded; // linear FST.
80  if ( (info->allow_partial || decoder->ReachedFinal())
81  && decoder->GetBestPath(&decoded) ) {
82  if (!decoder->ReachedFinal())
83  KALDI_WARN << "Decoder did not reach end-state, outputting partial "
84  "traceback.";
85 
86  vector<kaldi::int32> alignment, words;
87  LatticeWeight weight;
88  GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
89 
90  info->words_writer.Write(uttid, words);
91  if (info->alignment_writer.IsOpen())
92  info->alignment_writer.Write(uttid, alignment);
93  if (info->word_syms != NULL) {
94  std::ostringstream ss;
95  ss << uttid << ' ';
96  for (size_t i = 0; i < words.size(); i++) {
97  string s = info->word_syms->Find(words[i]);
98  if (s == "")
99  KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
100  ss << s << ' ';
101  }
102  ss << '\n';
103  KALDI_LOG << ss.str();
104  }
105 
106  BaseFloat like = -weight.Value1() -weight.Value2();
107  KALDI_LOG << "Log-like per frame = " << (like/num_frames);
108  (*total_like) += like;
109  return true;
110  } else {
111  KALDI_WARN << "Did not successfully decode utterance, length = "
112  << num_frames;
113  return false;
114  }
115 }
116 
117 int main(int argc, char *argv[]) {
118  try {
119  using namespace kaldi;
120  typedef kaldi::int32 int32;
121 
122  const char *usage = "Decode features using GMM-based model.\n"
123  "Usage: gmm-decode-faster-regtree-fmllr [options] model-in fst-in "
124  "regtree-in features-rspecifier transforms-rspecifier "
125  "words-wspecifier [alignments-wspecifier]\n";
126  ParseOptions po(usage);
127  bool binary = true;
128  bool allow_partial = true;
130 
131  std::string word_syms_filename, utt2spk_rspecifier;
132  FasterDecoderOptions decoder_opts;
133  decoder_opts.Register(&po, true); // true == include obscure settings.
134  po.Register("utt2spk", &utt2spk_rspecifier, "rspecifier for utterance to "
135  "speaker map");
136  po.Register("binary", &binary, "Write output in binary mode");
137  po.Register("acoustic-scale", &acoustic_scale,
138  "Scaling factor for acoustic likelihoods");
139  po.Register("word-symbol-table", &word_syms_filename,
140  "Symbol table for words [for debug output]");
141  po.Register("allow-partial", &allow_partial,
142  "Produce output even when final state was not reached");
143  po.Read(argc, argv);
144 
145  if (po.NumArgs() < 6 || po.NumArgs() > 7) {
146  po.PrintUsage();
147  exit(1);
148  }
149 
150  std::string model_in_filename = po.GetArg(1),
151  fst_in_filename = po.GetArg(2),
152  regtree_filename = po.GetArg(3),
153  feature_rspecifier = po.GetArg(4),
154  xforms_rspecifier = po.GetArg(5),
155  words_wspecifier = po.GetArg(6),
156  alignment_wspecifier = po.GetOptArg(7);
157 
159  AmDiagGmm am_gmm;
160  {
161  bool binary_read;
162  Input ki(model_in_filename, &binary_read);
163  trans_model.Read(ki.Stream(), binary_read);
164  am_gmm.Read(ki.Stream(), binary_read);
165  }
166 
167  VectorFst<StdArc> *decode_fst = fst::ReadFstKaldi(fst_in_filename);
168 
169  RegressionTree regtree;
170  {
171  bool binary_read;
172  Input in(regtree_filename, &binary_read);
173  regtree.Read(in.Stream(), binary_read, am_gmm);
174  }
175 
176  RandomAccessRegtreeFmllrDiagGmmReaderMapped fmllr_reader(xforms_rspecifier,
177  utt2spk_rspecifier);
178 
179  Int32VectorWriter words_writer(words_wspecifier);
180 
181  Int32VectorWriter alignment_writer(alignment_wspecifier);
182 
183  fst::SymbolTable *word_syms = NULL;
184  if (word_syms_filename != "") {
185  word_syms = fst::SymbolTable::ReadText(word_syms_filename);
186  if (!word_syms) {
187  KALDI_ERR << "Could not read symbol table from file "
188  << word_syms_filename;
189  }
190  }
191 
192  BaseFloat tot_like = 0.0;
193  kaldi::int64 frame_count = 0;
194  int num_success = 0, num_fail = 0;
195  FasterDecoder decoder(*decode_fst, decoder_opts);
196 
197  Timer timer;
198 
199  DecodeInfo decode_info(am_gmm, trans_model, &decoder, acoustic_scale,
200  allow_partial, words_writer, alignment_writer,
201  word_syms);
202 
203  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
204  for (; !feature_reader.Done(); feature_reader.Next()) {
205  string utt = feature_reader.Key();
206 
207  Matrix<BaseFloat> features(feature_reader.Value());
208  feature_reader.FreeCurrent();
209  if (features.NumRows() == 0) {
210  KALDI_WARN << "Zero-length utterance: " << utt;
211  num_fail++;
212  continue;
213  }
214 
215  if (!fmllr_reader.HasKey(utt)) { // Decode without FMLLR if none found
216  KALDI_WARN << "No FMLLR transform for key " << utt <<
217  ", decoding without fMLLR.";
218  kaldi::DecodableAmDiagGmmScaled gmm_decodable(am_gmm, trans_model,
219  features,
220  acoustic_scale);
221  if (DecodeUtterance(&decoder, &gmm_decodable, &decode_info,
222  utt, features.NumRows(), &tot_like)) {
223  frame_count += gmm_decodable.NumFramesReady();
224  num_success++;
225  } else {
226  num_fail++;
227  }
228  continue;
229  }
230 
231  // If found, load the transforms for the current utterance.
232  RegtreeFmllrDiagGmm fmllr(fmllr_reader.Value(utt));
233  if (fmllr.NumRegClasses() == 1) {
234  Matrix<BaseFloat> xformed_features(features);
235  Matrix<BaseFloat> fmllr_matrix;
236  fmllr.GetXformMatrix(0, &fmllr_matrix);
237  for (int32 i = 0; i < xformed_features.NumRows(); i++) {
238  SubVector<BaseFloat> row(xformed_features, i);
239  ApplyAffineTransform(fmllr_matrix, &row);
240  }
241  kaldi::DecodableAmDiagGmmScaled gmm_decodable(am_gmm, trans_model,
242  xformed_features,
243  acoustic_scale);
244 
245  if (DecodeUtterance(&decoder, &gmm_decodable, &decode_info,
246  utt, xformed_features.NumRows(), &tot_like)) {
247  frame_count += gmm_decodable.NumFramesReady();
248  num_success++;
249  } else {
250  num_fail++;
251  }
252  } else {
253  kaldi::DecodableAmDiagGmmRegtreeFmllr gmm_decodable(am_gmm, trans_model,
254  features, fmllr,
255  regtree,
256  acoustic_scale);
257  if (DecodeUtterance(&decoder, &gmm_decodable, &decode_info,
258  utt, features.NumRows(), &tot_like)) {
259  frame_count += gmm_decodable.NumFramesReady();
260  num_success++;
261  } else {
262  num_fail++;
263  }
264  }
265  } // end looping over all utterances
266 
267  KALDI_LOG << "Average log-likelihood per frame is " << (tot_like
268  / frame_count) << " over " << frame_count << " frames.";
269 
270  double elapsed = timer.Elapsed();
271  KALDI_LOG << "Time taken [excluding initialization] " << elapsed
272  << "s: real-time factor assuming 100 frames/sec is "
273  << (elapsed * 100.0 / frame_count);
274  KALDI_LOG << "Done " << num_success << " utterances, failed for "
275  << num_fail;
276 
277  delete word_syms;
278  delete decode_fst;
279  if (num_success != 0)
280  return 0;
281  else
282  return 1;
283  }
284  catch(const std::exception &e) {
285  std::cerr << e.what();
286  return -1;
287  }
288 }
289 
290 
int32 words[kMaxOrder]
void Read(std::istream &in, bool binary, const AmDiagGmm &am)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
const kaldi::Int32VectorWriter & alignment_writer
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
fst::StdArc StdArc
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
const kaldi::TransitionModel & trans_model
void Decode(DecodableInterface *decodable)
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
DecodeInfo(const kaldi::AmDiagGmm &am, const kaldi::TransitionModel &tm, kaldi::FasterDecoder *decoder, BaseFloat scale, bool allow_partial, const kaldi::Int32VectorWriter &wwriter, const kaldi::Int32VectorWriter &awriter, fst::SymbolTable *wsyms)
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodeInfo)
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
bool GetBestPath(fst::MutableFst< LatticeArc > *fst_out, bool use_final_probs=true)
GetBestPath gets the decoding traceback.
An FMLLR (feature-space MLLR) transformation, also called CMLLR (constrained MLLR) is an affine trans...
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A regression tree is a clustering of Gaussian densities in an acoustic model, such that the group of ...
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
void Register(OptionsItf *opts, bool full)
bool HasKey(const std::string &key)
fst::SymbolTable * word_syms
bool DecodeUtterance(kaldi::FasterDecoder *decoder, kaldi::DecodableInterface *decodable, DecodeInfo *info, const string &uttid, int32 num_frames, BaseFloat *total_like)
kaldi::FasterDecoder * decoder
const kaldi::AmDiagGmm & acoustic_model
int NumArgs() const
Number of positional parameters (c.f. argc-1).
int main(int argc, char *argv[])
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void ReadFstKaldi(std::istream &is, bool binary, VectorFst< Arc > *fst)
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
LatticeWeightTpl< BaseFloat > LatticeWeight
const T & Value(const std::string &key)
void ApplyAffineTransform(const MatrixBase< BaseFloat > &xform, VectorBase< BaseFloat > *vec)
Applies the affine transform &#39;xform&#39; to the vector &#39;vec&#39; and overwrites the contents of &#39;vec&#39;...
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
bool ReachedFinal() const
Returns true if a final state was active on the last frame.
const kaldi::Int32VectorWriter & words_writer
std::string GetOptArg(int param) const