online2-wav-nnet2-latgen-threaded.cc
Go to the documentation of this file.
1 // online2bin/online2-wav-nnet2-latgen-threaded.cc
2 
3 // Copyright 2014-2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "feat/wave-reader.h"
23 #include "online2/onlinebin-util.h"
24 #include "online2/online-timing.h"
26 #include "fstext/fstext-lib.h"
27 #include "lat/lattice-functions.h"
28 #include "util/kaldi-thread.h"
29 
30 namespace kaldi {
31 
32 void GetDiagnosticsAndPrintOutput(const std::string &utt,
33  const fst::SymbolTable *word_syms,
34  const CompactLattice &clat,
35  int64 *tot_num_frames,
36  double *tot_like) {
37  if (clat.NumStates() == 0) {
38  KALDI_WARN << "Empty lattice.";
39  return;
40  }
41  CompactLattice best_path_clat;
42  CompactLatticeShortestPath(clat, &best_path_clat);
43 
44  Lattice best_path_lat;
45  ConvertLattice(best_path_clat, &best_path_lat);
46 
47  double likelihood;
48  LatticeWeight weight;
49  int32 num_frames;
50  std::vector<int32> alignment;
51  std::vector<int32> words;
52  GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight);
53  num_frames = alignment.size();
54  likelihood = -(weight.Value1() + weight.Value2());
55  *tot_num_frames += num_frames;
56  *tot_like += likelihood;
57  KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is "
58  << (likelihood / num_frames) << " over " << num_frames
59  << " frames.";
60 
61  if (word_syms != NULL) {
62  std::cerr << utt << ' ';
63  for (size_t i = 0; i < words.size(); i++) {
64  std::string s = word_syms->Find(words[i]);
65  if (s == "")
66  KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
67  std::cerr << s << ' ';
68  }
69  std::cerr << std::endl;
70  }
71 }
72 
73 }
74 
75 int main(int argc, char *argv[]) {
76  try {
77  using namespace kaldi;
78  using namespace fst;
79 
80  typedef kaldi::int32 int32;
81  typedef kaldi::int64 int64;
82 
83  const char *usage =
84  "Reads in wav file(s) and simulates online decoding with neural nets\n"
85  "(nnet2 setup), with optional iVector-based speaker adaptation and\n"
86  "optional endpointing. This version uses multiple threads for decoding.\n"
87  "Note: some configuration values and inputs are set via config files\n"
88  "whose filenames are passed as options\n"
89  "\n"
90  "Usage: online2-wav-nnet2-latgen-threaded [options] <nnet2-in> <fst-in> "
91  "<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n"
92  "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
93  "you want to decode utterance by utterance.\n"
94  "See egs/rm/s5/local/run_online_decoding_nnet2.sh for example\n"
95  "See also online2-wav-nnet2-latgen-faster\n";
96 
97  ParseOptions po(usage);
98 
99  std::string word_syms_rxfilename;
100 
101  OnlineEndpointConfig endpoint_config;
102 
103  // feature_config includes configuration for the iVector adaptation,
104  // as well as the basic features.
105  OnlineNnet2FeaturePipelineConfig feature_config;
106  OnlineNnet2DecodingThreadedConfig nnet2_decoding_config;
107 
108  BaseFloat chunk_length_secs = 0.05;
109  bool do_endpointing = false;
110  bool modify_ivector_config = false;
111  bool simulate_realtime_decoding = true;
112 
113  po.Register("chunk-length", &chunk_length_secs,
114  "Length of chunk size in seconds, that we provide each time to the "
115  "decoder. The actual chunk sizes it processes for various stages "
116  "of decoding are dynamically determinated, and unrelated to this");
117  po.Register("word-symbol-table", &word_syms_rxfilename,
118  "Symbol table for words [for debug output]");
119  po.Register("do-endpointing", &do_endpointing,
120  "If true, apply endpoint detection");
121  po.Register("modify-ivector-config", &modify_ivector_config,
122  "If true, modifies the iVector configuration from the config files "
123  "by setting --use-most-recent-ivector=true and --greedy-ivector-extractor=true. "
124  "This will give the best possible results, but the results may become dependent "
125  "on the speed of your machine (slower machine -> better results). Compare "
126  "to the --online option in online2-wav-nnet2-latgen-faster");
127  po.Register("simulate-realtime-decoding", &simulate_realtime_decoding,
128  "If true, simulate real-time decoding scenario by providing the "
129  "data incrementally, calling sleep() until each piece is ready. "
130  "If false, don't sleep (so it will be faster).");
131  po.Register("num-threads-startup", &g_num_threads,
132  "Number of threads used when initializing iVector extractor. ");
133 
134  feature_config.Register(&po);
135  nnet2_decoding_config.Register(&po);
136  endpoint_config.Register(&po);
137 
138  po.Read(argc, argv);
139 
140  if (po.NumArgs() != 5) {
141  po.PrintUsage();
142  return 1;
143  }
144 
145  std::string nnet2_rxfilename = po.GetArg(1),
146  fst_rxfilename = po.GetArg(2),
147  spk2utt_rspecifier = po.GetArg(3),
148  wav_rspecifier = po.GetArg(4),
149  clat_wspecifier = po.GetArg(5);
150 
151  OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
152 
153  if (modify_ivector_config) {
156  }
157 
158  Matrix<double> global_cmvn_stats;
159  if (feature_info.global_cmvn_stats_rxfilename != "")
161  &global_cmvn_stats);
162 
163  TransitionModel trans_model;
164  nnet2::AmNnet am_nnet;
165  {
166  bool binary;
167  Input ki(nnet2_rxfilename, &binary);
168  trans_model.Read(ki.Stream(), binary);
169  am_nnet.Read(ki.Stream(), binary);
170  }
171 
172  fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldiGeneric(fst_rxfilename);
173 
174  fst::SymbolTable *word_syms = NULL;
175  if (word_syms_rxfilename != "")
176  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
177  KALDI_ERR << "Could not read symbol table from file "
178  << word_syms_rxfilename;
179 
180  int32 num_done = 0, num_err = 0;
181  double tot_like = 0.0;
182  int64 num_frames = 0;
183  Timer global_timer;
184 
185  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
186  RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
187  CompactLatticeWriter clat_writer(clat_wspecifier);
188 
189  OnlineTimingStats timing_stats;
190 
191  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
192  std::string spk = spk2utt_reader.Key();
193  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
194 
195  OnlineIvectorExtractorAdaptationState adaptation_state(
196  feature_info.ivector_extractor_info);
197  OnlineCmvnState cmvn_state(global_cmvn_stats);
198 
199  for (size_t i = 0; i < uttlist.size(); i++) {
200  std::string utt = uttlist[i];
201  if (!wav_reader.HasKey(utt)) {
202  KALDI_WARN << "Did not find audio for utterance " << utt;
203  num_err++;
204  continue;
205  }
206  const WaveData &wave_data = wav_reader.Value(utt);
207  // get the data for channel zero (if the signal is not mono, we only
208  // take the first channel).
209  SubVector<BaseFloat> data(wave_data.Data(), 0);
210 
212  nnet2_decoding_config, trans_model, am_nnet,
213  *decode_fst, feature_info, adaptation_state, cmvn_state);
214 
215  OnlineTimer decoding_timer(utt);
216 
217  BaseFloat samp_freq = wave_data.SampFreq();
218  int32 chunk_length;
219  KALDI_ASSERT(chunk_length_secs > 0);
220  chunk_length = int32(samp_freq * chunk_length_secs);
221  if (chunk_length == 0) chunk_length = 1;
222 
223  int32 samp_offset = 0;
224  while (samp_offset < data.Dim()) {
225  int32 samp_remaining = data.Dim() - samp_offset;
226  int32 num_samp = chunk_length < samp_remaining ? chunk_length
227  : samp_remaining;
228 
229  SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
230 
231  // The endpointing code won't work if we let the waveform be given to
232  // the decoder all at once, because we'll exit this while loop, and
233  // the endpointing happens inside this while loop. The next statement
234  // is intended to prevent this from happening.
235  while (do_endpointing &&
236  decoder.NumWaveformPiecesPending() * chunk_length_secs > 2.0)
237  Sleep(0.5f);
238 
239  decoder.AcceptWaveform(samp_freq, wave_part);
240 
241  samp_offset += num_samp;
242 
243  if (simulate_realtime_decoding) {
244  // Note: the next call may actually call sleep().
245  decoding_timer.SleepUntil(samp_offset / samp_freq);
246  }
247  if (samp_offset == data.Dim()) {
248  // no more input. flush out last frames
249  decoder.InputFinished();
250  }
251 
252  if (do_endpointing && decoder.EndpointDetected(endpoint_config)) {
253  decoder.TerminateDecoding();
254  break;
255  }
256  }
257  Timer timer;
258  decoder.Wait();
259  if (simulate_realtime_decoding) {
260  KALDI_VLOG(1) << "Waited " << timer.Elapsed() << " seconds for decoder to "
261  << "finish after giving it last chunk.";
262  }
263  decoder.FinalizeDecoding();
264 
265  CompactLattice clat;
266  bool end_of_utterance = true;
267  decoder.GetLattice(end_of_utterance, &clat, NULL);
268 
269  GetDiagnosticsAndPrintOutput(utt, word_syms, clat,
270  &num_frames, &tot_like);
271 
272  decoding_timer.OutputStats(&timing_stats);
273 
274  // In an application you might avoid updating the adaptation state if
275  // you felt the utterance had low confidence. See lat/confidence.h
276  decoder.GetAdaptationState(&adaptation_state);
277  decoder.GetCmvnState(&cmvn_state);
278 
279  // we want to output the lattice with un-scaled acoustics.
280  BaseFloat inv_acoustic_scale =
281  1.0 / nnet2_decoding_config.acoustic_scale;
282  ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat);
283 
284  if (simulate_realtime_decoding) {
285  KALDI_VLOG(1) << "Adding the various end-of-utterance tasks took the "
286  << "total latency to " << timer.Elapsed() << " seconds.";
287  }
288  clat_writer.Write(utt, clat);
289  KALDI_LOG << "Decoded utterance " << utt;
290 
291  num_done++;
292  }
293  }
294  bool online = true;
295 
296  if (simulate_realtime_decoding) {
297  timing_stats.Print(online);
298  } else {
299  BaseFloat frame_shift = 0.01;
300  BaseFloat real_time_factor =
301  global_timer.Elapsed() / (frame_shift * num_frames);
302  if (num_frames > 0)
303  KALDI_LOG << "Real-time factor was " << real_time_factor
304  << " assuming frame shift of " << frame_shift;
305  }
306 
307  KALDI_LOG << "Decoded " << num_done << " utterances, "
308  << num_err << " with errors.";
309  KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames)
310  << " per frame over " << num_frames << " frames.";
311  delete decode_fst;
312  delete word_syms; // will delete if non-NULL.
313  return (num_done != 0 ? 0 : 1);
314  } catch(const std::exception& e) {
315  std::cerr << e.what();
316  return -1;
317  }
318 } // main()
int32 words[kMaxOrder]
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
class OnlineTimer is used to test real-time decoding algorithms and evaluate how long the decoding of...
Definition: online-timing.h:88
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Sleep(float seconds)
Definition: kaldi-utils.cc:45
int32 g_num_threads
Definition: kaldi-thread.cc:25
This class stores the adaptation state from the online iVector extractor, which can help you to initi...
void OutputStats(OnlineTimingStats *stats)
This call, which should be made after decoding is done, writes the stats to the object that accumulat...
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
BaseFloat SampFreq() const
Definition: wave-reader.h:126
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
void Register(OptionsItf *opts)
void GetDiagnosticsAndPrintOutput(const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like)
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
This file contains a different version of the feature-extraction pipeline in online-feature-pipeline...
void Write(const std::string &key, const T &value) const
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
void CompactLatticeShortestPath(const CompactLattice &clat, CompactLattice *shortest_path)
A form of the shortest-path/best-path algorithm that&#39;s specially coded for CompactLattice.
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Print(bool online=true)
Here, if "online == false" we take into account that the setup was used in not-really-online mode whe...
const T & Value(const std::string &key)
void SleepUntil(double cur_utterance_length)
The call to SleepUntil(t) will sleep until cur_utterance_length seconds after this object was initial...
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void Read(std::istream &is, bool binary)
int main(int argc, char *argv[])
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
bool HasKey(const std::string &key)
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
int NumArgs() const
Number of positional parameters (c.f. argc-1).
std::string global_cmvn_stats_rxfilename
Options for online cmvn, read from config file.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
class OnlineTimingStats stores statistics from timing of online decoding, which will enable the Print...
Definition: online-timing.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501