77 using namespace kaldi;
81 typedef kaldi::int64 int64;
84 "Reads in wav file(s) and simulates online decoding with neural nets\n" 85 "(nnet2 setup), with optional iVector-based speaker adaptation and\n" 86 "optional endpointing. This version uses multiple threads for decoding.\n" 87 "Note: some configuration values and inputs are set via config files\n" 88 "whose filenames are passed as options\n" 90 "Usage: online2-wav-nnet2-latgen-threaded [options] <nnet2-in> <fst-in> " 91 "<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n" 92 "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n" 93 "you want to decode utterance by utterance.\n" 94 "See egs/rm/s5/local/run_online_decoding_nnet2.sh for example\n" 95 "See also online2-wav-nnet2-latgen-faster\n";
99 std::string word_syms_rxfilename;
109 bool do_endpointing =
false;
110 bool modify_ivector_config =
false;
111 bool simulate_realtime_decoding =
true;
113 po.Register(
"chunk-length", &chunk_length_secs,
114 "Length of chunk size in seconds, that we provide each time to the " 115 "decoder. The actual chunk sizes it processes for various stages " 116 "of decoding are dynamically determinated, and unrelated to this");
117 po.Register(
"word-symbol-table", &word_syms_rxfilename,
118 "Symbol table for words [for debug output]");
119 po.Register(
"do-endpointing", &do_endpointing,
120 "If true, apply endpoint detection");
121 po.Register(
"modify-ivector-config", &modify_ivector_config,
122 "If true, modifies the iVector configuration from the config files " 123 "by setting --use-most-recent-ivector=true and --greedy-ivector-extractor=true. " 124 "This will give the best possible results, but the results may become dependent " 125 "on the speed of your machine (slower machine -> better results). Compare " 126 "to the --online option in online2-wav-nnet2-latgen-faster");
127 po.Register(
"simulate-realtime-decoding", &simulate_realtime_decoding,
128 "If true, simulate real-time decoding scenario by providing the " 129 "data incrementally, calling sleep() until each piece is ready. " 130 "If false, don't sleep (so it will be faster).");
132 "Number of threads used when initializing iVector extractor. ");
135 nnet2_decoding_config.
Register(&po);
140 if (po.NumArgs() != 5) {
145 std::string nnet2_rxfilename = po.GetArg(1),
146 fst_rxfilename = po.GetArg(2),
147 spk2utt_rspecifier = po.GetArg(3),
148 wav_rspecifier = po.GetArg(4),
149 clat_wspecifier = po.GetArg(5);
153 if (modify_ivector_config) {
154 feature_info.ivector_extractor_info.use_most_recent_ivector =
true;
155 feature_info.ivector_extractor_info.greedy_ivector_extractor =
true;
159 if (feature_info.global_cmvn_stats_rxfilename !=
"")
167 Input ki(nnet2_rxfilename, &binary);
168 trans_model.
Read(ki.Stream(), binary);
169 am_nnet.
Read(ki.Stream(), binary);
174 fst::SymbolTable *word_syms = NULL;
175 if (word_syms_rxfilename !=
"")
176 if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
177 KALDI_ERR <<
"Could not read symbol table from file " 178 << word_syms_rxfilename;
180 int32 num_done = 0, num_err = 0;
181 double tot_like = 0.0;
182 int64 num_frames = 0;
191 for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
192 std::string spk = spk2utt_reader.Key();
193 const std::vector<std::string> &uttlist = spk2utt_reader.Value();
196 feature_info.ivector_extractor_info);
199 for (
size_t i = 0;
i < uttlist.size();
i++) {
200 std::string utt = uttlist[
i];
201 if (!wav_reader.HasKey(utt)) {
202 KALDI_WARN <<
"Did not find audio for utterance " << utt;
206 const WaveData &wave_data = wav_reader.Value(utt);
212 nnet2_decoding_config, trans_model, am_nnet,
213 *decode_fst, feature_info, adaptation_state, cmvn_state);
220 chunk_length =
int32(samp_freq * chunk_length_secs);
221 if (chunk_length == 0) chunk_length = 1;
223 int32 samp_offset = 0;
224 while (samp_offset < data.Dim()) {
225 int32 samp_remaining = data.Dim() - samp_offset;
226 int32 num_samp = chunk_length < samp_remaining ? chunk_length
235 while (do_endpointing &&
236 decoder.NumWaveformPiecesPending() * chunk_length_secs > 2.0)
239 decoder.AcceptWaveform(samp_freq, wave_part);
241 samp_offset += num_samp;
243 if (simulate_realtime_decoding) {
245 decoding_timer.SleepUntil(samp_offset / samp_freq);
247 if (samp_offset == data.Dim()) {
249 decoder.InputFinished();
252 if (do_endpointing && decoder.EndpointDetected(endpoint_config)) {
253 decoder.TerminateDecoding();
259 if (simulate_realtime_decoding) {
261 <<
"finish after giving it last chunk.";
263 decoder.FinalizeDecoding();
266 bool end_of_utterance =
true;
267 decoder.GetLattice(end_of_utterance, &clat, NULL);
270 &num_frames, &tot_like);
272 decoding_timer.OutputStats(&timing_stats);
276 decoder.GetAdaptationState(&adaptation_state);
277 decoder.GetCmvnState(&cmvn_state);
284 if (simulate_realtime_decoding) {
285 KALDI_VLOG(1) <<
"Adding the various end-of-utterance tasks took the " 286 <<
"total latency to " << timer.
Elapsed() <<
" seconds.";
288 clat_writer.Write(utt, clat);
289 KALDI_LOG <<
"Decoded utterance " << utt;
296 if (simulate_realtime_decoding) {
297 timing_stats.
Print(online);
301 global_timer.
Elapsed() / (frame_shift * num_frames);
303 KALDI_LOG <<
"Real-time factor was " << real_time_factor
304 <<
" assuming frame shift of " << frame_shift;
307 KALDI_LOG <<
"Decoded " << num_done <<
" utterances, " 308 << num_err <<
" with errors.";
309 KALDI_LOG <<
"Overall likelihood per frame was " << (tot_like / num_frames)
310 <<
" per frame over " << num_frames <<
" frames.";
313 return (num_done != 0 ? 0 : 1);
314 }
catch(
const std::exception& e) {
315 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
class OnlineTimer is used to test real-time decoding algorithms and evaluate how long the decoding of...
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
void Sleep(float seconds)
void Read(std::istream &is, bool binary)
A templated class for writing objects to an archive or script file; see The Table concept...
BaseFloat SampFreq() const
const Matrix< BaseFloat > & Data() const
void Register(OptionsItf *opts)
void GetDiagnosticsAndPrintOutput(const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like)
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Print(bool online=true)
Here, if "online == false" we take into account that the setup was used in not-really-online mode whe...
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void Read(std::istream &is, bool binary)
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Register(OptionsItf *opts)
void Register(OptionsItf *opts)
fst::VectorFst< CompactLatticeArc > CompactLattice
This class's purpose is to read in Wave files.
#define KALDI_ASSERT(cond)
class OnlineTimingStats stores statistics from timing of online decoding, which will enable the Print...
double Elapsed() const
Returns time in seconds.
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...