118 using namespace kaldi;
122 typedef kaldi::int64 int64;
125 "Reads in audio from a network socket and performs online\n" 126 "decoding with neural nets (nnet3 setup), with iVector-based\n" 127 "speaker adaptation and endpointing.\n" 128 "Note: some configuration values and inputs are set via config\n" 129 "files whose filenames are passed as options\n" 131 "Usage: online2-tcp-nnet3-decode-faster [options] <nnet3-in> " 132 "<fst-in> <word-symbol-table>\n";
148 int read_timeout = 3;
149 bool produce_time =
false;
151 po.Register(
"samp-freq", &samp_freq,
152 "Sampling frequency of the input signal (coded as 16-bit slinear).");
153 po.Register(
"chunk-length", &chunk_length_secs,
154 "Length of chunk size in seconds, that we process.");
155 po.Register(
"output-period", &output_period,
156 "How often in seconds, do we check for changes in output.");
158 "Number of threads used when initializing iVector extractor.");
159 po.Register(
"read-timeout", &read_timeout,
160 "Number of seconds of timout for TCP audio data to appear on the stream. Use -1 for blocking.");
161 po.Register(
"port-num", &port_num,
162 "Port number the server will listen on.");
163 po.Register(
"produce-time", &produce_time,
164 "Prepend begin/end times between endpoints (e.g. '5.46 6.81 <text_output>', in seconds)");
173 if (po.NumArgs() != 3) {
178 std::string nnet3_rxfilename = po.GetArg(1),
179 fst_rxfilename = po.GetArg(2),
180 word_syms_filename = po.GetArg(3);
184 BaseFloat frame_shift = feature_info.FrameShiftInSeconds();
193 Input ki(nnet3_rxfilename, &binary);
194 trans_model.
Read(ki.Stream(), binary);
195 am_nnet.
Read(ki.Stream(), binary);
211 fst::SymbolTable *word_syms = NULL;
212 if (!word_syms_filename.empty())
213 if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
214 KALDI_ERR <<
"Could not read symbol table from file " 215 << word_syms_filename;
217 signal(SIGPIPE, SIG_IGN);
219 TcpServer server(read_timeout);
221 server.Listen(port_num);
227 int32 samp_count = 0;
228 size_t chunk_len =
static_cast<size_t>(chunk_length_secs * samp_freq);
229 int32 check_period =
static_cast<int32
>(samp_freq * output_period);
230 int32 check_count = check_period;
232 int32 frame_offset = 0;
239 *decode_fst, &feature_pipeline);
243 decoder.InitDecoding(frame_offset);
246 feature_info.silence_weighting_config,
248 std::vector<std::pair<int32, BaseFloat>> delta_weights;
251 eos = !server.ReadChunk(chunk_len);
254 feature_pipeline.InputFinished();
255 decoder.AdvanceDecoding();
256 decoder.FinalizeDecoding();
257 frame_offset += decoder.NumFramesDecoded();
258 if (decoder.NumFramesDecoded() > 0) {
260 decoder.GetLattice(
true, &lat);
265 int32 t_beg = frame_offset - decoder.NumFramesDecoded();
266 int32 t_end = frame_offset;
267 msg =
GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) +
" " + msg;
270 KALDI_VLOG(1) <<
"EndOfAudio, sending message: " << msg;
278 Vector<BaseFloat> wave_part = server.GetChunk();
279 feature_pipeline.AcceptWaveform(samp_freq, wave_part);
280 samp_count += chunk_len;
282 if (silence_weighting.Active() &&
283 feature_pipeline.IvectorFeature() != NULL) {
284 silence_weighting.ComputeCurrentTraceback(decoder.Decoder());
285 silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(),
288 feature_pipeline.UpdateFrameWeights(delta_weights);
291 decoder.AdvanceDecoding();
293 if (samp_count > check_count) {
294 if (decoder.NumFramesDecoded() > 0) {
296 decoder.GetBestPath(
false, &lat);
302 int32 t_beg = frame_offset;
304 msg =
GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) +
" " + msg;
307 KALDI_VLOG(1) <<
"Temporary transcript: " << msg;
308 server.WriteLn(msg,
"\r");
310 check_count += check_period;
313 if (decoder.EndpointDetected(endpoint_opts)) {
314 decoder.FinalizeDecoding();
315 frame_offset += decoder.NumFramesDecoded();
317 decoder.GetLattice(
true, &lat);
322 int32 t_beg = frame_offset - decoder.NumFramesDecoded();
323 int32 t_end = frame_offset;
324 msg =
GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) +
" " + msg;
327 KALDI_VLOG(1) <<
"Endpoint, sending message: " << msg;
334 }
catch (
const std::exception &e) {
335 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
int32 frame_subsampling_factor
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
const Nnet & GetNnet() const
void Register(OptionsItf *opts)
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void Read(std::istream &is, bool binary)
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
int32 GetLatticeTimeSpan(const Lattice &lat)
void Read(std::istream &is, bool binary)
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
fst::VectorFst< LatticeArc > Lattice
void Register(OptionsItf *opts)
fst::VectorFst< CompactLatticeArc > CompactLattice
OnlineNnet2FeaturePipeline is a class that's responsible for putting together the various parts of th...
std::string LatticeToString(const Lattice &lat, const fst::SymbolTable &word_syms)
std::string GetTimeString(int32 t_beg, int32 t_end, BaseFloat time_unit)
void Register(OptionsItf *opts)
When you instantiate class DecodableNnetSimpleLooped, you should give it a const reference to this cl...
void Register(OptionsItf *opts)
Config class for the CollapseModel function.