online2-wav-nnet3-latgen-faster.cc
Go to the documentation of this file.
1 // online2bin/online2-wav-nnet3-latgen-faster.cc
2 
3 // Copyright 2014 Johns Hopkins University (author: Daniel Povey)
4 // 2016 Api.ai (Author: Ilya Platonov)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "feat/wave-reader.h"
24 #include "online2/onlinebin-util.h"
25 #include "online2/online-timing.h"
27 #include "fstext/fstext-lib.h"
28 #include "lat/lattice-functions.h"
29 #include "util/kaldi-thread.h"
30 #include "nnet3/nnet-utils.h"
31 
32 namespace kaldi {
33 
34 void GetDiagnosticsAndPrintOutput(const std::string &utt,
35  const fst::SymbolTable *word_syms,
36  const CompactLattice &clat,
37  int64 *tot_num_frames,
38  double *tot_like) {
39  if (clat.NumStates() == 0) {
40  KALDI_WARN << "Empty lattice.";
41  return;
42  }
43  CompactLattice best_path_clat;
44  CompactLatticeShortestPath(clat, &best_path_clat);
45 
46  Lattice best_path_lat;
47  ConvertLattice(best_path_clat, &best_path_lat);
48 
49  double likelihood;
50  LatticeWeight weight;
51  int32 num_frames;
52  std::vector<int32> alignment;
53  std::vector<int32> words;
54  GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight);
55  num_frames = alignment.size();
56  likelihood = -(weight.Value1() + weight.Value2());
57  *tot_num_frames += num_frames;
58  *tot_like += likelihood;
59  KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is "
60  << (likelihood / num_frames) << " over " << num_frames
61  << " frames, = " << (-weight.Value1() / num_frames)
62  << ',' << (weight.Value2() / num_frames);
63 
64  if (word_syms != NULL) {
65  std::cerr << utt << ' ';
66  for (size_t i = 0; i < words.size(); i++) {
67  std::string s = word_syms->Find(words[i]);
68  if (s == "")
69  KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
70  std::cerr << s << ' ';
71  }
72  std::cerr << std::endl;
73  }
74 }
75 
76 }
77 
78 int main(int argc, char *argv[]) {
79  try {
80  using namespace kaldi;
81  using namespace fst;
82 
83  typedef kaldi::int32 int32;
84  typedef kaldi::int64 int64;
85 
86  const char *usage =
87  "Reads in wav file(s) and simulates online decoding with neural nets\n"
88  "(nnet3 setup), with optional iVector-based speaker adaptation and\n"
89  "optional endpointing. Note: some configuration values and inputs are\n"
90  "set via config files whose filenames are passed as options\n"
91  "\n"
92  "Usage: online2-wav-nnet3-latgen-faster [options] <nnet3-in> <fst-in> "
93  "<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n"
94  "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
95  "you want to decode utterance by utterance.\n";
96 
97  ParseOptions po(usage);
98 
99  std::string word_syms_rxfilename;
100 
101  // feature_opts includes configuration for the iVector adaptation,
102  // as well as the basic features.
105  LatticeFasterDecoderConfig decoder_opts;
106  OnlineEndpointConfig endpoint_opts;
107 
108  BaseFloat chunk_length_secs = 0.18;
109  bool do_endpointing = false;
110  bool online = true;
111 
112  po.Register("chunk-length", &chunk_length_secs,
113  "Length of chunk size in seconds, that we process. Set to <= 0 "
114  "to use all input in one chunk.");
115  po.Register("word-symbol-table", &word_syms_rxfilename,
116  "Symbol table for words [for debug output]");
117  po.Register("do-endpointing", &do_endpointing,
118  "If true, apply endpoint detection");
119  po.Register("online", &online,
120  "You can set this to false to disable online iVector estimation "
121  "and have all the data for each utterance used, even at "
122  "utterance start. This is useful where you just want the best "
123  "results and don't care about online operation. Setting this to "
124  "false has the same effect as setting "
125  "--use-most-recent-ivector=true and --greedy-ivector-extractor=true "
126  "in the file given to --ivector-extraction-config, and "
127  "--chunk-length=-1.");
128  po.Register("num-threads-startup", &g_num_threads,
129  "Number of threads used when initializing iVector extractor.");
130 
131  feature_opts.Register(&po);
132  decodable_opts.Register(&po);
133  decoder_opts.Register(&po);
134  endpoint_opts.Register(&po);
135 
136 
137  po.Read(argc, argv);
138 
139  if (po.NumArgs() != 5) {
140  po.PrintUsage();
141  return 1;
142  }
143 
144  std::string nnet3_rxfilename = po.GetArg(1),
145  fst_rxfilename = po.GetArg(2),
146  spk2utt_rspecifier = po.GetArg(3),
147  wav_rspecifier = po.GetArg(4),
148  clat_wspecifier = po.GetArg(5);
149 
150  OnlineNnet2FeaturePipelineInfo feature_info(feature_opts);
151  if (!online) {
154  chunk_length_secs = -1.0;
155  }
156 
157  Matrix<double> global_cmvn_stats;
158  if (feature_info.global_cmvn_stats_rxfilename != "")
160  &global_cmvn_stats);
161 
162  TransitionModel trans_model;
163  nnet3::AmNnetSimple am_nnet;
164  {
165  bool binary;
166  Input ki(nnet3_rxfilename, &binary);
167  trans_model.Read(ki.Stream(), binary);
168  am_nnet.Read(ki.Stream(), binary);
169  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
170  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
172  }
173 
174  // this object contains precomputed stuff that is used by all decodable
175  // objects. It takes a pointer to am_nnet because if it has iVectors it has
176  // to modify the nnet to accept iVectors at intervals.
177  nnet3::DecodableNnetSimpleLoopedInfo decodable_info(decodable_opts,
178  &am_nnet);
179 
180 
181  fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldiGeneric(fst_rxfilename);
182 
183  fst::SymbolTable *word_syms = NULL;
184  if (word_syms_rxfilename != "")
185  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
186  KALDI_ERR << "Could not read symbol table from file "
187  << word_syms_rxfilename;
188 
189  int32 num_done = 0, num_err = 0;
190  double tot_like = 0.0;
191  int64 num_frames = 0;
192 
193  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
194  RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
195  CompactLatticeWriter clat_writer(clat_wspecifier);
196 
197  OnlineTimingStats timing_stats;
198 
199  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
200  std::string spk = spk2utt_reader.Key();
201  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
202 
203  OnlineIvectorExtractorAdaptationState adaptation_state(
204  feature_info.ivector_extractor_info);
205  OnlineCmvnState cmvn_state(global_cmvn_stats);
206 
207  for (size_t i = 0; i < uttlist.size(); i++) {
208  std::string utt = uttlist[i];
209  if (!wav_reader.HasKey(utt)) {
210  KALDI_WARN << "Did not find audio for utterance " << utt;
211  num_err++;
212  continue;
213  }
214  const WaveData &wave_data = wav_reader.Value(utt);
215  // get the data for channel zero (if the signal is not mono, we only
216  // take the first channel).
217  SubVector<BaseFloat> data(wave_data.Data(), 0);
218 
219  OnlineNnet2FeaturePipeline feature_pipeline(feature_info);
220  feature_pipeline.SetAdaptationState(adaptation_state);
221  feature_pipeline.SetCmvnState(cmvn_state);
222 
223  OnlineSilenceWeighting silence_weighting(
224  trans_model,
225  feature_info.silence_weighting_config,
226  decodable_opts.frame_subsampling_factor);
227 
228  SingleUtteranceNnet3Decoder decoder(decoder_opts, trans_model,
229  decodable_info,
230  *decode_fst, &feature_pipeline);
231  OnlineTimer decoding_timer(utt);
232 
233  BaseFloat samp_freq = wave_data.SampFreq();
234  int32 chunk_length;
235  if (chunk_length_secs > 0) {
236  chunk_length = int32(samp_freq * chunk_length_secs);
237  if (chunk_length == 0) chunk_length = 1;
238  } else {
239  chunk_length = std::numeric_limits<int32>::max();
240  }
241 
242  int32 samp_offset = 0;
243  std::vector<std::pair<int32, BaseFloat> > delta_weights;
244 
245  while (samp_offset < data.Dim()) {
246  int32 samp_remaining = data.Dim() - samp_offset;
247  int32 num_samp = chunk_length < samp_remaining ? chunk_length
248  : samp_remaining;
249 
250  SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
251  feature_pipeline.AcceptWaveform(samp_freq, wave_part);
252 
253  samp_offset += num_samp;
254  decoding_timer.WaitUntil(samp_offset / samp_freq);
255  if (samp_offset == data.Dim()) {
256  // no more input. flush out last frames
257  feature_pipeline.InputFinished();
258  }
259 
260  if (silence_weighting.Active() &&
261  feature_pipeline.IvectorFeature() != NULL) {
262  silence_weighting.ComputeCurrentTraceback(decoder.Decoder());
263  silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(),
264  &delta_weights);
265  feature_pipeline.IvectorFeature()->UpdateFrameWeights(delta_weights);
266  }
267 
268  decoder.AdvanceDecoding();
269 
270  if (do_endpointing && decoder.EndpointDetected(endpoint_opts)) {
271  break;
272  }
273  }
274  decoder.FinalizeDecoding();
275 
276  CompactLattice clat;
277  bool end_of_utterance = true;
278  decoder.GetLattice(end_of_utterance, &clat);
279 
280  GetDiagnosticsAndPrintOutput(utt, word_syms, clat,
281  &num_frames, &tot_like);
282 
283  decoding_timer.OutputStats(&timing_stats);
284 
285  // In an application you might avoid updating the adaptation state if
286  // you felt the utterance had low confidence. See lat/confidence.h
287  feature_pipeline.GetAdaptationState(&adaptation_state);
288  feature_pipeline.GetCmvnState(&cmvn_state);
289 
290  // we want to output the lattice with un-scaled acoustics.
291  BaseFloat inv_acoustic_scale =
292  1.0 / decodable_opts.acoustic_scale;
293  ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat);
294 
295  clat_writer.Write(utt, clat);
296  KALDI_LOG << "Decoded utterance " << utt;
297  num_done++;
298  }
299  }
300  timing_stats.Print(online);
301 
302  KALDI_LOG << "Decoded " << num_done << " utterances, "
303  << num_err << " with errors.";
304  KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames)
305  << " per frame over " << num_frames << " frames.";
306  delete decode_fst;
307  delete word_syms; // will delete if non-NULL.
308  return (num_done != 0 ? 0 : 1);
309  } catch(const std::exception& e) {
310  std::cerr << e.what();
311  return -1;
312  }
313 } // main()
int32 words[kMaxOrder]
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
class OnlineTimer is used to test real-time decoding algorithms and evaluate how long the decoding of...
Definition: online-timing.h:88
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 g_num_threads
Definition: kaldi-thread.cc:25
This class stores the adaptation state from the online iVector extractor, which can help you to initi...
void OutputStats(OnlineTimingStats *stats)
This call, which should be made after decoding is done, writes the stats to the object that accumulat...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
BaseFloat SampFreq() const
Definition: wave-reader.h:126
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
const Nnet & GetNnet() const
void Register(OptionsItf *opts)
void GetDiagnosticsAndPrintOutput(const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like)
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
This file contains a different version of the feature-extraction pipeline in online-feature-pipeline...
void Write(const std::string &key, const T &value) const
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
This file contains some miscellaneous functions dealing with class Nnet.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
void CompactLatticeShortestPath(const CompactLattice &clat, CompactLattice *shortest_path)
A form of the shortest-path/best-path algorithm that&#39;s specially coded for CompactLattice.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Print(bool online=true)
Here, if "online == false" we take into account that the setup was used in not-really-online mode whe...
void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl< FST > &decoder)
const T & Value(const std::string &key)
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void Read(std::istream &is, bool binary)
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
void GetLattice(bool end_of_utterance, CompactLattice *clat) const
Gets the lattice.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
bool EndpointDetected(const OnlineEndpointConfig &config)
This function calls EndpointDetected from online-endpoint.h, with the required arguments.
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
bool HasKey(const std::string &key)
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
int NumArgs() const
Number of positional parameters (c.f. argc-1).
std::string global_cmvn_stats_rxfilename
Options for online cmvn, read from config file.
OnlineNnet2FeaturePipeline is a class that&#39;s responsible for putting together the various parts of th...
OnlineSilenceWeightingConfig silence_weighting_config
Config for weighting silence in iVector adaptation.
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void AdvanceDecoding()
Advances the decoding as far as we can.
int main(int argc, char *argv[])
class OnlineTimingStats stores statistics from timing of online decoding, which will enable the Print...
Definition: online-timing.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
When you instantiate class DecodableNnetSimpleLooped, you should give it a const reference to this cl...
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
const LatticeFasterOnlineDecoderTpl< FST > & Decoder() const
void WaitUntil(double cur_utterance_length)
The call to WaitUntil(t) simulates the effect of sleeping until cur_utterance_length seconds after th...
void FinalizeDecoding()
Finalizes the decoding.
Config class for the CollapseModel function.
Definition: nnet-utils.h:240
void GetDeltaWeights(int32 num_frames_ready, int32 first_decoder_frame, std::vector< std::pair< int32, BaseFloat > > *delta_weights)