online-audio-server-decode-faster.cc
Go to the documentation of this file.
1 // onlinebin/online-audio-server-decode-faster.cc
2 
3 // Copyright 2012 Cisco Systems (author: Matthias Paulik)
4 // Copyright 2013 Polish-Japanese Institute of Information Technology (author: Danijel Korzinek)
5 
6 // Modifications to the original contribution by Cisco Systems made by:
7 // Vassil Panayotov
8 
9 // See ../../COPYING for clarification regarding multiple authors
10 //
11 // Licensed under the Apache License, Version 2.0 (the "License");
12 // you may not use this file except in compliance with the License.
13 // You may obtain a copy of the License at
14 //
15 // http://www.apache.org/licenses/LICENSE-2.0
16 //
17 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
18 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
19 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
20 // MERCHANTABLITY OR NON-INFRINGEMENT.
21 // See the Apache 2 License for the specific language governing permissions and
22 // limitations under the License.
23 
24 #include "feat/feature-mfcc.h"
25 #include "feat/wave-reader.h"
30 #include "online/onlinebin-util.h"
31 #include "matrix/kaldi-vector.h"
32 #include "lat/word-align-lattice.h"
33 #include "lat/lattice-functions.h"
34 #include "lat/sausages.h"
36 
37 #include <sys/socket.h>
38 #include <sys/types.h>
39 #include <unistd.h>
40 #include <ctime>
41 #include <signal.h>
42 
43 namespace kaldi {
44 /*
45  * This class is for a very simple TCP server implementation
46  * in UNIX sockets.
47  */
48 class TcpServer {
49  public:
50  TcpServer();
51  ~TcpServer();
52 
53  bool Listen(int32 port); //start listening on a given port
54  int32 Accept(); //accept a client and return its descriptor
55 
56  private:
57  struct sockaddr_in h_addr_;
59 };
60 
61 //write a line of text to socket
62 bool WriteLine(int32 socket, std::string line);
63 
64 //constant allowing to convert frame count to time
65 const float kFramesPerSecond = 100.0f;
66 } // namespace kaldi
67 
68 int32 main(int argc, char *argv[]) {
69  using namespace kaldi;
70  using namespace fst;
71 
72  try {
73  typedef kaldi::int32 int32;
74  typedef OnlineFeInput<Mfcc> FeInput;
75  TcpServer tcp_server;
76  signal(SIGPIPE, SIG_IGN);
77 
78  // up to delta-delta derivative features are calculated (unless LDA is used)
79  const int32 kDeltaOrder = 2;
80 
81  const char *usage =
82  "Starts a TCP server that receives RAW audio and outputs aligned words.\n"
83  "A sample client can be found in: onlinebin/online-audio-client\n\n"
84  "Usage: online-audio-server-decode-faster [options] model-in "
85  "fst-in word-symbol-table silence-phones word_boundary_file tcp-port [lda-matrix-in]\n\n"
86  "example: online-audio-server-decode-faster --verbose=1 --rt-min=0.5 --rt-max=3.0 --max-active=6000\n"
87  "--beam=72.0 --acoustic-scale=0.0769 final.mdl graph/HCLG.fst graph/words.txt '1:2:3:4:5'\n"
88  "graph/word_boundary.int 5000 final.mat\n\n";
89 
90  ParseOptions po(usage);
91  BaseFloat acoustic_scale = 0.1;
92  int32 cmn_window = 600, min_cmn_window = 100; // adds 1 second latency, only at utterance start.
93  int32 right_context = 4, left_context = 4;
94  BaseFloat frame_shift = 0.01;
95 
96  OnlineFasterDecoderOpts decoder_opts;
97  decoder_opts.Register(&po, true);
98  OnlineFeatureMatrixOptions feature_reading_opts;
99  feature_reading_opts.Register(&po);
100 
101  po.Register("left-context", &left_context,
102  "Number of frames of left context");
103  po.Register("right-context", &right_context,
104  "Number of frames of right context");
105  po.Register("acoustic-scale", &acoustic_scale,
106  "Scaling factor for acoustic likelihoods");
107  po.Register(
108  "cmn-window", &cmn_window,
109  "Number of feat. vectors used in the running average CMN calculation");
110  po.Register("min-cmn-window", &min_cmn_window,
111  "Minumum CMN window used at start of decoding (adds "
112  "latency only at start)");
113  po.Register("frame-shift", &frame_shift,
114  "Time in seconds between frames.\n");
115 
117  opts.Register(&po);
118 
119  po.Read(argc, argv);
120  if (po.NumArgs() < 6 || po.NumArgs() > 7) {
121  po.PrintUsage();
122  return 1;
123  }
124 
125  std::string model_rspecifier = po.GetArg(1), fst_rspecifier = po.GetArg(2),
126  word_syms_filename = po.GetArg(3), silence_phones_str = po.GetArg(4),
127  word_boundary_file = po.GetArg(5), lda_mat_rspecifier = "";
128 
129  if (po.NumArgs() == 7)
130  lda_mat_rspecifier = po.GetOptArg(7);
131 
132  int32 port = strtol(po.GetArg(6).c_str(), 0, 10);
133 
134  std::vector<int32> silence_phones;
135  if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones))
136  KALDI_ERR << "Invalid silence-phones string " << silence_phones_str;
137  if (silence_phones.empty())
138  KALDI_ERR << "No silence phones given!";
139 
140  if (!tcp_server.Listen(port))
141  return 0;
142 
143  std::cout << "Reading LDA matrix: " << lda_mat_rspecifier << "..."
144  << std::endl;
145  Matrix < BaseFloat > lda_transform;
146  if (lda_mat_rspecifier != "") {
147  bool binary_in;
148  Input ki(lda_mat_rspecifier, &binary_in);
149  lda_transform.Read(ki.Stream(), binary_in);
150  }
151 
152  std::cout << "Reading acoustic model: " << model_rspecifier << "..."
153  << std::endl;
154  TransitionModel trans_model;
155  AmDiagGmm am_gmm;
156  {
157  bool binary;
158  Input ki(model_rspecifier, &binary);
159  trans_model.Read(ki.Stream(), binary);
160  am_gmm.Read(ki.Stream(), binary);
161  }
162 
163  std::cout << "Reading word list: " << word_syms_filename << "..."
164  << std::endl;
165  fst::SymbolTable *word_syms = NULL;
166  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
167  KALDI_ERR << "Could not read symbol table from file "
168  << word_syms_filename;
169 
170  std::cout << "Reading word boundary file: " << word_boundary_file << "..."
171  << std::endl;
172  WordBoundaryInfo info(opts, word_boundary_file);
173 
174  std::cout << "Reading FST: " << fst_rspecifier << "..." << std::endl;
175  fst::Fst < fst::StdArc > *decode_fst = ReadDecodeGraph(fst_rspecifier);
176 
177  // We are not properly registering/exposing MFCC and frame extraction options,
178  // because there are parts of the online decoding code, where some of these
179  // options are hardwired(ToDo: we should fix this at some point)
180  MfccOptions mfcc_opts;
181  mfcc_opts.use_energy = false;
182  int32 frame_length = mfcc_opts.frame_opts.frame_length_ms = 25;
183  int32 mfcc_frame_shift = mfcc_opts.frame_opts.frame_shift_ms = 10;
184 
185  int32 window_size = right_context + left_context + 1;
186  decoder_opts.batch_size = std::max(decoder_opts.batch_size, window_size);
187 
189  det_opts.max_mem = 50000000;
190  det_opts.max_loop = 0;
191 
192  VectorFst < LatticeArc > out_fst;
193  Lattice out_lat;
194  CompactLattice det_lat, aligned_lat;
195  OnlineTcpVectorSource* au_src = NULL;
196  int32 client_socket = -1;
197 
198  while (true) {
199  if (au_src == NULL || !au_src->IsConnected()) {
200  if (au_src) {
201  std::cout << "Client disconnected!" << std::endl;
202  delete au_src;
203  }
204  client_socket = tcp_server.Accept();
205  au_src = new OnlineTcpVectorSource(client_socket);
206  }
207 
208  //re-initalizing decoder for each utterance
209  OnlineFasterDecoder decoder(*decode_fst, decoder_opts, silence_phones,
210  trans_model);
211 
212  Mfcc mfcc(mfcc_opts);
213  FeInput fe_input(au_src, &mfcc, frame_length * (16000 / 1000),
214  mfcc_frame_shift * (16000 / 1000)); //we always assume 16 kHz Fs on input
215  OnlineCmnInput cmn_input(&fe_input, cmn_window, min_cmn_window);
216  OnlineFeatInputItf *feat_transform = 0;
217  if (lda_mat_rspecifier != "") {
218  feat_transform = new OnlineLdaInput(&cmn_input, lda_transform,
219  left_context, right_context);
220  } else {
222  opts.order = kDeltaOrder;
223  feat_transform = new OnlineDeltaInput(opts, &cmn_input);
224  }
225 
226  // feature_reading_opts contains number of retries, batch size.
227  OnlineFeatureMatrix feature_matrix(feature_reading_opts, feat_transform);
228 
229  OnlineDecodableDiagGmmScaled decodable(am_gmm, trans_model,
230  acoustic_scale, &feature_matrix);
231 
232  clock_t start = clock();
233  int32 decoder_offset = 0;
234 
235  while (1) {
236  if (!au_src->IsConnected())
237  break;
238 
239  OnlineFasterDecoder::DecodeState dstate = decoder.Decode(&decodable);
240 
241  if (!au_src->IsConnected()) {
242  break;
243  }
244 
245  if (dstate & (decoder.kEndFeats | decoder.kEndUtt)) {
246  std::vector<int32> word_ids, times, lengths;
247 
248  decoder.FinishTraceBack(&out_fst);
249  decoder.GetBestPath(&out_fst);
250 
251  ConvertLattice(out_fst, &out_lat);
252 
253  Invert(&out_lat);
254  //TopSort(&out_lat);
255  //ArcSort(&out_lat, ILabelCompare<LatticeArc>());
256 
257  DeterminizeLatticePruned(out_lat, 10.0f, &det_lat, det_opts);
258 
259  WordAlignLattice(det_lat, trans_model, info, 0, &aligned_lat);
260 
261  CompactLatticeToWordAlignment(aligned_lat, &word_ids, &times,
262  &lengths);
263 
264  //count number of non-sil words
265  int32 words_num = 0;
266  for (size_t i = 0; i < word_ids.size(); i++)
267  if (word_ids[i] != 0)
268  words_num++;
269 
270  if (words_num > 0) {
271 
272  float dur = (clock() - start) / (float) CLOCKS_PER_SEC;
273  float input_dur = au_src->SamplesProcessed() / 16000.0;
274 
275  start = clock();
276  au_src->ResetSamples();
277 
278  std::stringstream sstr;
279  sstr << "RESULT:NUM=" << words_num << ",FORMAT=WSE,RECO-DUR=" << dur
280  << ",INPUT-DUR=" << input_dur;
281 
282  WriteLine(client_socket, sstr.str());
283 
284  for (size_t i = 0; i < word_ids.size(); i++) {
285  if (word_ids[i] == 0)
286  continue; //skip silences...
287 
288  std::string word = word_syms->Find(word_ids[i]);
289  if (word.empty())
290  word = "???";
291 
292  float start = (times[i] + decoder_offset) / kFramesPerSecond;
293  float len = lengths[i] / kFramesPerSecond;
294 
295  std::stringstream wstr;
296  wstr << word << "," << start << "," << (start + len);
297 
298  WriteLine(client_socket, wstr.str());
299  }
300  }
301 
302  if (dstate == decoder.kEndFeats) {
303  WriteLine(client_socket, "RESULT:DONE");
304  break;
305  }
306 
307  decoder_offset = decoder.frame();
308  } else {
309  std::vector<int32> word_ids;
310  if (decoder.PartialTraceback(&out_fst)) {
311  GetLinearSymbolSequence(out_fst, static_cast<vector<int32> *>(0),
312  &word_ids,
313  static_cast<LatticeArc::Weight*>(0));
314  for (size_t i = 0; i < word_ids.size(); i++) {
315  if (word_ids[i] != 0) {
316  WriteLine(client_socket,
317  "PARTIAL:" + word_syms->Find(word_ids[i]));
318  }
319  }
320  }
321  }
322  }
323  delete feat_transform;
324  }
325 
326  std::cout << "Deinitizalizing..." << std::endl;
327 
328  delete word_syms;
329  delete decode_fst;
330  return 0;
331 
332  } catch (const std::exception& e) {
333  std::cerr << e.what();
334  return -1;
335  }
336 } // main()
337 
338 namespace kaldi {
339 // IMPLEMENTATION OF THE CLASSES/METHODS ABOVE MAIN
341  server_desc_ = -1;
342 }
343 
345  h_addr_.sin_addr.s_addr = INADDR_ANY;
346  h_addr_.sin_port = htons(port);
347  h_addr_.sin_family = AF_INET;
348 
349  server_desc_ = socket(AF_INET, SOCK_STREAM, 0);
350 
351  if (server_desc_ == -1) {
352  KALDI_ERR << "Cannot create TCP socket!";
353  return false;
354  }
355 
356  int32 flag = 1;
357  int32 len = sizeof(int32);
358  if( setsockopt(server_desc_, SOL_SOCKET, SO_REUSEADDR, &flag, len) == -1){
359  KALDI_ERR << "Cannot set socket options!\n";
360  return false;
361  }
362 
363  if (bind(server_desc_, (struct sockaddr*) &h_addr_, sizeof(h_addr_)) == -1) {
364  KALDI_ERR << "Cannot bind to port: " << port << " (is it taken?)";
365  return false;
366  }
367 
368  if (listen(server_desc_, 1) == -1) {
369  KALDI_ERR << "Cannot listen on port!";
370  return false;
371  }
372 
373  std::cout << "TcpServer: Listening on port: " << port << std::endl;
374 
375  return true;
376 
377 }
378 
380  if (server_desc_ != -1)
381  close(server_desc_);
382 }
383 
385  std::cout << "Waiting for client..." << std::endl;
386 
387  socklen_t len;
388 
389  len = sizeof(struct sockaddr);
390  int32 client_desc = accept(server_desc_, (struct sockaddr*) &h_addr_, &len);
391 
392  struct sockaddr_storage addr;
393  char ipstr[20];
394 
395  len = sizeof addr;
396  getpeername(client_desc, (struct sockaddr*) &addr, &len);
397 
398  struct sockaddr_in *s = (struct sockaddr_in *) &addr;
399  inet_ntop(AF_INET, &s->sin_addr, ipstr, sizeof ipstr);
400 
401  std::cout << "TcpServer: Accepted connection from: " << ipstr << std::endl;
402 
403  return client_desc;
404 }
405 
406 bool WriteLine(int32 socket, std::string line) {
407  line = line + "\n";
408 
409  const char* p = line.c_str();
410  int32 to_write = line.size();
411  int32 wrote = 0;
412  while (to_write > 0) {
413  int32 ret = write(socket, p + wrote, to_write);
414  if (ret <= 0)
415  return false;
416 
417  to_write -= ret;
418  wrote += ret;
419  }
420 
421  return true;
422 }
423 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool PartialTraceback(fst::MutableFst< LatticeArc > *out_fst)
void Register(OptionsItf *opts, bool full)
MfccOptions contains basic options for computing MFCC features.
Definition: feature-mfcc.h:38
bool DeterminizeLatticePruned(const ExpandedFst< ArcTpl< Weight > > &ifst, double beam, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > *ofst, DeterminizeLatticePrunedOptions opts)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
DecodeState Decode(DecodableInterface *decodable)
bool WordAlignLattice(const CompactLattice &lat, const TransitionModel &tmodel, const WordBoundaryInfo &info, int32 max_states, CompactLattice *lat_out)
Align lattice so that each arc has the transition-ids on it that correspond to the word that is on th...
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
void Register(const std::string &name, bool *ptr, const std::string &doc)
bool GetBestPath(fst::MutableFst< LatticeArc > *fst_out, bool use_final_probs=true)
GetBestPath gets the decoding traceback.
std::istream & Stream()
Definition: kaldi-io.cc:826
void Read(std::istream &in, bool binary, bool add=false)
read from stream.
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void FinishTraceBack(fst::MutableFst< LatticeArc > *fst_out)
FrameExtractionOptions frame_opts
Definition: feature-mfcc.h:39
void Read(std::istream &is, bool binary)
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
void Register(OptionsItf *opts)
int32 main(int argc, char *argv[])
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
int NumArgs() const
Number of positional parameters (c.f. argc-1).
bool WriteLine(int32 socket, std::string line)
fst::Fst< fst::StdArc > * ReadDecodeGraph(const std::string &filename)
This templated class is intended for offline feature extraction, i.e.
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
bool CompactLatticeToWordAlignment(const CompactLattice &clat, std::vector< int32 > *words, std::vector< int32 > *begin_times, std::vector< int32 > *lengths)
This function takes a CompactLattice that should only contain a single linear sequence (e...
std::string GetOptArg(int param) const