online2-wav-nnet2-latgen-threaded.cc File Reference
Include dependency graph for online2-wav-nnet2-latgen-threaded.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Functions

void GetDiagnosticsAndPrintOutput (const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 75 of file online2-wav-nnet2-latgen-threaded.cc.

References OnlineNnet2DecodingThreadedConfig::acoustic_scale, fst::AcousticLatticeScale(), WaveData::Data(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), kaldi::g_num_threads, ParseOptions::GetArg(), kaldi::GetDiagnosticsAndPrintOutput(), OnlineNnet2FeaturePipelineInfo::global_cmvn_stats_rxfilename, OnlineIvectorExtractionInfo::greedy_ivector_extractor, RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, OnlineNnet2FeaturePipelineInfo::ivector_extractor_info, KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), OnlineTimer::OutputStats(), OnlineTimingStats::Print(), ParseOptions::PrintUsage(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), fst::ReadFstKaldiGeneric(), kaldi::ReadKaldiObject(), ParseOptions::Register(), OnlineNnet2FeaturePipelineConfig::Register(), OnlineEndpointConfig::Register(), OnlineNnet2DecodingThreadedConfig::Register(), WaveData::SampFreq(), fst::ScaleLattice(), kaldi::Sleep(), OnlineTimer::SleepUntil(), Input::Stream(), OnlineIvectorExtractionInfo::use_most_recent_ivector, RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

