online2-tcp-nnet3-decode-faster.cc File Reference
#include "feat/wave-reader.h"
#include "online2/online-nnet3-decoding.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "online2/onlinebin-util.h"
#include "online2/online-timing.h"
#include "online2/online-endpoint.h"
#include "fstext/fstext-lib.h"
#include "lat/lattice-functions.h"
#include "util/kaldi-thread.h"
#include "nnet3/nnet-utils.h"
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <poll.h>
#include <signal.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <string>
Include dependency graph for online2-tcp-nnet3-decode-faster.cc:

Go to the source code of this file.

Classes

class  TcpServer
 

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Functions

std::string LatticeToString (const Lattice &lat, const fst::SymbolTable &word_syms)
 
std::string GetTimeString (int32 t_beg, int32 t_end, BaseFloat time_unit)
 
int32 GetLatticeTimeSpan (const Lattice &lat)
 
std::string LatticeToString (const CompactLattice &clat, const fst::SymbolTable &word_syms)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 116 of file online2-tcp-nnet3-decode-faster.cc.

References OnlineNnet2FeaturePipeline::AcceptWaveform(), OnlineSilenceWeighting::Active(), SingleUtteranceNnet3DecoderTpl< FST >::AdvanceDecoding(), kaldi::nnet3::CollapseModel(), OnlineSilenceWeighting::ComputeCurrentTraceback(), SingleUtteranceNnet3DecoderTpl< FST >::Decoder(), SingleUtteranceNnet3DecoderTpl< FST >::EndpointDetected(), SingleUtteranceNnet3DecoderTpl< FST >::FinalizeDecoding(), NnetSimpleLoopedComputationOptions::frame_subsampling_factor, OnlineNnet2FeaturePipelineInfo::FrameShiftInSeconds(), kaldi::g_num_threads, ParseOptions::GetArg(), SingleUtteranceNnet3DecoderTpl< FST >::GetBestPath(), OnlineSilenceWeighting::GetDeltaWeights(), SingleUtteranceNnet3DecoderTpl< FST >::GetLattice(), kaldi::GetLatticeTimeSpan(), AmNnetSimple::GetNnet(), kaldi::GetTimeString(), SingleUtteranceNnet3DecoderTpl< FST >::InitDecoding(), OnlineNnet2FeaturePipeline::InputFinished(), OnlineNnet2FeaturePipeline::IvectorFeature(), KALDI_ERR, KALDI_VLOG, kaldi::LatticeToString(), ParseOptions::NumArgs(), SingleUtteranceNnet3DecoderTpl< FST >::NumFramesDecoded(), OnlineNnet2FeaturePipeline::NumFramesReady(), ParseOptions::PrintUsage(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), fst::ReadFstKaldiGeneric(), NnetSimpleLoopedComputationOptions::Register(), LatticeFasterDecoderConfig::Register(), ParseOptions::Register(), OnlineNnet2FeaturePipelineConfig::Register(), OnlineEndpointConfig::Register(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), OnlineNnet2FeaturePipelineInfo::silence_weighting_config, Input::Stream(), TcpServer::TcpServer(), and OnlineNnet2FeaturePipeline::UpdateFrameWeights().

