online2-wav-nnet3-latgen-grammar.cc File Reference
Include dependency graph for online2-wav-nnet3-latgen-grammar.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Functions

void GetDiagnosticsAndPrintOutput (const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 78 of file online2-wav-nnet3-latgen-grammar.cc.

References NnetSimpleLoopedComputationOptions::acoustic_scale, fst::AcousticLatticeScale(), OnlineSilenceWeighting::Active(), SingleUtteranceNnet3DecoderTpl< FST >::AdvanceDecoding(), kaldi::nnet3::CollapseModel(), OnlineSilenceWeighting::ComputeCurrentTraceback(), WaveData::Data(), SingleUtteranceNnet3DecoderTpl< FST >::Decoder(), SequentialTableReader< Holder >::Done(), SingleUtteranceNnet3DecoderTpl< FST >::EndpointDetected(), SingleUtteranceNnet3DecoderTpl< FST >::FinalizeDecoding(), NnetSimpleLoopedComputationOptions::frame_subsampling_factor, kaldi::g_num_threads, ParseOptions::GetArg(), OnlineSilenceWeighting::GetDeltaWeights(), kaldi::GetDiagnosticsAndPrintOutput(), SingleUtteranceNnet3DecoderTpl< FST >::GetLattice(), AmNnetSimple::GetNnet(), OnlineNnet2FeaturePipelineInfo::global_cmvn_stats_rxfilename, OnlineIvectorExtractionInfo::greedy_ivector_extractor, RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, OnlineNnet2FeaturePipelineInfo::ivector_extractor_info, KALDI_ERR, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), OnlineTimer::OutputStats(), OnlineTimingStats::Print(), ParseOptions::PrintUsage(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), NnetSimpleLoopedComputationOptions::Register(), LatticeFasterDecoderConfig::Register(), ParseOptions::Register(), OnlineNnet2FeaturePipelineConfig::Register(), OnlineEndpointConfig::Register(), WaveData::SampFreq(), fst::ScaleLattice(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), OnlineNnet2FeaturePipelineInfo::silence_weighting_config, Input::Stream(), OnlineIvectorExtractionInfo::use_most_recent_ivector, RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), OnlineTimer::WaitUntil(), and TableWriter< Holder >::Write().