75  {
76  try {
77  using namespace kaldi;
78  using namespace fst;
79 
80  typedef kaldi::int32 int32;
81  typedef kaldi::int64 int64;
82 
83  const char *usage =
84  "Reads in wav file(s) and simulates online decoding with neural nets\n"
85  "(nnet2 setup), with optional iVector-based speaker adaptation and\n"
86  "optional endpointing. This version uses multiple threads for decoding.\n"
87  "Note: some configuration values and inputs are set via config files\n"
88  "whose filenames are passed as options\n"
89  "\n"
90  "Usage: online2-wav-nnet2-latgen-threaded [options] <nnet2-in> <fst-in> "
91  "<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n"
92  "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
93  "you want to decode utterance by utterance.\n"
94  "See egs/rm/s5/local/run_online_decoding_nnet2.sh for example\n"
95  "See also online2-wav-nnet2-latgen-faster\n";
96 
97  ParseOptions po(usage);
98 
99  std::string word_syms_rxfilename;
100 
101  OnlineEndpointConfig endpoint_config;
102 
103  // feature_config includes configuration for the iVector adaptation,
104  // as well as the basic features.
105  OnlineNnet2FeaturePipelineConfig feature_config;
106  OnlineNnet2DecodingThreadedConfig nnet2_decoding_config;
107 
108  BaseFloat chunk_length_secs = 0.05;
109  bool do_endpointing = false;
110  bool modify_ivector_config = false;
111  bool simulate_realtime_decoding = true;
112 
113  po.Register("chunk-length", &chunk_length_secs,
114  "Length of chunk size in seconds, that we provide each time to the "
115  "decoder. The actual chunk sizes it processes for various stages "
116  "of decoding are dynamically determinated, and unrelated to this");
117  po.Register("word-symbol-table", &word_syms_rxfilename,
118  "Symbol table for words [for debug output]");
119  po.Register("do-endpointing", &do_endpointing,
120  "If true, apply endpoint detection");
121  po.Register("modify-ivector-config", &modify_ivector_config,
122  "If true, modifies the iVector configuration from the config files "
123  "by setting --use-most-recent-ivector=true and --greedy-ivector-extractor=true. "
124  "This will give the best possible results, but the results may become dependent "
125  "on the speed of your machine (slower machine -> better results). Compare "
126  "to the --online option in online2-wav-nnet2-latgen-faster");
127  po.Register("simulate-realtime-decoding", &simulate_realtime_decoding,
128  "If true, simulate real-time decoding scenario by providing the "
129  "data incrementally, calling sleep() until each piece is ready. "
130  "If false, don't sleep (so it will be faster).");
131  po.Register("num-threads-startup", &g_num_threads,
132  "Number of threads used when initializing iVector extractor. ");
133 
134  feature_config.Register(&po);
135  nnet2_decoding_config.Register(&po);
136  endpoint_config.Register(&po);
137 
138  po.Read(argc, argv);
139 
140  if (po.NumArgs() != 5) {
141  po.PrintUsage();
142  return 1;
143  }
144 
145  std::string nnet2_rxfilename = po.GetArg(1),
146  fst_rxfilename = po.GetArg(2),
147  spk2utt_rspecifier = po.GetArg(3),
148  wav_rspecifier = po.GetArg(4),
149  clat_wspecifier = po.GetArg(5);
150 
151  OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
152 
153  if (modify_ivector_config) {
154  feature_info.ivector_extractor_info.use_most_recent_ivector = true;
155  feature_info.ivector_extractor_info.greedy_ivector_extractor = true;
156  }
157 
158  Matrix<double> global_cmvn_stats;
159  if (feature_info.global_cmvn_stats_rxfilename != "")
160  ReadKaldiObject(feature_info.global_cmvn_stats_rxfilename,
161  &global_cmvn_stats);
162 
163  TransitionModel trans_model;
164  nnet2::AmNnet am_nnet;
165  {
166  bool binary;
167  Input ki(nnet2_rxfilename, &binary);
168  trans_model.Read(ki.Stream(), binary);
169  am_nnet.Read(ki.Stream(), binary);
170  }
171 
172  fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldiGeneric(fst_rxfilename);
173 
174  fst::SymbolTable *word_syms = NULL;
175  if (word_syms_rxfilename != "")
176  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
177  KALDI_ERR << "Could not read symbol table from file "
178  << word_syms_rxfilename;
179 
180  int32 num_done = 0, num_err = 0;
181  double tot_like = 0.0;
182  int64 num_frames = 0;
183  Timer global_timer;
184 
185  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
186  RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
187  CompactLatticeWriter clat_writer(clat_wspecifier);
188 
189  OnlineTimingStats timing_stats;
190 
191  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
192  std::string spk = spk2utt_reader.Key();
193  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
194 
195  OnlineIvectorExtractorAdaptationState adaptation_state(
196  feature_info.ivector_extractor_info);
197  OnlineCmvnState cmvn_state(global_cmvn_stats);
198 
199  for (size_t i = 0; i < uttlist.size(); i++) {
200  std::string utt = uttlist[i];
201  if (!wav_reader.HasKey(utt)) {
202  KALDI_WARN << "Did not find audio for utterance " << utt;
203  num_err++;
204  continue;
205  }
206  const WaveData &wave_data = wav_reader.Value(utt);
207  // get the data for channel zero (if the signal is not mono, we only
208  // take the first channel).
209  SubVector<BaseFloat> data(wave_data.Data(), 0);
210 
212  nnet2_decoding_config, trans_model, am_nnet,
213  *decode_fst, feature_info, adaptation_state, cmvn_state);
214 
215  OnlineTimer decoding_timer(utt);
216 
217  BaseFloat samp_freq = wave_data.SampFreq();
218  int32 chunk_length;
219  KALDI_ASSERT(chunk_length_secs > 0);
220  chunk_length = int32(samp_freq * chunk_length_secs);
221  if (chunk_length == 0) chunk_length = 1;
222 
223  int32 samp_offset = 0;
224  while (samp_offset < data.Dim()) {
225  int32 samp_remaining = data.Dim() - samp_offset;
226  int32 num_samp = chunk_length < samp_remaining ? chunk_length
227  : samp_remaining;
228 
229  SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
230 
231  // The endpointing code won't work if we let the waveform be given to
232  // the decoder all at once, because we'll exit this while loop, and
233  // the endpointing happens inside this while loop. The next statement
234  // is intended to prevent this from happening.
235  while (do_endpointing &&
236  decoder.NumWaveformPiecesPending() * chunk_length_secs > 2.0)
237  Sleep(0.5f);
238 
239  decoder.AcceptWaveform(samp_freq, wave_part);
240 
241  samp_offset += num_samp;
242 
243  if (simulate_realtime_decoding) {
244  // Note: the next call may actually call sleep().
245  decoding_timer.SleepUntil(samp_offset / samp_freq);
246  }
247  if (samp_offset == data.Dim()) {
248  // no more input. flush out last frames
249  decoder.InputFinished();
250  }
251 
252  if (do_endpointing && decoder.EndpointDetected(endpoint_config)) {
253  decoder.TerminateDecoding();
254  break;
255  }
256  }
257  Timer timer;
258  decoder.Wait();
259  if (simulate_realtime_decoding) {
260  KALDI_VLOG(1) << "Waited " << timer.Elapsed() << " seconds for decoder to "
261  << "finish after giving it last chunk.";
262  }
263  decoder.FinalizeDecoding();
264 
265  CompactLattice clat;
266  bool end_of_utterance = true;
267  decoder.GetLattice(end_of_utterance, &clat, NULL);
268 
269  GetDiagnosticsAndPrintOutput(utt, word_syms, clat,
270  &num_frames, &tot_like);
271 
272  decoding_timer.OutputStats(&timing_stats);
273 
274  // In an application you might avoid updating the adaptation state if
275  // you felt the utterance had low confidence. See lat/confidence.h
276  decoder.GetAdaptationState(&adaptation_state);
277  decoder.GetCmvnState(&cmvn_state);
278 
279  // we want to output the lattice with un-scaled acoustics.
280  BaseFloat inv_acoustic_scale =
281  1.0 / nnet2_decoding_config.acoustic_scale;
282  ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat);
283 
284  if (simulate_realtime_decoding) {
285  KALDI_VLOG(1) << "Adding the various end-of-utterance tasks took the "
286  << "total latency to " << timer.Elapsed() << " seconds.";
287  }
288  clat_writer.Write(utt, clat);
289  KALDI_LOG << "Decoded utterance " << utt;
290 
291  num_done++;
292  }
293  }
294  bool online = true;
295 
296  if (simulate_realtime_decoding) {
297  timing_stats.Print(online);
298  } else {
299  BaseFloat frame_shift = 0.01;
300  BaseFloat real_time_factor =
301  global_timer.Elapsed() / (frame_shift * num_frames);
302  if (num_frames > 0)
303  KALDI_LOG << "Real-time factor was " << real_time_factor
304  << " assuming frame shift of " << frame_shift;
305  }
306 
307  KALDI_LOG << "Decoded " << num_done << " utterances, "
308  << num_err << " with errors.";
309  KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames)
310  << " per frame over " << num_frames << " frames.";
311  delete decode_fst;
312  delete word_syms; // will delete if non-NULL.
313  return (num_done != 0 ? 0 : 1);
314  } catch(const std::exception& e) {
315  std::cerr << e.what();
316  return -1;
317  }
318 } // main()
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
class OnlineTimer is used to test real-time decoding algorithms and evaluate how long the decoding of...
Definition: online-timing.h:88
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void Sleep(float seconds)
Definition: kaldi-utils.cc:45
int32 g_num_threads
Definition: kaldi-thread.cc:25
This class stores the adaptation state from the online iVector extractor, which can help you to initi...
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
BaseFloat SampFreq() const
Definition: wave-reader.h:126
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
void Register(OptionsItf *opts)
void GetDiagnosticsAndPrintOutput(const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like)
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Print(bool online=true)
Here, if "online == false" we take into account that the setup was used in not-really-online mode whe...
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void Read(std::istream &is, bool binary)
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
class OnlineTimingStats stores statistics from timing of online decoding, which will enable the Print...
Definition: online-timing.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501