online2-wav-gmm-latgen-faster.cc
Go to the documentation of this file.
1 // online2bin/online2-wav-gmm-latgen-faster.cc
2 
3 // Copyright 2014 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "feat/wave-reader.h"
23 #include "online2/onlinebin-util.h"
24 #include "online2/online-timing.h"
26 #include "fstext/fstext-lib.h"
27 #include "lat/lattice-functions.h"
28 
29 namespace kaldi {
30 
31 void GetDiagnosticsAndPrintOutput(const std::string &utt,
32  const fst::SymbolTable *word_syms,
33  const CompactLattice &clat,
34  int64 *tot_num_frames,
35  double *tot_like) {
36  if (clat.NumStates() == 0) {
37  KALDI_WARN << "Empty lattice.";
38  return;
39  }
40  CompactLattice best_path_clat;
41  CompactLatticeShortestPath(clat, &best_path_clat);
42 
43  Lattice best_path_lat;
44  ConvertLattice(best_path_clat, &best_path_lat);
45 
46  double likelihood;
47  LatticeWeight weight;
48  int32 num_frames;
49  std::vector<int32> alignment;
50  std::vector<int32> words;
51  GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight);
52  num_frames = alignment.size();
53  likelihood = -(weight.Value1() + weight.Value2());
54  *tot_num_frames += num_frames;
55  *tot_like += likelihood;
56  KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is "
57  << (likelihood / num_frames) << " over " << num_frames
58  << " frames.";
59 
60  if (word_syms != NULL) {
61  std::cerr << utt << ' ';
62  for (size_t i = 0; i < words.size(); i++) {
63  std::string s = word_syms->Find(words[i]);
64  if (s == "")
65  KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
66  std::cerr << s << ' ';
67  }
68  std::cerr << std::endl;
69  }
70 }
71 
72 }
73 
74 int main(int argc, char *argv[]) {
75  try {
76  using namespace kaldi;
77  using namespace fst;
78 
79  typedef kaldi::int32 int32;
80  typedef kaldi::int64 int64;
81 
82  const char *usage =
83  "Reads in wav file(s) and simulates online decoding, including\n"
84  "basis-fMLLR adaptation and endpointing. Writes lattices.\n"
85  "Models are specified via options.\n"
86  "\n"
87  "Usage: online2-wav-gmm-latgen-faster [options] <fst-in> "
88  "<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n"
89  "Run egs/rm/s5/local/run_online_decoding.sh for example\n";
90 
91  ParseOptions po(usage);
92 
93  std::string word_syms_rxfilename;
94 
95  OnlineEndpointConfig endpoint_config;
96  OnlineFeaturePipelineCommandLineConfig feature_cmdline_config;
97  OnlineGmmDecodingConfig decode_config;
98 
99  BaseFloat chunk_length_secs = 0.05;
100  bool do_endpointing = false;
101  std::string use_gpu = "no";
102 
103  po.Register("chunk-length", &chunk_length_secs,
104  "Length of chunk size in seconds, that we process.");
105  po.Register("word-symbol-table", &word_syms_rxfilename,
106  "Symbol table for words [for debug output]");
107  po.Register("do-endpointing", &do_endpointing,
108  "If true, apply endpoint detection");
109 
110  feature_cmdline_config.Register(&po);
111  decode_config.Register(&po);
112  endpoint_config.Register(&po);
113 
114  po.Read(argc, argv);
115 
116  if (po.NumArgs() != 4) {
117  po.PrintUsage();
118  return 1;
119  }
120 
121  std::string fst_rxfilename = po.GetArg(1),
122  spk2utt_rspecifier = po.GetArg(2),
123  wav_rspecifier = po.GetArg(3),
124  clat_wspecifier = po.GetArg(4);
125 
126  OnlineFeaturePipelineConfig feature_config(feature_cmdline_config);
127  OnlineFeaturePipeline pipeline_prototype(feature_config);
128  // The following object initializes the models we use in decoding.
129  OnlineGmmDecodingModels gmm_models(decode_config);
130 
131 
132  fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldiGeneric(fst_rxfilename);
133 
134  fst::SymbolTable *word_syms = NULL;
135  if (word_syms_rxfilename != "")
136  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
137  KALDI_ERR << "Could not read symbol table from file "
138  << word_syms_rxfilename;
139 
140  int32 num_done = 0, num_err = 0;
141  double tot_like = 0.0;
142  int64 num_frames = 0;
143 
144  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
145  RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
146  CompactLatticeWriter clat_writer(clat_wspecifier);
147 
148  OnlineTimingStats timing_stats;
149 
150  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
151  std::string spk = spk2utt_reader.Key();
152  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
153  OnlineGmmAdaptationState adaptation_state;
154  for (size_t i = 0; i < uttlist.size(); i++) {
155  std::string utt = uttlist[i];
156  if (!wav_reader.HasKey(utt)) {
157  KALDI_WARN << "Did not find audio for utterance " << utt;
158  num_err++;
159  continue;
160  }
161  const WaveData &wave_data = wav_reader.Value(utt);
162  // get the data for channel zero (if the signal is not mono, we only
163  // take the first channel).
164  SubVector<BaseFloat> data(wave_data.Data(), 0);
165 
166  SingleUtteranceGmmDecoder decoder(decode_config,
167  gmm_models,
168  pipeline_prototype,
169  *decode_fst,
170  adaptation_state);
171 
172  OnlineTimer decoding_timer(utt);
173 
174  BaseFloat samp_freq = wave_data.SampFreq();
175  int32 chunk_length = int32(samp_freq * chunk_length_secs);
176  if (chunk_length == 0) chunk_length = 1;
177 
178  int32 samp_offset = 0;
179  while (samp_offset < data.Dim()) {
180  int32 samp_remaining = data.Dim() - samp_offset;
181  int32 num_samp = chunk_length < samp_remaining ? chunk_length
182  : samp_remaining;
183 
184  SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
185  decoder.FeaturePipeline().AcceptWaveform(samp_freq, wave_part);
186 
187  samp_offset += num_samp;
188  decoding_timer.WaitUntil(samp_offset / samp_freq);
189  if (samp_offset == data.Dim()) {
190  // no more input. flush out last frames
191  decoder.FeaturePipeline().InputFinished();
192  }
193  decoder.AdvanceDecoding();
194 
195  if (do_endpointing && decoder.EndpointDetected(endpoint_config))
196  break;
197  }
198  decoder.FinalizeDecoding();
199 
200  bool end_of_utterance = true;
201  decoder.EstimateFmllr(end_of_utterance);
202  CompactLattice clat;
203  bool rescore_if_needed = true;
204  decoder.GetLattice(rescore_if_needed, end_of_utterance, &clat);
205 
206  GetDiagnosticsAndPrintOutput(utt, word_syms, clat,
207  &num_frames, &tot_like);
208 
209  decoding_timer.OutputStats(&timing_stats);
210 
211  // In an application you might avoid updating the adaptation state if
212  // you felt the utterance had low confidence. See lat/confidence.h
213  decoder.GetAdaptationState(&adaptation_state);
214 
215  // we want to output the lattice with un-scaled acoustics.
216  if (decode_config.acoustic_scale != 0.0) {
217  BaseFloat inv_acoustic_scale = 1.0 / decode_config.acoustic_scale;
218  ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat);
219  }
220  clat_writer.Write(utt, clat);
221  KALDI_LOG << "Decoded utterance " << utt;
222  num_done++;
223  }
224  }
225  timing_stats.Print();
226  KALDI_LOG << "Decoded " << num_done << " utterances, "
227  << num_err << " with errors.";
228  KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames)
229  << " per frame over " << num_frames << " frames.";
230  delete decode_fst;
231  delete word_syms; // will delete if non-NULL.
232  return (num_done != 0 ? 0 : 1);
233  } catch(const std::exception& e) {
234  std::cerr << e.what();
235  return -1;
236  }
237 } // main()
int32 words[kMaxOrder]
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
class OnlineTimer is used to test real-time decoding algorithms and evaluate how long the decoding of...
Definition: online-timing.h:88
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
This class is used to read, store and give access to the models used for 3 phases of decoding (first-...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void OutputStats(OnlineTimingStats *stats)
This call, which should be made after decoding is done, writes the stats to the object that accumulat...
This file contains a class OnlineFeaturePipeline for online feature extraction, which puts together v...
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
BaseFloat SampFreq() const
Definition: wave-reader.h:126
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
void Register(OptionsItf *opts)
void GetDiagnosticsAndPrintOutput(const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like)
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
This configuration class is to set up OnlineFeaturePipelineConfig, which in turn is the configuration...
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
OnlineFeaturePipeline is a class that&#39;s responsible for putting together the various stages of the fe...
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void CompactLatticeShortestPath(const CompactLattice &clat, CompactLattice *shortest_path)
A form of the shortest-path/best-path algorithm that&#39;s specially coded for CompactLattice.
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Print(bool online=true)
Here, if "online == false" we take into account that the setup was used in not-really-online mode whe...
const T & Value(const std::string &key)
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
int main(int argc, char *argv[])
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
bool HasKey(const std::string &key)
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
This configuration class is responsible for storing the configuration options for OnlineFeaturePipeli...
class OnlineTimingStats stores statistics from timing of online decoding, which will enable the Print...
Definition: online-timing.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
void WaitUntil(double cur_utterance_length)
The call to WaitUntil(t) simulates the effect of sleeping until cur_utterance_length seconds after th...