78  {
79  try {
80  using namespace kaldi;
81  using namespace fst;
82 
83  typedef kaldi::int32 int32;
84  typedef kaldi::int64 int64;
85 
86  const char *usage =
87  "Reads in wav file(s) and simulates online decoding with neural nets\n"
88  "(nnet3 setup), with optional iVector-based speaker adaptation and\n"
89  "optional endpointing. Note: some configuration values and inputs are\n"
90  "set via config files whose filenames are passed as options.\n"
91  "This program like online2-wav-nnet3-latgen-faster but when the FST to\n"
92  "be decoded is of type GrammarFst.\n"
93  "\n"
94  "Usage: online2-wav-nnet3-latgen-grammar [options] <nnet3-in> <fst-in> "
95  "<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n"
96  "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
97  "you want to decode utterance by utterance.\n";
98 
99  ParseOptions po(usage);
100 
101  std::string word_syms_rxfilename;
102 
103  // feature_opts includes configuration for the iVector adaptation,
104  // as well as the basic features.
107  LatticeFasterDecoderConfig decoder_opts;
108  OnlineEndpointConfig endpoint_opts;
109 
110  BaseFloat chunk_length_secs = 0.18;
111  bool do_endpointing = false;
112  bool online = true;
113 
114  po.Register("chunk-length", &chunk_length_secs,
115  "Length of chunk size in seconds, that we process. Set to <= 0 "
116  "to use all input in one chunk.");
117  po.Register("word-symbol-table", &word_syms_rxfilename,
118  "Symbol table for words [for debug output]");
119  po.Register("do-endpointing", &do_endpointing,
120  "If true, apply endpoint detection");
121  po.Register("online", &online,
122  "You can set this to false to disable online iVector estimation "
123  "and have all the data for each utterance used, even at "
124  "utterance start. This is useful where you just want the best "
125  "results and don't care about online operation. Setting this to "
126  "false has the same effect as setting "
127  "--use-most-recent-ivector=true and --greedy-ivector-extractor=true "
128  "in the file given to --ivector-extraction-config, and "
129  "--chunk-length=-1.");
130  po.Register("num-threads-startup", &g_num_threads,
131  "Number of threads used when initializing iVector extractor.");
132 
133  feature_opts.Register(&po);
134  decodable_opts.Register(&po);
135  decoder_opts.Register(&po);
136  endpoint_opts.Register(&po);
137 
138 
139  po.Read(argc, argv);
140 
141  if (po.NumArgs() != 5) {
142  po.PrintUsage();
143  return 1;
144  }
145 
146  std::string nnet3_rxfilename = po.GetArg(1),
147  fst_rxfilename = po.GetArg(2),
148  spk2utt_rspecifier = po.GetArg(3),
149  wav_rspecifier = po.GetArg(4),
150  clat_wspecifier = po.GetArg(5);
151 
152  OnlineNnet2FeaturePipelineInfo feature_info(feature_opts);
153  if (!online) {
154  feature_info.ivector_extractor_info.use_most_recent_ivector = true;
155  feature_info.ivector_extractor_info.greedy_ivector_extractor = true;
156  chunk_length_secs = -1.0;
157  }
158 
159  Matrix<double> global_cmvn_stats;
160  if (feature_info.global_cmvn_stats_rxfilename != "")
161  ReadKaldiObject(feature_info.global_cmvn_stats_rxfilename,
162  &global_cmvn_stats);
163 
164  TransitionModel trans_model;
165  nnet3::AmNnetSimple am_nnet;
166  {
167  bool binary;
168  Input ki(nnet3_rxfilename, &binary);
169  trans_model.Read(ki.Stream(), binary);
170  am_nnet.Read(ki.Stream(), binary);
171  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
172  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
174  }
175 
176  // this object contains precomputed stuff that is used by all decodable
177  // objects. It takes a pointer to am_nnet because if it has iVectors it has
178  // to modify the nnet to accept iVectors at intervals.
179  nnet3::DecodableNnetSimpleLoopedInfo decodable_info(decodable_opts,
180  &am_nnet);
181 
182 
184  ReadKaldiObject(fst_rxfilename, &fst);
185 
186  fst::SymbolTable *word_syms = NULL;
187  if (word_syms_rxfilename != "")
188  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
189  KALDI_ERR << "Could not read symbol table from file "
190  << word_syms_rxfilename;
191 
192  int32 num_done = 0, num_err = 0;
193  double tot_like = 0.0;
194  int64 num_frames = 0;
195 
196  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
197  RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
198  CompactLatticeWriter clat_writer(clat_wspecifier);
199 
200  OnlineTimingStats timing_stats;
201 
202  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
203  std::string spk = spk2utt_reader.Key();
204  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
205 
206  OnlineIvectorExtractorAdaptationState adaptation_state(
207  feature_info.ivector_extractor_info);
208  OnlineCmvnState cmvn_state(global_cmvn_stats);
209 
210  for (size_t i = 0; i < uttlist.size(); i++) {
211  std::string utt = uttlist[i];
212  if (!wav_reader.HasKey(utt)) {
213  KALDI_WARN << "Did not find audio for utterance " << utt;
214  num_err++;
215  continue;
216  }
217  const WaveData &wave_data = wav_reader.Value(utt);
218  // get the data for channel zero (if the signal is not mono, we only
219  // take the first channel).
220  SubVector<BaseFloat> data(wave_data.Data(), 0);
221 
222  OnlineNnet2FeaturePipeline feature_pipeline(feature_info);
223  feature_pipeline.SetAdaptationState(adaptation_state);
224  feature_pipeline.SetCmvnState(cmvn_state);
225 
226  OnlineSilenceWeighting silence_weighting(
227  trans_model,
228  feature_info.silence_weighting_config,
229  decodable_opts.frame_subsampling_factor);
230 
232  decoder_opts, trans_model,
233  decodable_info, fst, &feature_pipeline);
234 
235  OnlineTimer decoding_timer(utt);
236 
237  BaseFloat samp_freq = wave_data.SampFreq();
238  int32 chunk_length;
239  if (chunk_length_secs > 0) {
240  chunk_length = int32(samp_freq * chunk_length_secs);
241  if (chunk_length == 0) chunk_length = 1;
242  } else {
243  chunk_length = std::numeric_limits<int32>::max();
244  }
245 
246  int32 samp_offset = 0;
247  std::vector<std::pair<int32, BaseFloat> > delta_weights;
248 
249  while (samp_offset < data.Dim()) {
250  int32 samp_remaining = data.Dim() - samp_offset;
251  int32 num_samp = chunk_length < samp_remaining ? chunk_length
252  : samp_remaining;
253 
254  SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
255  feature_pipeline.AcceptWaveform(samp_freq, wave_part);
256 
257  samp_offset += num_samp;
258  decoding_timer.WaitUntil(samp_offset / samp_freq);
259  if (samp_offset == data.Dim()) {
260  // no more input. flush out last frames
261  feature_pipeline.InputFinished();
262  }
263 
264  if (silence_weighting.Active() &&
265  feature_pipeline.IvectorFeature() != NULL) {
266  silence_weighting.ComputeCurrentTraceback(decoder.Decoder());
267  silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(),
268  &delta_weights);
269  feature_pipeline.IvectorFeature()->UpdateFrameWeights(delta_weights);
270  }
271 
272  decoder.AdvanceDecoding();
273 
274  if (do_endpointing && decoder.EndpointDetected(endpoint_opts)) {
275  break;
276  }
277  }
278  decoder.FinalizeDecoding();
279 
280  CompactLattice clat;
281  bool end_of_utterance = true;
282  decoder.GetLattice(end_of_utterance, &clat);
283 
284  GetDiagnosticsAndPrintOutput(utt, word_syms, clat,
285  &num_frames, &tot_like);
286 
287  decoding_timer.OutputStats(&timing_stats);
288 
289  // In an application you might avoid updating the adaptation state if
290  // you felt the utterance had low confidence. See lat/confidence.h
291  feature_pipeline.GetAdaptationState(&adaptation_state);
292  feature_pipeline.GetCmvnState(&cmvn_state);
293 
294  // we want to output the lattice with un-scaled acoustics.
295  BaseFloat inv_acoustic_scale =
296  1.0 / decodable_opts.acoustic_scale;
297  ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat);
298 
299  clat_writer.Write(utt, clat);
300  KALDI_LOG << "Decoded utterance " << utt;
301  num_done++;
302  }
303  }
304  timing_stats.Print(online);
305 
306  KALDI_LOG << "Decoded " << num_done << " utterances, "
307  << num_err << " with errors.";
308  KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames)
309  << " per frame over " << num_frames << " frames.";
310  delete word_syms; // will delete if non-NULL.
311  return (num_done != 0 ? 0 : 1);
312  } catch(const std::exception& e) {
313  std::cerr << e.what();
314  return -1;
315  }
316 } // main()
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
class OnlineTimer is used to test real-time decoding algorithms and evaluate how long the decoding of...
Definition: online-timing.h:88
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
int32 g_num_threads
Definition: kaldi-thread.cc:25
This class stores the adaptation state from the online iVector extractor, which can help you to initi...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
BaseFloat SampFreq() const
Definition: wave-reader.h:126
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
const Nnet & GetNnet() const
void Register(OptionsItf *opts)
void GetDiagnosticsAndPrintOutput(const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like)
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void Read(std::istream &is, bool binary)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Print(bool online=true)
Here, if "online == false" we take into account that the setup was used in not-really-online mode whe...
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void Read(std::istream &is, bool binary)
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
#define KALDI_ERR
Definition: kaldi-error.h:147
GrammarFst is an FST that is &#39;stitched together&#39; from multiple FSTs, that can recursively incorporate...
Definition: grammar-fst.h:96
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
OnlineNnet2FeaturePipeline is a class that&#39;s responsible for putting together the various parts of th...
class OnlineTimingStats stores statistics from timing of online decoding, which will enable the Print...
Definition: online-timing.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
When you instantiate class DecodableNnetSimpleLooped, you should give it a const reference to this cl...
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
Config class for the CollapseModel function.
Definition: nnet-utils.h:240