online2-tcp-nnet3-decode-faster.cc
Go to the documentation of this file.
1 // online2bin/online2-tcp-nnet3-decode-faster.cc
2 
3 // Copyright 2014 Johns Hopkins University (author: Daniel Povey)
4 // 2016 Api.ai (Author: Ilya Platonov)
5 // 2018 Polish-Japanese Academy of Information Technology (Author: Danijel Korzinek)
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include "feat/wave-reader.h"
25 #include "online2/onlinebin-util.h"
26 #include "online2/online-timing.h"
28 #include "fstext/fstext-lib.h"
29 #include "lat/lattice-functions.h"
30 #include "util/kaldi-thread.h"
31 #include "nnet3/nnet-utils.h"
32 
33 #include <netinet/in.h>
34 #include <sys/socket.h>
35 #include <sys/types.h>
36 #include <poll.h>
37 #include <signal.h>
38 #include <arpa/inet.h>
39 #include <unistd.h>
40 #include <string>
41 
42 namespace kaldi {
43 
44 class TcpServer {
45  public:
46  explicit TcpServer(int read_timeout);
47  ~TcpServer();
48 
49  bool Listen(int32 port); // start listening on a given port
50  int32 Accept(); // accept a client and return its descriptor
51 
52  bool ReadChunk(size_t len); // get more data and return false if end-of-stream
53 
54  Vector<BaseFloat> GetChunk(); // get the data read by above method
55 
56  bool Write(const std::string &msg); // write to accepted client
57  bool WriteLn(const std::string &msg, const std::string &eol = "\n"); // write line to accepted client
58 
59  void Disconnect();
60 
61  private:
62  struct ::sockaddr_in h_addr_;
64  int16 *samp_buf_;
66  pollfd client_set_[1];
68 };
69 
70 std::string LatticeToString(const Lattice &lat, const fst::SymbolTable &word_syms) {
71  LatticeWeight weight;
72  std::vector<int32> alignment;
73  std::vector<int32> words;
74  GetLinearSymbolSequence(lat, &alignment, &words, &weight);
75 
76  std::ostringstream msg;
77  for (size_t i = 0; i < words.size(); i++) {
78  std::string s = word_syms.Find(words[i]);
79  if (s.empty()) {
80  KALDI_WARN << "Word-id " << words[i] << " not in symbol table.";
81  msg << "<#" << std::to_string(i) << "> ";
82  } else
83  msg << s << " ";
84  }
85  return msg.str();
86 }
87 
88 std::string GetTimeString(int32 t_beg, int32 t_end, BaseFloat time_unit) {
89  char buffer[100];
90  double t_beg2 = t_beg * time_unit;
91  double t_end2 = t_end * time_unit;
92  snprintf(buffer, 100, "%.2f %.2f", t_beg2, t_end2);
93  return std::string(buffer);
94 }
95 
97  std::vector<int32> times;
98  LatticeStateTimes(lat, &times);
99  return times.back();
100 }
101 
102 std::string LatticeToString(const CompactLattice &clat, const fst::SymbolTable &word_syms) {
103  if (clat.NumStates() == 0) {
104  KALDI_WARN << "Empty lattice.";
105  return "";
106  }
107  CompactLattice best_path_clat;
108  CompactLatticeShortestPath(clat, &best_path_clat);
109 
110  Lattice best_path_lat;
111  ConvertLattice(best_path_clat, &best_path_lat);
112  return LatticeToString(best_path_lat, word_syms);
113 }
114 }
115 
116 int main(int argc, char *argv[]) {
117  try {
118  using namespace kaldi;
119  using namespace fst;
120 
121  typedef kaldi::int32 int32;
122  typedef kaldi::int64 int64;
123 
124  const char *usage =
125  "Reads in audio from a network socket and performs online\n"
126  "decoding with neural nets (nnet3 setup), with iVector-based\n"
127  "speaker adaptation and endpointing.\n"
128  "Note: some configuration values and inputs are set via config\n"
129  "files whose filenames are passed as options\n"
130  "\n"
131  "Usage: online2-tcp-nnet3-decode-faster [options] <nnet3-in> "
132  "<fst-in> <word-symbol-table>\n";
133 
134  ParseOptions po(usage);
135 
136 
137  // feature_opts includes configuration for the iVector adaptation,
138  // as well as the basic features.
141  LatticeFasterDecoderConfig decoder_opts;
142  OnlineEndpointConfig endpoint_opts;
143 
144  BaseFloat chunk_length_secs = 0.18;
145  BaseFloat output_period = 1;
146  BaseFloat samp_freq = 16000.0;
147  int port_num = 5050;
148  int read_timeout = 3;
149  bool produce_time = false;
150 
151  po.Register("samp-freq", &samp_freq,
152  "Sampling frequency of the input signal (coded as 16-bit slinear).");
153  po.Register("chunk-length", &chunk_length_secs,
154  "Length of chunk size in seconds, that we process.");
155  po.Register("output-period", &output_period,
156  "How often in seconds, do we check for changes in output.");
157  po.Register("num-threads-startup", &g_num_threads,
158  "Number of threads used when initializing iVector extractor.");
159  po.Register("read-timeout", &read_timeout,
160  "Number of seconds of timout for TCP audio data to appear on the stream. Use -1 for blocking.");
161  po.Register("port-num", &port_num,
162  "Port number the server will listen on.");
163  po.Register("produce-time", &produce_time,
164  "Prepend begin/end times between endpoints (e.g. '5.46 6.81 <text_output>', in seconds)");
165 
166  feature_opts.Register(&po);
167  decodable_opts.Register(&po);
168  decoder_opts.Register(&po);
169  endpoint_opts.Register(&po);
170 
171  po.Read(argc, argv);
172 
173  if (po.NumArgs() != 3) {
174  po.PrintUsage();
175  return 1;
176  }
177 
178  std::string nnet3_rxfilename = po.GetArg(1),
179  fst_rxfilename = po.GetArg(2),
180  word_syms_filename = po.GetArg(3);
181 
182  OnlineNnet2FeaturePipelineInfo feature_info(feature_opts);
183 
184  BaseFloat frame_shift = feature_info.FrameShiftInSeconds();
185  int32 frame_subsampling = decodable_opts.frame_subsampling_factor;
186 
187  KALDI_VLOG(1) << "Loading AM...";
188 
189  TransitionModel trans_model;
190  nnet3::AmNnetSimple am_nnet;
191  {
192  bool binary;
193  Input ki(nnet3_rxfilename, &binary);
194  trans_model.Read(ki.Stream(), binary);
195  am_nnet.Read(ki.Stream(), binary);
196  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
197  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
199  }
200 
201  // this object contains precomputed stuff that is used by all decodable
202  // objects. It takes a pointer to am_nnet because if it has iVectors it has
203  // to modify the nnet to accept iVectors at intervals.
204  nnet3::DecodableNnetSimpleLoopedInfo decodable_info(decodable_opts,
205  &am_nnet);
206 
207  KALDI_VLOG(1) << "Loading FST...";
208 
209  fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldiGeneric(fst_rxfilename);
210 
211  fst::SymbolTable *word_syms = NULL;
212  if (!word_syms_filename.empty())
213  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
214  KALDI_ERR << "Could not read symbol table from file "
215  << word_syms_filename;
216 
217  signal(SIGPIPE, SIG_IGN); // ignore SIGPIPE to avoid crashing when socket forcefully disconnected
218 
219  TcpServer server(read_timeout);
220 
221  server.Listen(port_num);
222 
223  while (true) {
224 
225  server.Accept();
226 
227  int32 samp_count = 0;// this is used for output refresh rate
228  size_t chunk_len = static_cast<size_t>(chunk_length_secs * samp_freq);
229  int32 check_period = static_cast<int32>(samp_freq * output_period);
230  int32 check_count = check_period;
231 
232  int32 frame_offset = 0;
233 
234  bool eos = false;
235 
236  OnlineNnet2FeaturePipeline feature_pipeline(feature_info);
237  SingleUtteranceNnet3Decoder decoder(decoder_opts, trans_model,
238  decodable_info,
239  *decode_fst, &feature_pipeline);
240 
241  while (!eos) {
242 
243  decoder.InitDecoding(frame_offset);
244  OnlineSilenceWeighting silence_weighting(
245  trans_model,
246  feature_info.silence_weighting_config,
247  decodable_opts.frame_subsampling_factor);
248  std::vector<std::pair<int32, BaseFloat>> delta_weights;
249 
250  while (true) {
251  eos = !server.ReadChunk(chunk_len);
252 
253  if (eos) {
254  feature_pipeline.InputFinished();
255  decoder.AdvanceDecoding();
256  decoder.FinalizeDecoding();
257  frame_offset += decoder.NumFramesDecoded();
258  if (decoder.NumFramesDecoded() > 0) {
259  CompactLattice lat;
260  decoder.GetLattice(true, &lat);
261  std::string msg = LatticeToString(lat, *word_syms);
262 
263  // get time-span from previous endpoint to end of audio,
264  if (produce_time) {
265  int32 t_beg = frame_offset - decoder.NumFramesDecoded();
266  int32 t_end = frame_offset;
267  msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
268  }
269 
270  KALDI_VLOG(1) << "EndOfAudio, sending message: " << msg;
271  server.WriteLn(msg);
272  } else
273  server.Write("\n");
274  server.Disconnect();
275  break;
276  }
277 
278  Vector<BaseFloat> wave_part = server.GetChunk();
279  feature_pipeline.AcceptWaveform(samp_freq, wave_part);
280  samp_count += chunk_len;
281 
282  if (silence_weighting.Active() &&
283  feature_pipeline.IvectorFeature() != NULL) {
284  silence_weighting.ComputeCurrentTraceback(decoder.Decoder());
285  silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(),
286  frame_offset * decodable_opts.frame_subsampling_factor,
287  &delta_weights);
288  feature_pipeline.UpdateFrameWeights(delta_weights);
289  }
290 
291  decoder.AdvanceDecoding();
292 
293  if (samp_count > check_count) {
294  if (decoder.NumFramesDecoded() > 0) {
295  Lattice lat;
296  decoder.GetBestPath(false, &lat);
297  TopSort(&lat); // for LatticeStateTimes(),
298  std::string msg = LatticeToString(lat, *word_syms);
299 
300  // get time-span after previous endpoint,
301  if (produce_time) {
302  int32 t_beg = frame_offset;
303  int32 t_end = frame_offset + GetLatticeTimeSpan(lat);
304  msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
305  }
306 
307  KALDI_VLOG(1) << "Temporary transcript: " << msg;
308  server.WriteLn(msg, "\r");
309  }
310  check_count += check_period;
311  }
312 
313  if (decoder.EndpointDetected(endpoint_opts)) {
314  decoder.FinalizeDecoding();
315  frame_offset += decoder.NumFramesDecoded();
316  CompactLattice lat;
317  decoder.GetLattice(true, &lat);
318  std::string msg = LatticeToString(lat, *word_syms);
319 
320  // get time-span between endpoints,
321  if (produce_time) {
322  int32 t_beg = frame_offset - decoder.NumFramesDecoded();
323  int32 t_end = frame_offset;
324  msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
325  }
326 
327  KALDI_VLOG(1) << "Endpoint, sending message: " << msg;
328  server.WriteLn(msg);
329  break; // while (true)
330  }
331  }
332  }
333  }
334  } catch (const std::exception &e) {
335  std::cerr << e.what();
336  return -1;
337  }
338 } // main()
339 
340 
341 namespace kaldi {
342 TcpServer::TcpServer(int read_timeout) {
343  server_desc_ = -1;
344  client_desc_ = -1;
345  samp_buf_ = NULL;
346  buf_len_ = 0;
347  read_timeout_ = 1000 * read_timeout;
348 }
349 
350 bool TcpServer::Listen(int32 port) {
351  h_addr_.sin_addr.s_addr = INADDR_ANY;
352  h_addr_.sin_port = htons(port);
353  h_addr_.sin_family = AF_INET;
354 
355  server_desc_ = socket(AF_INET, SOCK_STREAM, 0);
356 
357  if (server_desc_ == -1) {
358  KALDI_ERR << "Cannot create TCP socket!";
359  return false;
360  }
361 
362  int32 flag = 1;
363  int32 len = sizeof(int32);
364  if (setsockopt(server_desc_, SOL_SOCKET, SO_REUSEADDR, &flag, len) == -1) {
365  KALDI_ERR << "Cannot set socket options!";
366  return false;
367  }
368 
369  if (bind(server_desc_, (struct sockaddr *) &h_addr_, sizeof(h_addr_)) == -1) {
370  KALDI_ERR << "Cannot bind to port: " << port << " (is it taken?)";
371  return false;
372  }
373 
374  if (listen(server_desc_, 1) == -1) {
375  KALDI_ERR << "Cannot listen on port!";
376  return false;
377  }
378 
379  KALDI_LOG << "TcpServer: Listening on port: " << port;
380 
381  return true;
382 
383 }
384 
386  Disconnect();
387  if (server_desc_ != -1)
388  close(server_desc_);
389  delete[] samp_buf_;
390 }
391 
393  KALDI_LOG << "Waiting for client...";
394 
395  socklen_t len;
396 
397  len = sizeof(struct sockaddr);
398  client_desc_ = accept(server_desc_, (struct sockaddr *) &h_addr_, &len);
399 
400  struct sockaddr_storage addr;
401  char ipstr[20];
402 
403  len = sizeof addr;
404  getpeername(client_desc_, (struct sockaddr *) &addr, &len);
405 
406  struct sockaddr_in *s = (struct sockaddr_in *) &addr;
407  inet_ntop(AF_INET, &s->sin_addr, ipstr, sizeof ipstr);
408 
409  client_set_[0].fd = client_desc_;
410  client_set_[0].events = POLLIN;
411 
412  KALDI_LOG << "Accepted connection from: " << ipstr;
413 
414  return client_desc_;
415 }
416 
417 bool TcpServer::ReadChunk(size_t len) {
418  if (buf_len_ != len) {
419  buf_len_ = len;
420  delete[] samp_buf_;
421  samp_buf_ = new int16[len];
422  }
423 
424  ssize_t ret;
425  int poll_ret;
426  size_t to_read = len;
427  has_read_ = 0;
428  while (to_read > 0) {
429  poll_ret = poll(client_set_, 1, read_timeout_);
430  if (poll_ret == 0) {
431  KALDI_WARN << "Socket timeout! Disconnecting...";
432  break;
433  }
434  if (poll_ret < 0) {
435  KALDI_WARN << "Socket error! Disconnecting...";
436  break;
437  }
438  ret = read(client_desc_, static_cast<void *>(samp_buf_ + has_read_), to_read * sizeof(int16));
439  if (ret <= 0) {
440  KALDI_WARN << "Stream over...";
441  break;
442  }
443  to_read -= ret / sizeof(int16);
444  has_read_ += ret / sizeof(int16);
445  }
446 
447  return has_read_ > 0;
448 }
449 
451  Vector<BaseFloat> buf;
452 
453  buf.Resize(static_cast<MatrixIndexT>(has_read_));
454 
455  for (int i = 0; i < has_read_; i++)
456  buf(i) = static_cast<BaseFloat>(samp_buf_[i]);
457 
458  return buf;
459 }
460 
461 bool TcpServer::Write(const std::string &msg) {
462 
463  const char *p = msg.c_str();
464  size_t to_write = msg.size();
465  size_t wrote = 0;
466  while (to_write > 0) {
467  ssize_t ret = write(client_desc_, static_cast<const void *>(p + wrote), to_write);
468  if (ret <= 0)
469  return false;
470 
471  to_write -= ret;
472  wrote += ret;
473  }
474 
475  return true;
476 }
477 
478 bool TcpServer::WriteLn(const std::string &msg, const std::string &eol) {
479  if (Write(msg))
480  return Write(eol);
481  else return false;
482 }
483 
485  if (client_desc_ != -1) {
486  close(client_desc_);
487  client_desc_ = -1;
488  }
489 }
490 } // namespace kaldi
bool Write(const std::string &msg)
int32 words[kMaxOrder]
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
void UpdateFrameWeights(const std::vector< std::pair< int32, BaseFloat > > &delta_weights)
If you are downweighting silence, you can call OnlineSilenceWeighting::GetDeltaWeights and supply the...
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 g_num_threads
Definition: kaldi-thread.cc:25
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
void InputFinished()
If you call InputFinished(), it tells the class you won&#39;t be providing any more waveform.
kaldi::int32 int32
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
OnlineIvectorFeature * IvectorFeature()
This function returns the iVector-extracting part of the feature pipeline (or NULL if iVectors are no...
virtual int32 NumFramesReady() const
returns the feature dimension.
const Nnet & GetNnet() const
void Register(OptionsItf *opts)
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
This file contains a different version of the feature-extraction pipeline in online-feature-pipeline...
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
This file contains some miscellaneous functions dealing with class Nnet.
void InitDecoding(int32 frame_offset=0)
Initializes the decoding and sets the frame offset of the underlying decodable object.
void CompactLatticeShortestPath(const CompactLattice &clat, CompactLattice *shortest_path)
A form of the shortest-path/best-path algorithm that&#39;s specially coded for CompactLattice.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
void AcceptWaveform(BaseFloat sampling_rate, const VectorBase< BaseFloat > &waveform)
Accept more data to process.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl< FST > &decoder)
int32 GetLatticeTimeSpan(const Lattice &lat)
void Read(std::istream &is, bool binary)
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
void GetLattice(bool end_of_utterance, CompactLattice *clat) const
Gets the lattice.
int main(int argc, char *argv[])
bool EndpointDetected(const OnlineEndpointConfig &config)
This function calls EndpointDetected from online-endpoint.h, with the required arguments.
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
void GetBestPath(bool end_of_utterance, Lattice *best_path) const
Outputs an FST corresponding to the single best path through the current lattice. ...
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
OnlineNnet2FeaturePipeline is a class that&#39;s responsible for putting together the various parts of th...
OnlineSilenceWeightingConfig silence_weighting_config
Config for weighting silence in iVector adaptation.
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void AdvanceDecoding()
Advances the decoding as far as we can.
std::string LatticeToString(const Lattice &lat, const fst::SymbolTable &word_syms)
std::string GetTimeString(int32 t_beg, int32 t_end, BaseFloat time_unit)
#define KALDI_LOG
Definition: kaldi-error.h:153
bool WriteLn(const std::string &msg, const std::string &eol="\)
When you instantiate class DecodableNnetSimpleLooped, you should give it a const reference to this cl...
const LatticeFasterOnlineDecoderTpl< FST > & Decoder() const
void FinalizeDecoding()
Finalizes the decoding.
Config class for the CollapseModel function.
Definition: nnet-utils.h:240
void GetDeltaWeights(int32 num_frames_ready, int32 first_decoder_frame, std::vector< std::pair< int32, BaseFloat > > *delta_weights)