online-audio-server-decode-faster.cc File Reference
#include "feat/feature-mfcc.h"
#include "feat/wave-reader.h"
#include "online/online-tcp-source.h"
#include "online/online-feat-input.h"
#include "online/online-decodable.h"
#include "online/online-faster-decoder.h"
#include "online/onlinebin-util.h"
#include "matrix/kaldi-vector.h"
#include "lat/word-align-lattice.h"
#include "lat/lattice-functions.h"
#include "lat/sausages.h"
#include "lat/determinize-lattice-pruned.h"
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <ctime>
#include <signal.h>
Include dependency graph for online-audio-server-decode-faster.cc:

Go to the source code of this file.

Classes

class  TcpServer
 

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Functions

bool WriteLine (int32 socket, std::string line)
 
int32 main (int argc, char *argv[])
 

Variables

const float kFramesPerSecond = 100.0f
 

Function Documentation

◆ main()

int32 main ( int  argc,
char *  argv[] 
)

Definition at line 68 of file online-audio-server-decode-faster.cc.

References TcpServer::Accept(), OnlineFasterDecoderOpts::batch_size, kaldi::CompactLatticeToWordAlignment(), fst::ConvertLattice(), OnlineFasterDecoder::Decode(), fst::DeterminizeLatticePruned(), OnlineFasterDecoder::FinishTraceBack(), OnlineFasterDecoder::frame(), FrameExtractionOptions::frame_length_ms, MfccOptions::frame_opts, FrameExtractionOptions::frame_shift_ms, ParseOptions::GetArg(), FasterDecoder::GetBestPath(), fst::GetLinearSymbolSequence(), ParseOptions::GetOptArg(), rnnlm::i, OnlineTcpVectorSource::IsConnected(), KALDI_ERR, OnlineFasterDecoder::kEndFeats, OnlineFasterDecoder::kEndUtt, kaldi::kFramesPerSecond, TcpServer::Listen(), DeterminizeLatticePrunedOptions::max_loop, DeterminizeLatticePrunedOptions::max_mem, ParseOptions::NumArgs(), DeltaFeaturesOptions::order, OnlineFasterDecoder::PartialTraceback(), ParseOptions::PrintUsage(), AmDiagGmm::Read(), ParseOptions::Read(), TransitionModel::Read(), Matrix< Real >::Read(), kaldi::ReadDecodeGraph(), OnlineFasterDecoderOpts::Register(), ParseOptions::Register(), WordBoundaryInfoNewOpts::Register(), OnlineFeatureMatrixOptions::Register(), OnlineTcpVectorSource::ResetSamples(), OnlineTcpVectorSource::SamplesProcessed(), kaldi::SplitStringToIntegers(), Input::Stream(), MfccOptions::use_energy, kaldi::WordAlignLattice(), and kaldi::WriteLine().