116  {
117  try {
118  using namespace kaldi;
119  using namespace fst;
120 
121  typedef kaldi::int32 int32;
122  typedef kaldi::int64 int64;
123 
124  const char *usage =
125  "Reads in audio from a network socket and performs online\n"
126  "decoding with neural nets (nnet3 setup), with iVector-based\n"
127  "speaker adaptation and endpointing.\n"
128  "Note: some configuration values and inputs are set via config\n"
129  "files whose filenames are passed as options\n"
130  "\n"
131  "Usage: online2-tcp-nnet3-decode-faster [options] <nnet3-in> "
132  "<fst-in> <word-symbol-table>\n";
133 
134  ParseOptions po(usage);
135 
136 
137  // feature_opts includes configuration for the iVector adaptation,
138  // as well as the basic features.
141  LatticeFasterDecoderConfig decoder_opts;
142  OnlineEndpointConfig endpoint_opts;
143 
144  BaseFloat chunk_length_secs = 0.18;
145  BaseFloat output_period = 1;
146  BaseFloat samp_freq = 16000.0;
147  int port_num = 5050;
148  int read_timeout = 3;
149  bool produce_time = false;
150 
151  po.Register("samp-freq", &samp_freq,
152  "Sampling frequency of the input signal (coded as 16-bit slinear).");
153  po.Register("chunk-length", &chunk_length_secs,
154  "Length of chunk size in seconds, that we process.");
155  po.Register("output-period", &output_period,
156  "How often in seconds, do we check for changes in output.");
157  po.Register("num-threads-startup", &g_num_threads,
158  "Number of threads used when initializing iVector extractor.");
159  po.Register("read-timeout", &read_timeout,
160  "Number of seconds of timout for TCP audio data to appear on the stream. Use -1 for blocking.");
161  po.Register("port-num", &port_num,
162  "Port number the server will listen on.");
163  po.Register("produce-time", &produce_time,
164  "Prepend begin/end times between endpoints (e.g. '5.46 6.81 <text_output>', in seconds)");
165 
166  feature_opts.Register(&po);
167  decodable_opts.Register(&po);
168  decoder_opts.Register(&po);
169  endpoint_opts.Register(&po);
170 
171  po.Read(argc, argv);
172 
173  if (po.NumArgs() != 3) {
174  po.PrintUsage();
175  return 1;
176  }
177 
178  std::string nnet3_rxfilename = po.GetArg(1),
179  fst_rxfilename = po.GetArg(2),
180  word_syms_filename = po.GetArg(3);
181 
182  OnlineNnet2FeaturePipelineInfo feature_info(feature_opts);
183 
184  BaseFloat frame_shift = feature_info.FrameShiftInSeconds();
185  int32 frame_subsampling = decodable_opts.frame_subsampling_factor;
186 
187  KALDI_VLOG(1) << "Loading AM...";
188 
189  TransitionModel trans_model;
190  nnet3::AmNnetSimple am_nnet;
191  {
192  bool binary;
193  Input ki(nnet3_rxfilename, &binary);
194  trans_model.Read(ki.Stream(), binary);
195  am_nnet.Read(ki.Stream(), binary);
196  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
197  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
199  }
200 
201  // this object contains precomputed stuff that is used by all decodable
202  // objects. It takes a pointer to am_nnet because if it has iVectors it has
203  // to modify the nnet to accept iVectors at intervals.
204  nnet3::DecodableNnetSimpleLoopedInfo decodable_info(decodable_opts,
205  &am_nnet);
206 
207  KALDI_VLOG(1) << "Loading FST...";
208 
209  fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldiGeneric(fst_rxfilename);
210 
211  fst::SymbolTable *word_syms = NULL;
212  if (!word_syms_filename.empty())
213  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
214  KALDI_ERR << "Could not read symbol table from file "
215  << word_syms_filename;
216 
217  signal(SIGPIPE, SIG_IGN); // ignore SIGPIPE to avoid crashing when socket forcefully disconnected
218 
219  TcpServer server(read_timeout);
220 
221  server.Listen(port_num);
222 
223  while (true) {
224 
225  server.Accept();
226 
227  int32 samp_count = 0;// this is used for output refresh rate
228  size_t chunk_len = static_cast<size_t>(chunk_length_secs * samp_freq);
229  int32 check_period = static_cast<int32>(samp_freq * output_period);
230  int32 check_count = check_period;
231 
232  int32 frame_offset = 0;
233 
234  bool eos = false;
235 
236  OnlineNnet2FeaturePipeline feature_pipeline(feature_info);
237  SingleUtteranceNnet3Decoder decoder(decoder_opts, trans_model,
238  decodable_info,
239  *decode_fst, &feature_pipeline);
240 
241  while (!eos) {
242 
243  decoder.InitDecoding(frame_offset);
244  OnlineSilenceWeighting silence_weighting(
245  trans_model,
246  feature_info.silence_weighting_config,
247  decodable_opts.frame_subsampling_factor);
248  std::vector<std::pair<int32, BaseFloat>> delta_weights;
249 
250  while (true) {
251  eos = !server.ReadChunk(chunk_len);
252 
253  if (eos) {
254  feature_pipeline.InputFinished();
255  decoder.AdvanceDecoding();
256  decoder.FinalizeDecoding();
257  frame_offset += decoder.NumFramesDecoded();
258  if (decoder.NumFramesDecoded() > 0) {
259  CompactLattice lat;
260  decoder.GetLattice(true, &lat);
261  std::string msg = LatticeToString(lat, *word_syms);
262 
263  // get time-span from previous endpoint to end of audio,
264  if (produce_time) {
265  int32 t_beg = frame_offset - decoder.NumFramesDecoded();
266  int32 t_end = frame_offset;
267  msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
268  }
269 
270  KALDI_VLOG(1) << "EndOfAudio, sending message: " << msg;
271  server.WriteLn(msg);
272  } else
273  server.Write("\n");
274  server.Disconnect();
275  break;
276  }
277 
278  Vector<BaseFloat> wave_part = server.GetChunk();
279  feature_pipeline.AcceptWaveform(samp_freq, wave_part);
280  samp_count += chunk_len;
281 
282  if (silence_weighting.Active() &&
283  feature_pipeline.IvectorFeature() != NULL) {
284  silence_weighting.ComputeCurrentTraceback(decoder.Decoder());
285  silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(),
286  frame_offset * decodable_opts.frame_subsampling_factor,
287  &delta_weights);
288  feature_pipeline.UpdateFrameWeights(delta_weights);
289  }
290 
291  decoder.AdvanceDecoding();
292 
293  if (samp_count > check_count) {
294  if (decoder.NumFramesDecoded() > 0) {
295  Lattice lat;
296  decoder.GetBestPath(false, &lat);
297  TopSort(&lat); // for LatticeStateTimes(),
298  std::string msg = LatticeToString(lat, *word_syms);
299 
300  // get time-span after previous endpoint,
301  if (produce_time) {
302  int32 t_beg = frame_offset;
303  int32 t_end = frame_offset + GetLatticeTimeSpan(lat);
304  msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
305  }
306 
307  KALDI_VLOG(1) << "Temporary transcript: " << msg;
308  server.WriteLn(msg, "\r");
309  }
310  check_count += check_period;
311  }
312 
313  if (decoder.EndpointDetected(endpoint_opts)) {
314  decoder.FinalizeDecoding();
315  frame_offset += decoder.NumFramesDecoded();
316  CompactLattice lat;
317  decoder.GetLattice(true, &lat);
318  std::string msg = LatticeToString(lat, *word_syms);
319 
320  // get time-span between endpoints,
321  if (produce_time) {
322  int32 t_beg = frame_offset - decoder.NumFramesDecoded();
323  int32 t_end = frame_offset;
324  msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
325  }
326 
327  KALDI_VLOG(1) << "Endpoint, sending message: " << msg;
328  server.WriteLn(msg);
329  break; // while (true)
330  }
331  }
332  }
333  }
334  } catch (const std::exception &e) {
335  std::cerr << e.what();
336  return -1;
337  }
338 } // main()
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
int32 g_num_threads
Definition: kaldi-thread.cc:25
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
kaldi::int32 int32
const Nnet & GetNnet() const
void Register(OptionsItf *opts)
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void Read(std::istream &is, bool binary)
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
int32 GetLatticeTimeSpan(const Lattice &lat)
void Read(std::istream &is, bool binary)
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
OnlineNnet2FeaturePipeline is a class that&#39;s responsible for putting together the various parts of th...
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
std::string LatticeToString(const Lattice &lat, const fst::SymbolTable &word_syms)
std::string GetTimeString(int32 t_beg, int32 t_end, BaseFloat time_unit)
When you instantiate class DecodableNnetSimpleLooped, you should give it a const reference to this cl...
Config class for the CollapseModel function.
Definition: nnet-utils.h:240