online-wav-gmm-decode-faster.cc
Go to the documentation of this file.
1 // onlinebin/online-wav-gmm-decode-faster.cc
2 
3 // Copyright 2012 Cisco Systems (author: Matthias Paulik)
4 
5 // Modifications to the original contribution by Cisco Systems made by:
6 // Vassil Panayotov
7 
8 // See ../../COPYING for clarification regarding multiple authors
9 //
10 // Licensed under the Apache License, Version 2.0 (the "License");
11 // you may not use this file except in compliance with the License.
12 // You may obtain a copy of the License at
13 //
14 // http://www.apache.org/licenses/LICENSE-2.0
15 //
16 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
18 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
19 // MERCHANTABLITY OR NON-INFRINGEMENT.
20 // See the Apache 2 License for the specific language governing permissions and
21 // limitations under the License.
22 
23 #include "feat/feature-mfcc.h"
24 #include "feat/wave-reader.h"
29 #include "online/onlinebin-util.h"
30 
31 int main(int argc, char *argv[]) {
32  try {
33  using namespace kaldi;
34  using namespace fst;
35 
36  typedef kaldi::int32 int32;
37  typedef OnlineFeInput<Mfcc> FeInput;
38 
39  // up to delta-delta derivative features are calculated (unless LDA is used)
40  const int32 kDeltaOrder = 2;
41 
42  const char *usage =
43  "Reads in wav file(s) and simulates online decoding.\n"
44  "Writes integerized-text and .ali files for WER computation. Utterance "
45  "segmentation is done on-the-fly.\n"
46  "Feature splicing/LDA transform is used, if the optional(last) argument "
47  "is given.\n"
48  "Otherwise delta/delta-delta(i.e. 2-nd order) features are produced.\n"
49  "Caution: the last few frames of the wav file may not be decoded properly.\n"
50  "Hence, don't use one wav file per utterance, but "
51  "rather use one wav file per show.\n\n"
52  "Usage: online-wav-gmm-decode-faster [options] wav-rspecifier model-in"
53  "fst-in word-symbol-table silence-phones transcript-wspecifier "
54  "alignments-wspecifier [lda-matrix-in]\n\n"
55  "Example: ./online-wav-gmm-decode-faster --rt-min=0.3 --rt-max=0.5 "
56  "--max-active=4000 --beam=12.0 --acoustic-scale=0.0769 "
57  "scp:wav.scp model HCLG.fst words.txt '1:2:3:4:5' ark,t:trans.txt ark,t:ali.txt";
58  ParseOptions po(usage);
59  BaseFloat acoustic_scale = 0.1;
60  int32 cmn_window = 600,
61  min_cmn_window = 100; // adds 1 second latency, only at utterance start.
62  int32 channel = -1;
63  int32 right_context = 4, left_context = 4;
64 
65  OnlineFasterDecoderOpts decoder_opts;
66  decoder_opts.Register(&po, true);
67  OnlineFeatureMatrixOptions feature_reading_opts;
68  feature_reading_opts.Register(&po);
69 
70  po.Register("left-context", &left_context, "Number of frames of left context");
71  po.Register("right-context", &right_context, "Number of frames of right context");
72  po.Register("acoustic-scale", &acoustic_scale,
73  "Scaling factor for acoustic likelihoods");
74  po.Register("cmn-window", &cmn_window,
75  "Number of feat. vectors used in the running average CMN calculation");
76  po.Register("min-cmn-window", &min_cmn_window,
77  "Minumum CMN window used at start of decoding (adds "
78  "latency only at start)");
79  po.Register("channel", &channel,
80  "Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right)");
81  po.Read(argc, argv);
82  if (po.NumArgs() != 7 && po.NumArgs() != 8) {
83  po.PrintUsage();
84  return 1;
85  }
86 
87  std::string wav_rspecifier = po.GetArg(1),
88  model_rspecifier = po.GetArg(2),
89  fst_rspecifier = po.GetArg(3),
90  word_syms_filename = po.GetArg(4),
91  silence_phones_str = po.GetArg(5),
92  words_wspecifier = po.GetArg(6),
93  alignment_wspecifier = po.GetArg(7),
94  lda_mat_rspecifier = po.GetOptArg(8);
95 
96  std::vector<int32> silence_phones;
97  if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones))
98  KALDI_ERR << "Invalid silence-phones string " << silence_phones_str;
99  if (silence_phones.empty())
100  KALDI_ERR << "No silence phones given!";
101 
102  Int32VectorWriter words_writer(words_wspecifier);
103  Int32VectorWriter alignment_writer(alignment_wspecifier);
104 
105  Matrix<BaseFloat> lda_transform;
106  if (lda_mat_rspecifier != "") {
107  bool binary_in;
108  Input ki(lda_mat_rspecifier, &binary_in);
109  lda_transform.Read(ki.Stream(), binary_in);
110  }
111 
112  TransitionModel trans_model;
113  AmDiagGmm am_gmm;
114  {
115  bool binary;
116  Input ki(model_rspecifier, &binary);
117  trans_model.Read(ki.Stream(), binary);
118  am_gmm.Read(ki.Stream(), binary);
119  }
120 
121  fst::SymbolTable *word_syms = NULL;
122  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
123  KALDI_ERR << "Could not read symbol table from file "
124  << word_syms_filename;
125 
126  fst::Fst<fst::StdArc> *decode_fst = ReadDecodeGraph(fst_rspecifier);
127 
128  // We are not properly registering/exposing MFCC and frame extraction options,
129  // because there are parts of the online decoding code, where some of these
130  // options are hardwired(ToDo: we should fix this at some point)
131  MfccOptions mfcc_opts;
132  mfcc_opts.use_energy = false;
133  int32 frame_length = mfcc_opts.frame_opts.frame_length_ms = 25;
134  int32 frame_shift = mfcc_opts.frame_opts.frame_shift_ms = 10;
135 
136  int32 window_size = right_context + left_context + 1;
137  decoder_opts.batch_size = std::max(decoder_opts.batch_size, window_size);
138 
139  OnlineFasterDecoder decoder(*decode_fst, decoder_opts,
140  silence_phones, trans_model);
141  SequentialTableReader<WaveHolder> reader(wav_rspecifier);
142  VectorFst<LatticeArc> out_fst;
143  for (; !reader.Done(); reader.Next()) {
144  std::string wav_key = reader.Key();
145  std::cerr << "File: " << wav_key << std::endl;
146  const WaveData &wav_data = reader.Value();
147  if(wav_data.SampFreq() != 16000)
148  KALDI_ERR << "Sampling rates other than 16kHz are not supported!";
149  int32 num_chan = wav_data.Data().NumRows(), this_chan = channel;
150  { // This block works out the channel (0=left, 1=right...)
151  KALDI_ASSERT(num_chan > 0); // should have been caught in
152  // reading code if no channels.
153  if (channel == -1) {
154  this_chan = 0;
155  if (num_chan != 1)
156  KALDI_WARN << "Channel not specified but you have data with "
157  << num_chan << " channels; defaulting to zero";
158  } else {
159  if (this_chan >= num_chan) {
160  KALDI_WARN << "File with id " << wav_key << " has "
161  << num_chan << " channels but you specified channel "
162  << channel << ", producing no output.";
163  continue;
164  }
165  }
166  }
167  OnlineVectorSource au_src(wav_data.Data().Row(this_chan));
168  Mfcc mfcc(mfcc_opts);
169  FeInput fe_input(&au_src, &mfcc,
170  frame_length*(wav_data.SampFreq()/1000),
171  frame_shift*(wav_data.SampFreq()/1000));
172  OnlineCmnInput cmn_input(&fe_input, cmn_window, min_cmn_window);
173  OnlineFeatInputItf *feat_transform = 0;
174  if (lda_mat_rspecifier != "") {
175  feat_transform = new OnlineLdaInput(
176  &cmn_input, lda_transform,
177  left_context, right_context);
178  } else {
180  opts.order = kDeltaOrder;
181  feat_transform = new OnlineDeltaInput(opts, &cmn_input);
182  }
183 
184  // feature_reading_opts contains number of retries, batch size.
185  OnlineFeatureMatrix feature_matrix(feature_reading_opts,
186  feat_transform);
187 
188  OnlineDecodableDiagGmmScaled decodable(am_gmm, trans_model, acoustic_scale,
189  &feature_matrix);
190  int32 start_frame = 0;
191  bool partial_res = false;
192  decoder.InitDecoding();
193  while (1) {
194  OnlineFasterDecoder::DecodeState dstate = decoder.Decode(&decodable);
195  if (dstate & (decoder.kEndFeats | decoder.kEndUtt)) {
196  std::vector<int32> word_ids;
197  decoder.FinishTraceBack(&out_fst);
199  static_cast<vector<int32> *>(0),
200  &word_ids,
201  static_cast<LatticeArc::Weight*>(0));
202  PrintPartialResult(word_ids, word_syms, partial_res || word_ids.size());
203  partial_res = false;
204 
205  decoder.GetBestPath(&out_fst);
206  std::vector<int32> tids;
208  &tids,
209  &word_ids,
210  static_cast<LatticeArc::Weight*>(0));
211  std::stringstream res_key;
212  res_key << wav_key << '_' << start_frame << '-' << decoder.frame();
213  if (!word_ids.empty())
214  words_writer.Write(res_key.str(), word_ids);
215  alignment_writer.Write(res_key.str(), tids);
216  if (dstate == decoder.kEndFeats)
217  break;
218  start_frame = decoder.frame();
219  } else {
220  std::vector<int32> word_ids;
221  if (decoder.PartialTraceback(&out_fst)) {
223  static_cast<vector<int32> *>(0),
224  &word_ids,
225  static_cast<LatticeArc::Weight*>(0));
226  PrintPartialResult(word_ids, word_syms, false);
227  if (!partial_res)
228  partial_res = (word_ids.size() > 0);
229  }
230  }
231  }
232  delete feat_transform;
233  }
234  delete word_syms;
235  delete decode_fst;
236  return 0;
237  } catch(const std::exception& e) {
238  std::cerr << e.what();
239  return -1;
240  }
241 } // main()
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool PartialTraceback(fst::MutableFst< LatticeArc > *out_fst)
void Register(OptionsItf *opts, bool full)
MfccOptions contains basic options for computing MFCC features.
Definition: feature-mfcc.h:38
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void InitDecoding()
As a new alternative to Decode(), you can call InitDecoding and then (possibly multiple times) Advanc...
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
BaseFloat SampFreq() const
Definition: wave-reader.h:126
DecodeState Decode(DecodableInterface *decodable)
int main(int argc, char *argv[])
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
void PrintPartialResult(const std::vector< int32 > &words, const fst::SymbolTable *word_syms, bool line_break)
bool GetBestPath(fst::MutableFst< LatticeArc > *fst_out, bool use_final_probs=true)
GetBestPath gets the decoding traceback.
std::istream & Stream()
Definition: kaldi-io.cc:826
void Read(std::istream &in, bool binary, bool add=false)
read from stream.
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
void FinishTraceBack(fst::MutableFst< LatticeArc > *fst_out)
FrameExtractionOptions frame_opts
Definition: feature-mfcc.h:39
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
int NumArgs() const
Number of positional parameters (c.f. argc-1).
fst::Fst< fst::StdArc > * ReadDecodeGraph(const std::string &filename)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
This templated class is intended for offline feature extraction, i.e.
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
std::string GetOptArg(int param) const