online2-wav-nnet3-latgen-incremental.cc
Go to the documentation of this file.
1 // online2bin/online2-wav-nnet3-latgen-incremental.cc
2 
3 // Copyright 2019 Zhehuai Chen
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "feat/wave-reader.h"
23 #include "online2/onlinebin-util.h"
24 #include "online2/online-timing.h"
26 #include "fstext/fstext-lib.h"
27 #include "lat/lattice-functions.h"
28 #include "util/kaldi-thread.h"
29 #include "nnet3/nnet-utils.h"
30 
31 namespace kaldi {
32 
33 void GetDiagnosticsAndPrintOutput(const std::string &utt,
34  const fst::SymbolTable *word_syms,
35  const CompactLattice &clat,
36  int64 *tot_num_frames,
37  double *tot_like) {
38  if (clat.NumStates() == 0) {
39  KALDI_WARN << "Empty lattice.";
40  return;
41  }
42  CompactLattice best_path_clat;
43  CompactLatticeShortestPath(clat, &best_path_clat);
44 
45  Lattice best_path_lat;
46  ConvertLattice(best_path_clat, &best_path_lat);
47 
48  double likelihood;
49  LatticeWeight weight;
50  int32 num_frames;
51  std::vector<int32> alignment;
52  std::vector<int32> words;
53  GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight);
54  num_frames = alignment.size();
55  likelihood = -(weight.Value1() + weight.Value2());
56  *tot_num_frames += num_frames;
57  *tot_like += likelihood;
58  KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is "
59  << (likelihood / num_frames) << " over " << num_frames
60  << " frames, = " << (-weight.Value1() / num_frames)
61  << ',' << (weight.Value2() / num_frames);
62 
63  if (word_syms != NULL) {
64  std::cerr << utt << ' ';
65  for (size_t i = 0; i < words.size(); i++) {
66  std::string s = word_syms->Find(words[i]);
67  if (s == "")
68  KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
69  std::cerr << s << ' ';
70  }
71  std::cerr << std::endl;
72  }
73 }
74 
75 }
76 
77 int main(int argc, char *argv[]) {
78  try {
79  using namespace kaldi;
80  using namespace fst;
81 
82  typedef kaldi::int32 int32;
83  typedef kaldi::int64 int64;
84 
85  const char *usage =
86  "Reads in wav file(s) and simulates online decoding with neural nets\n"
87  "(nnet3 setup), with optional iVector-based speaker adaptation and\n"
88  "optional endpointing. Note: some configuration values and inputs are\n"
89  "set via config files whose filenames are passed as options\n"
90  "The lattice determinization algorithm here can operate\n"
91  "incrementally.\n"
92  "\n"
93  "Usage: online2-wav-nnet3-latgen-incremental [options] <nnet3-in> <fst-in> "
94  "<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n"
95  "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
96  "you want to decode utterance by utterance.\n";
97 
98  ParseOptions po(usage);
99 
100  std::string word_syms_rxfilename;
101 
102  // feature_opts includes configuration for the iVector adaptation,
103  // as well as the basic features.
106  LatticeIncrementalDecoderConfig decoder_opts;
107  OnlineEndpointConfig endpoint_opts;
108 
109  BaseFloat chunk_length_secs = 0.18;
110  bool do_endpointing = false;
111  bool online = true;
112 
113  po.Register("chunk-length", &chunk_length_secs,
114  "Length of chunk size in seconds, that we process. Set to <= 0 "
115  "to use all input in one chunk.");
116  po.Register("word-symbol-table", &word_syms_rxfilename,
117  "Symbol table for words [for debug output]");
118  po.Register("do-endpointing", &do_endpointing,
119  "If true, apply endpoint detection");
120  po.Register("online", &online,
121  "You can set this to false to disable online iVector estimation "
122  "and have all the data for each utterance used, even at "
123  "utterance start. This is useful where you just want the best "
124  "results and don't care about online operation. Setting this to "
125  "false has the same effect as setting "
126  "--use-most-recent-ivector=true and --greedy-ivector-extractor=true "
127  "in the file given to --ivector-extraction-config, and "
128  "--chunk-length=-1.");
129  po.Register("num-threads-startup", &g_num_threads,
130  "Number of threads used when initializing iVector extractor.");
131 
132  feature_opts.Register(&po);
133  decodable_opts.Register(&po);
134  decoder_opts.Register(&po);
135  endpoint_opts.Register(&po);
136 
137 
138  po.Read(argc, argv);
139 
140  if (po.NumArgs() != 5) {
141  po.PrintUsage();
142  return 1;
143  }
144 
145  std::string nnet3_rxfilename = po.GetArg(1),
146  fst_rxfilename = po.GetArg(2),
147  spk2utt_rspecifier = po.GetArg(3),
148  wav_rspecifier = po.GetArg(4),
149  clat_wspecifier = po.GetArg(5);
150 
151  OnlineNnet2FeaturePipelineInfo feature_info(feature_opts);
152 
153  if (!online) {
156  chunk_length_secs = -1.0;
157  }
158 
159  TransitionModel trans_model;
160  nnet3::AmNnetSimple am_nnet;
161  {
162  bool binary;
163  Input ki(nnet3_rxfilename, &binary);
164  trans_model.Read(ki.Stream(), binary);
165  am_nnet.Read(ki.Stream(), binary);
166  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
167  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
169  }
170 
171  // this object contains precomputed stuff that is used by all decodable
172  // objects. It takes a pointer to am_nnet because if it has iVectors it has
173  // to modify the nnet to accept iVectors at intervals.
174  nnet3::DecodableNnetSimpleLoopedInfo decodable_info(decodable_opts,
175  &am_nnet);
176 
177 
178  fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldiGeneric(fst_rxfilename);
179 
180  fst::SymbolTable *word_syms = NULL;
181  if (word_syms_rxfilename != "")
182  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
183  KALDI_ERR << "Could not read symbol table from file "
184  << word_syms_rxfilename;
185 
186  int32 num_done = 0, num_err = 0;
187  double tot_like = 0.0;
188  int64 num_frames = 0;
189 
190  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
191  RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
192  CompactLatticeWriter clat_writer(clat_wspecifier);
193 
194  OnlineTimingStats timing_stats;
195 
196  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
197  std::string spk = spk2utt_reader.Key();
198  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
199  OnlineIvectorExtractorAdaptationState adaptation_state(
200  feature_info.ivector_extractor_info);
201  for (size_t i = 0; i < uttlist.size(); i++) {
202  std::string utt = uttlist[i];
203  if (!wav_reader.HasKey(utt)) {
204  KALDI_WARN << "Did not find audio for utterance " << utt;
205  num_err++;
206  continue;
207  }
208  const WaveData &wave_data = wav_reader.Value(utt);
209  // get the data for channel zero (if the signal is not mono, we only
210  // take the first channel).
211  SubVector<BaseFloat> data(wave_data.Data(), 0);
212 
213  OnlineNnet2FeaturePipeline feature_pipeline(feature_info);
214  feature_pipeline.SetAdaptationState(adaptation_state);
215 
216  OnlineSilenceWeighting silence_weighting(
217  trans_model,
218  feature_info.silence_weighting_config,
219  decodable_opts.frame_subsampling_factor);
220 
221  SingleUtteranceNnet3IncrementalDecoder decoder(decoder_opts, trans_model,
222  decodable_info,
223  *decode_fst, &feature_pipeline);
224  OnlineTimer decoding_timer(utt);
225 
226  BaseFloat samp_freq = wave_data.SampFreq();
227  int32 chunk_length;
228  if (chunk_length_secs > 0) {
229  chunk_length = int32(samp_freq * chunk_length_secs);
230  if (chunk_length == 0) chunk_length = 1;
231  } else {
232  chunk_length = std::numeric_limits<int32>::max();
233  }
234 
235  int32 samp_offset = 0;
236  std::vector<std::pair<int32, BaseFloat> > delta_weights;
237 
238  while (samp_offset < data.Dim()) {
239  int32 samp_remaining = data.Dim() - samp_offset;
240  int32 num_samp = chunk_length < samp_remaining ? chunk_length
241  : samp_remaining;
242 
243  SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
244  feature_pipeline.AcceptWaveform(samp_freq, wave_part);
245 
246  samp_offset += num_samp;
247  decoding_timer.WaitUntil(samp_offset / samp_freq);
248  if (samp_offset == data.Dim()) {
249  // no more input. flush out last frames
250  feature_pipeline.InputFinished();
251  }
252 
253  if (silence_weighting.Active() &&
254  feature_pipeline.IvectorFeature() != NULL) {
255  silence_weighting.ComputeCurrentTraceback(decoder.Decoder());
256  silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(),
257  &delta_weights);
258  feature_pipeline.IvectorFeature()->UpdateFrameWeights(delta_weights);
259  }
260 
261  decoder.AdvanceDecoding();
262 
263  if (do_endpointing && decoder.EndpointDetected(endpoint_opts)) {
264  break;
265  }
266  }
267  decoder.FinalizeDecoding();
268 
269  bool use_final_probs = true;
270  CompactLattice clat = decoder.GetLattice(decoder.NumFramesDecoded(),
271  use_final_probs);
272 
273  Connect(&clat);
274  GetDiagnosticsAndPrintOutput(utt, word_syms, clat,
275  &num_frames, &tot_like);
276 
277  decoding_timer.OutputStats(&timing_stats);
278 
279  // In an application you might avoid updating the adaptation state if
280  // you felt the utterance had low confidence. See lat/confidence.h
281  feature_pipeline.GetAdaptationState(&adaptation_state);
282 
283  // we want to output the lattice with un-scaled acoustics.
284  BaseFloat inv_acoustic_scale =
285  1.0 / decodable_opts.acoustic_scale;
286  ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat);
287 
288  clat_writer.Write(utt, clat);
289  KALDI_LOG << "Decoded utterance " << utt;
290  num_done++;
291  }
292  }
293  timing_stats.Print(online);
294 
295  KALDI_LOG << "Decoded " << num_done << " utterances, "
296  << num_err << " with errors.";
297  KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames)
298  << " per frame over " << num_frames << " frames.";
299  delete decode_fst;
300  delete word_syms; // will delete if non-NULL.
301  return (num_done != 0 ? 0 : 1);
302  } catch(const std::exception& e) {
303  std::cerr << e.what();
304  return -1;
305  }
306 } // main()
int32 words[kMaxOrder]
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
class OnlineTimer is used to test real-time decoding algorithms and evaluate how long the decoding of...
Definition: online-timing.h:88
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
const CompactLattice & GetLattice(int32 num_frames_to_include, bool use_final_probs=false)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 g_num_threads
Definition: kaldi-thread.cc:25
This class stores the adaptation state from the online iVector extractor, which can help you to initi...
void OutputStats(OnlineTimingStats *stats)
This call, which should be made after decoding is done, writes the stats to the object that accumulat...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
BaseFloat SampFreq() const
Definition: wave-reader.h:126
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
const Nnet & GetNnet() const
void Register(OptionsItf *opts)
void GetDiagnosticsAndPrintOutput(const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like)
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
This file contains a different version of the feature-extraction pipeline in online-feature-pipeline...
void Write(const std::string &key, const T &value) const
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
This file contains some miscellaneous functions dealing with class Nnet.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
void CompactLatticeShortestPath(const CompactLattice &clat, CompactLattice *shortest_path)
A form of the shortest-path/best-path algorithm that&#39;s specially coded for CompactLattice.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Print(bool online=true)
Here, if "online == false" we take into account that the setup was used in not-really-online mode whe...
void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl< FST > &decoder)
const T & Value(const std::string &key)
void AdvanceDecoding()
Advances the decoding as far as we can.
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void Read(std::istream &is, bool binary)
int main(int argc, char *argv[])
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
bool HasKey(const std::string &key)
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
const LatticeIncrementalOnlineDecoderTpl< FST > & Decoder() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
The normal decoder, lattice-faster-decoder.h, sometimes has an issue when doing real-time application...
bool EndpointDetected(const OnlineEndpointConfig &config)
This function calls EndpointDetected from online-endpoint.h, with the required arguments.
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
OnlineNnet2FeaturePipeline is a class that&#39;s responsible for putting together the various parts of th...
OnlineSilenceWeightingConfig silence_weighting_config
Config for weighting silence in iVector adaptation.
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
class OnlineTimingStats stores statistics from timing of online decoding, which will enable the Print...
Definition: online-timing.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
When you instantiate class DecodableNnetSimpleLooped, you should give it a const reference to this cl...
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
void WaitUntil(double cur_utterance_length)
The call to WaitUntil(t) simulates the effect of sleeping until cur_utterance_length seconds after th...
Config class for the CollapseModel function.
Definition: nnet-utils.h:240
void GetDeltaWeights(int32 num_frames_ready, int32 first_decoder_frame, std::vector< std::pair< int32, BaseFloat > > *delta_weights)