68  {
69  using namespace kaldi;
70  using namespace fst;
71 
72  try {
73  typedef kaldi::int32 int32;
74  typedef OnlineFeInput<Mfcc> FeInput;
75  TcpServer tcp_server;
76  signal(SIGPIPE, SIG_IGN);
77 
78  // up to delta-delta derivative features are calculated (unless LDA is used)
79  const int32 kDeltaOrder = 2;
80 
81  const char *usage =
82  "Starts a TCP server that receives RAW audio and outputs aligned words.\n"
83  "A sample client can be found in: onlinebin/online-audio-client\n\n"
84  "Usage: online-audio-server-decode-faster [options] model-in "
85  "fst-in word-symbol-table silence-phones word_boundary_file tcp-port [lda-matrix-in]\n\n"
86  "example: online-audio-server-decode-faster --verbose=1 --rt-min=0.5 --rt-max=3.0 --max-active=6000\n"
87  "--beam=72.0 --acoustic-scale=0.0769 final.mdl graph/HCLG.fst graph/words.txt '1:2:3:4:5'\n"
88  "graph/word_boundary.int 5000 final.mat\n\n";
89 
90  ParseOptions po(usage);
91  BaseFloat acoustic_scale = 0.1;
92  int32 cmn_window = 600, min_cmn_window = 100; // adds 1 second latency, only at utterance start.
93  int32 right_context = 4, left_context = 4;
94  BaseFloat frame_shift = 0.01;
95 
96  OnlineFasterDecoderOpts decoder_opts;
97  decoder_opts.Register(&po, true);
98  OnlineFeatureMatrixOptions feature_reading_opts;
99  feature_reading_opts.Register(&po);
100 
101  po.Register("left-context", &left_context,
102  "Number of frames of left context");
103  po.Register("right-context", &right_context,
104  "Number of frames of right context");
105  po.Register("acoustic-scale", &acoustic_scale,
106  "Scaling factor for acoustic likelihoods");
107  po.Register(
108  "cmn-window", &cmn_window,
109  "Number of feat. vectors used in the running average CMN calculation");
110  po.Register("min-cmn-window", &min_cmn_window,
111  "Minumum CMN window used at start of decoding (adds "
112  "latency only at start)");
113  po.Register("frame-shift", &frame_shift,
114  "Time in seconds between frames.\n");
115 
117  opts.Register(&po);
118 
119  po.Read(argc, argv);
120  if (po.NumArgs() < 6 || po.NumArgs() > 7) {
121  po.PrintUsage();
122  return 1;
123  }
124 
125  std::string model_rspecifier = po.GetArg(1), fst_rspecifier = po.GetArg(2),
126  word_syms_filename = po.GetArg(3), silence_phones_str = po.GetArg(4),
127  word_boundary_file = po.GetArg(5), lda_mat_rspecifier = "";
128 
129  if (po.NumArgs() == 7)
130  lda_mat_rspecifier = po.GetOptArg(7);
131 
132  int32 port = strtol(po.GetArg(6).c_str(), 0, 10);
133 
134  std::vector<int32> silence_phones;
135  if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones))
136  KALDI_ERR << "Invalid silence-phones string " << silence_phones_str;
137  if (silence_phones.empty())
138  KALDI_ERR << "No silence phones given!";
139 
140  if (!tcp_server.Listen(port))
141  return 0;
142 
143  std::cout << "Reading LDA matrix: " << lda_mat_rspecifier << "..."
144  << std::endl;
145  Matrix < BaseFloat > lda_transform;
146  if (lda_mat_rspecifier != "") {
147  bool binary_in;
148  Input ki(lda_mat_rspecifier, &binary_in);
149  lda_transform.Read(ki.Stream(), binary_in);
150  }
151 
152  std::cout << "Reading acoustic model: " << model_rspecifier << "..."
153  << std::endl;
154  TransitionModel trans_model;
155  AmDiagGmm am_gmm;
156  {
157  bool binary;
158  Input ki(model_rspecifier, &binary);
159  trans_model.Read(ki.Stream(), binary);
160  am_gmm.Read(ki.Stream(), binary);
161  }
162 
163  std::cout << "Reading word list: " << word_syms_filename << "..."
164  << std::endl;
165  fst::SymbolTable *word_syms = NULL;
166  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
167  KALDI_ERR << "Could not read symbol table from file "
168  << word_syms_filename;
169 
170  std::cout << "Reading word boundary file: " << word_boundary_file << "..."
171  << std::endl;
172  WordBoundaryInfo info(opts, word_boundary_file);
173 
174  std::cout << "Reading FST: " << fst_rspecifier << "..." << std::endl;
175  fst::Fst < fst::StdArc > *decode_fst = ReadDecodeGraph(fst_rspecifier);
176 
177  // We are not properly registering/exposing MFCC and frame extraction options,
178  // because there are parts of the online decoding code, where some of these
179  // options are hardwired(ToDo: we should fix this at some point)
180  MfccOptions mfcc_opts;
181  mfcc_opts.use_energy = false;
182  int32 frame_length = mfcc_opts.frame_opts.frame_length_ms = 25;
183  int32 mfcc_frame_shift = mfcc_opts.frame_opts.frame_shift_ms = 10;
184 
185  int32 window_size = right_context + left_context + 1;
186  decoder_opts.batch_size = std::max(decoder_opts.batch_size, window_size);
187 
189  det_opts.max_mem = 50000000;
190  det_opts.max_loop = 0;
191 
192  VectorFst < LatticeArc > out_fst;
193  Lattice out_lat;
194  CompactLattice det_lat, aligned_lat;
195  OnlineTcpVectorSource* au_src = NULL;
196  int32 client_socket = -1;
197 
198  while (true) {
199  if (au_src == NULL || !au_src->IsConnected()) {
200  if (au_src) {
201  std::cout << "Client disconnected!" << std::endl;
202  delete au_src;
203  }
204  client_socket = tcp_server.Accept();
205  au_src = new OnlineTcpVectorSource(client_socket);
206  }
207 
208  //re-initalizing decoder for each utterance
209  OnlineFasterDecoder decoder(*decode_fst, decoder_opts, silence_phones,
210  trans_model);
211 
212  Mfcc mfcc(mfcc_opts);
213  FeInput fe_input(au_src, &mfcc, frame_length * (16000 / 1000),
214  mfcc_frame_shift * (16000 / 1000)); //we always assume 16 kHz Fs on input
215  OnlineCmnInput cmn_input(&fe_input, cmn_window, min_cmn_window);
216  OnlineFeatInputItf *feat_transform = 0;
217  if (lda_mat_rspecifier != "") {
218  feat_transform = new OnlineLdaInput(&cmn_input, lda_transform,
219  left_context, right_context);
220  } else {
222  opts.order = kDeltaOrder;
223  feat_transform = new OnlineDeltaInput(opts, &cmn_input);
224  }
225 
226  // feature_reading_opts contains number of retries, batch size.
227  OnlineFeatureMatrix feature_matrix(feature_reading_opts, feat_transform);
228 
229  OnlineDecodableDiagGmmScaled decodable(am_gmm, trans_model,
230  acoustic_scale, &feature_matrix);
231 
232  clock_t start = clock();
233  int32 decoder_offset = 0;
234 
235  while (1) {
236  if (!au_src->IsConnected())
237  break;
238 
239  OnlineFasterDecoder::DecodeState dstate = decoder.Decode(&decodable);
240 
241  if (!au_src->IsConnected()) {
242  break;
243  }
244 
245  if (dstate & (decoder.kEndFeats | decoder.kEndUtt)) {
246  std::vector<int32> word_ids, times, lengths;
247 
248  decoder.FinishTraceBack(&out_fst);
249  decoder.GetBestPath(&out_fst);
250 
251  ConvertLattice(out_fst, &out_lat);
252 
253  Invert(&out_lat);
254  //TopSort(&out_lat);
255  //ArcSort(&out_lat, ILabelCompare<LatticeArc>());
256 
257  DeterminizeLatticePruned(out_lat, 10.0f, &det_lat, det_opts);
258 
259  WordAlignLattice(det_lat, trans_model, info, 0, &aligned_lat);
260 
261  CompactLatticeToWordAlignment(aligned_lat, &word_ids, &times,
262  &lengths);
263 
264  //count number of non-sil words
265  int32 words_num = 0;
266  for (size_t i = 0; i < word_ids.size(); i++)
267  if (word_ids[i] != 0)
268  words_num++;
269 
270  if (words_num > 0) {
271 
272  float dur = (clock() - start) / (float) CLOCKS_PER_SEC;
273  float input_dur = au_src->SamplesProcessed() / 16000.0;
274 
275  start = clock();
276  au_src->ResetSamples();
277 
278  std::stringstream sstr;
279  sstr << "RESULT:NUM=" << words_num << ",FORMAT=WSE,RECO-DUR=" << dur
280  << ",INPUT-DUR=" << input_dur;
281 
282  WriteLine(client_socket, sstr.str());
283 
284  for (size_t i = 0; i < word_ids.size(); i++) {
285  if (word_ids[i] == 0)
286  continue; //skip silences...
287 
288  std::string word = word_syms->Find(word_ids[i]);
289  if (word.empty())
290  word = "???";
291 
292  float start = (times[i] + decoder_offset) / kFramesPerSecond;
293  float len = lengths[i] / kFramesPerSecond;
294 
295  std::stringstream wstr;
296  wstr << word << "," << start << "," << (start + len);
297 
298  WriteLine(client_socket, wstr.str());
299  }
300  }
301 
302  if (dstate == decoder.kEndFeats) {
303  WriteLine(client_socket, "RESULT:DONE");
304  break;
305  }
306 
307  decoder_offset = decoder.frame();
308  } else {
309  std::vector<int32> word_ids;
310  if (decoder.PartialTraceback(&out_fst)) {
311  GetLinearSymbolSequence(out_fst, static_cast<vector<int32> *>(0),
312  &word_ids,
313  static_cast<LatticeArc::Weight*>(0));
314  for (size_t i = 0; i < word_ids.size(); i++) {
315  if (word_ids[i] != 0) {
316  WriteLine(client_socket,
317  "PARTIAL:" + word_syms->Find(word_ids[i]));
318  }
319  }
320  }
321  }
322  }
323  delete feat_transform;
324  }
325 
326  std::cout << "Deinitizalizing..." << std::endl;
327 
328  delete word_syms;
329  delete decode_fst;
330  return 0;
331 
332  } catch (const std::exception& e) {
333  std::cerr << e.what();
334  return -1;
335  }
336 } // main()
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Register(OptionsItf *opts, bool full)
MfccOptions contains basic options for computing MFCC features.
Definition: feature-mfcc.h:38
bool DeterminizeLatticePruned(const ExpandedFst< ArcTpl< Weight > > &ifst, double beam, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > *ofst, DeterminizeLatticePrunedOptions opts)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
kaldi::int32 int32
bool WordAlignLattice(const CompactLattice &lat, const TransitionModel &tmodel, const WordBoundaryInfo &info, int32 max_states, CompactLattice *lat_out)
Align lattice so that each arc has the transition-ids on it that correspond to the word that is on th...
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
void Read(std::istream &in, bool binary, bool add=false)
read from stream.
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
FrameExtractionOptions frame_opts
Definition: feature-mfcc.h:39
void Read(std::istream &is, bool binary)
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
void Register(OptionsItf *opts)
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
bool WriteLine(int32 socket, std::string line)
fst::Fst< fst::StdArc > * ReadDecodeGraph(const std::string &filename)
This templated class is intended for offline feature extraction, i.e.
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
bool CompactLatticeToWordAlignment(const CompactLattice &clat, std::vector< int32 > *words, std::vector< int32 > *begin_times, std::vector< int32 > *lengths)
This function takes a CompactLattice that should only contain a single linear sequence (e...