80 using namespace kaldi;
84 typedef kaldi::int64 int64;
87 "Reads in wav file(s) and simulates online decoding with neural nets\n" 88 "(nnet3 setup), with optional iVector-based speaker adaptation and\n" 89 "optional endpointing. Note: some configuration values and inputs are\n" 90 "set via config files whose filenames are passed as options.\n" 91 "This program like online2-wav-nnet3-latgen-faster but when the FST to\n" 92 "be decoded is of type GrammarFst.\n" 94 "Usage: online2-wav-nnet3-latgen-grammar [options] <nnet3-in> <fst-in> " 95 "<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n" 96 "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n" 97 "you want to decode utterance by utterance.\n";
101 std::string word_syms_rxfilename;
111 bool do_endpointing =
false;
114 po.Register(
"chunk-length", &chunk_length_secs,
115 "Length of chunk size in seconds, that we process. Set to <= 0 " 116 "to use all input in one chunk.");
117 po.Register(
"word-symbol-table", &word_syms_rxfilename,
118 "Symbol table for words [for debug output]");
119 po.Register(
"do-endpointing", &do_endpointing,
120 "If true, apply endpoint detection");
121 po.Register(
"online", &online,
122 "You can set this to false to disable online iVector estimation " 123 "and have all the data for each utterance used, even at " 124 "utterance start. This is useful where you just want the best " 125 "results and don't care about online operation. Setting this to " 126 "false has the same effect as setting " 127 "--use-most-recent-ivector=true and --greedy-ivector-extractor=true " 128 "in the file given to --ivector-extraction-config, and " 129 "--chunk-length=-1.");
131 "Number of threads used when initializing iVector extractor.");
141 if (po.NumArgs() != 5) {
146 std::string nnet3_rxfilename = po.GetArg(1),
147 fst_rxfilename = po.GetArg(2),
148 spk2utt_rspecifier = po.GetArg(3),
149 wav_rspecifier = po.GetArg(4),
150 clat_wspecifier = po.GetArg(5);
154 feature_info.ivector_extractor_info.use_most_recent_ivector =
true;
155 feature_info.ivector_extractor_info.greedy_ivector_extractor =
true;
156 chunk_length_secs = -1.0;
160 if (feature_info.global_cmvn_stats_rxfilename !=
"")
168 Input ki(nnet3_rxfilename, &binary);
169 trans_model.
Read(ki.Stream(), binary);
170 am_nnet.
Read(ki.Stream(), binary);
186 fst::SymbolTable *word_syms = NULL;
187 if (word_syms_rxfilename !=
"")
188 if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
189 KALDI_ERR <<
"Could not read symbol table from file " 190 << word_syms_rxfilename;
192 int32 num_done = 0, num_err = 0;
193 double tot_like = 0.0;
194 int64 num_frames = 0;
202 for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
203 std::string spk = spk2utt_reader.Key();
204 const std::vector<std::string> &uttlist = spk2utt_reader.Value();
207 feature_info.ivector_extractor_info);
210 for (
size_t i = 0;
i < uttlist.size();
i++) {
211 std::string utt = uttlist[
i];
212 if (!wav_reader.HasKey(utt)) {
213 KALDI_WARN <<
"Did not find audio for utterance " << utt;
217 const WaveData &wave_data = wav_reader.Value(utt);
223 feature_pipeline.SetAdaptationState(adaptation_state);
224 feature_pipeline.SetCmvnState(cmvn_state);
228 feature_info.silence_weighting_config,
232 decoder_opts, trans_model,
233 decodable_info, fst, &feature_pipeline);
239 if (chunk_length_secs > 0) {
240 chunk_length =
int32(samp_freq * chunk_length_secs);
241 if (chunk_length == 0) chunk_length = 1;
243 chunk_length = std::numeric_limits<int32>::max();
246 int32 samp_offset = 0;
247 std::vector<std::pair<int32, BaseFloat> > delta_weights;
249 while (samp_offset < data.Dim()) {
250 int32 samp_remaining = data.Dim() - samp_offset;
251 int32 num_samp = chunk_length < samp_remaining ? chunk_length
255 feature_pipeline.AcceptWaveform(samp_freq, wave_part);
257 samp_offset += num_samp;
258 decoding_timer.WaitUntil(samp_offset / samp_freq);
259 if (samp_offset == data.Dim()) {
261 feature_pipeline.InputFinished();
264 if (silence_weighting.Active() &&
265 feature_pipeline.IvectorFeature() != NULL) {
266 silence_weighting.ComputeCurrentTraceback(decoder.Decoder());
267 silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(),
269 feature_pipeline.IvectorFeature()->UpdateFrameWeights(delta_weights);
272 decoder.AdvanceDecoding();
274 if (do_endpointing && decoder.EndpointDetected(endpoint_opts)) {
278 decoder.FinalizeDecoding();
281 bool end_of_utterance =
true;
282 decoder.GetLattice(end_of_utterance, &clat);
285 &num_frames, &tot_like);
287 decoding_timer.OutputStats(&timing_stats);
291 feature_pipeline.GetAdaptationState(&adaptation_state);
292 feature_pipeline.GetCmvnState(&cmvn_state);
299 clat_writer.Write(utt, clat);
300 KALDI_LOG <<
"Decoded utterance " << utt;
304 timing_stats.
Print(online);
306 KALDI_LOG <<
"Decoded " << num_done <<
" utterances, " 307 << num_err <<
" with errors.";
308 KALDI_LOG <<
"Overall likelihood per frame was " << (tot_like / num_frames)
309 <<
" per frame over " << num_frames <<
" frames.";
311 return (num_done != 0 ? 0 : 1);
312 }
catch(
const std::exception& e) {
313 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
class OnlineTimer is used to test real-time decoding algorithms and evaluate how long the decoding of...
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
int32 frame_subsampling_factor
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
A templated class for writing objects to an archive or script file; see The Table concept...
BaseFloat SampFreq() const
const Matrix< BaseFloat > & Data() const
const Nnet & GetNnet() const
void Register(OptionsItf *opts)
void GetDiagnosticsAndPrintOutput(const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like)
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void Read(std::istream &is, bool binary)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Print(bool online=true)
Here, if "online == false" we take into account that the setup was used in not-really-online mode whe...
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void Read(std::istream &is, bool binary)
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
void Register(OptionsItf *opts)
GrammarFst is an FST that is 'stitched together' from multiple FSTs, that can recursively incorporate...
fst::VectorFst< CompactLatticeArc > CompactLattice
This class's purpose is to read in Wave files.
OnlineNnet2FeaturePipeline is a class that's responsible for putting together the various parts of th...
class OnlineTimingStats stores statistics from timing of online decoding, which will enable the Print...
void Register(OptionsItf *opts)
When you instantiate class DecodableNnetSimpleLooped, you should give it a const reference to this cl...
void Register(OptionsItf *opts)
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Config class for the CollapseModel function.