online-server-gmm-decode-faster.cc
Go to the documentation of this file.
1 // onlinebin/online-server-gmm-decode-faster.cc
2 
3 // Copyright 2012 Cisco Systems (author: Matthias Paulik)
4 // 2012 Vassil Panayotov
5 // 2013 Johns Hopkins University (author: Daniel Povey)
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include "feat/feature-mfcc.h"
26 #include "online/onlinebin-util.h"
27 
28 namespace kaldi {
29 
30 void SendPartialResult(const std::vector<int32>& words,
31  const fst::SymbolTable *word_syms,
32  const bool line_break,
33  const int32 serv_sock,
34  const sockaddr_in &client_addr) {
35  KALDI_ASSERT(word_syms != NULL);
36  std::stringstream sstream;
37  for (size_t i = 0; i < words.size(); i++) {
38  std::string word = word_syms->Find(words[i]);
39  if (word == "")
40  KALDI_ERR << "Word-id " << words[i] <<" not in symbol table.";
41  sstream << word << ' ';
42  }
43  if (line_break)
44  sstream << "\n\n";
45 
46  ssize_t sent = sendto(serv_sock, sstream.str().c_str(), sstream.str().size(),
47  0, reinterpret_cast<const sockaddr*>(&client_addr),
48  sizeof(client_addr));
49  if (sent == -1)
50  KALDI_WARN << "sendto() call failed when tried to send recognition results";
51 }
52 
53 } // namespace kaldi
54 
55 
56 int main(int argc, char *argv[]) {
57  try {
58  using namespace kaldi;
59  using namespace fst;
60 
61  typedef kaldi::int32 int32;
62 
63  // Up to delta-delta derivative features are calculated (unless LDA is used)
64  const int32 kDeltaOrder = 2;
65 
66  const char *usage =
67  "Decode speech, using feature batches received over a network connection\n\n"
68  "Utterance segmentation is done on-the-fly.\n"
69  "Feature splicing/LDA transform is used, if the optional(last) argument "
70  "is given.\n"
71  "Otherwise delta/delta-delta(2-nd order) features are produced.\n\n"
72  "Usage: online-server-gmm-decode-faster [options] model-in"
73  "fst-in word-symbol-table silence-phones udp-port [lda-matrix-in]\n\n"
74  "Example: online-server-gmm-decode-faster --rt-min=0.3 --rt-max=0.5 "
75  "--max-active=4000 --beam=12.0 --acoustic-scale=0.0769 "
76  "model HCLG.fst words.txt '1:2:3:4:5' 1234 lda-matrix";
77  ParseOptions po(usage);
78  BaseFloat acoustic_scale = 0.1;
79  int32 cmn_window = 600,
80  min_cmn_window = 100; // adds 1 second latency, only at utterance start.
81  int32 right_context = 4, left_context = 4;
82 
83  kaldi::DeltaFeaturesOptions delta_opts;
84  delta_opts.Register(&po);
85  OnlineFasterDecoderOpts decoder_opts;
86  OnlineFeatureMatrixOptions feature_reading_opts;
87  decoder_opts.Register(&po, true);
88  feature_reading_opts.Register(&po);
89 
90  po.Register("left-context", &left_context, "Number of frames of left context");
91  po.Register("right-context", &right_context, "Number of frames of right context");
92  po.Register("acoustic-scale", &acoustic_scale,
93  "Scaling factor for acoustic likelihoods");
94  po.Register("cmn-window", &cmn_window,
95  "Number of feat. vectors used in the running average CMN calculation");
96  po.Register("min-cmn-window", &min_cmn_window,
97  "Minumum CMN window used at start of decoding (adds "
98  "latency only at start)");
99 
100  po.Read(argc, argv);
101  if (po.NumArgs() != 5 && po.NumArgs() != 6) {
102  po.PrintUsage();
103  return 1;
104  }
105 
106  std::string model_rxfilename = po.GetArg(1),
107  fst_rxfilename = po.GetArg(2),
108  word_syms_filename = po.GetArg(3),
109  silence_phones_str = po.GetArg(4),
110  lda_mat_rspecifier = po.GetOptArg(6);
111  int32 udp_port = atoi(po.GetArg(5).c_str());
112 
113  Matrix<BaseFloat> lda_transform;
114  if (lda_mat_rspecifier != "") {
115  bool binary_in;
116  Input ki(lda_mat_rspecifier, &binary_in);
117  lda_transform.Read(ki.Stream(), binary_in);
118  }
119 
120  std::vector<int32> silence_phones;
121  if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones))
122  KALDI_ERR << "Invalid silence-phones string " << silence_phones_str;
123  if (silence_phones.empty())
124  KALDI_ERR << "No silence phones given!";
125 
126  TransitionModel trans_model;
127  AmDiagGmm am_gmm;
128  {
129  bool binary;
130  Input ki(model_rxfilename, &binary);
131  trans_model.Read(ki.Stream(), binary);
132  am_gmm.Read(ki.Stream(), binary);
133  }
134 
135  fst::SymbolTable *word_syms = NULL;
136  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
137  KALDI_ERR << "Could not read symbol table from file "
138  << word_syms_filename;
139 
140  fst::Fst<fst::StdArc> *decode_fst = ReadDecodeGraph(fst_rxfilename);
141 
142  // We are not properly registering/exposing MFCC and frame extraction options,
143  // because there are parts of the online decoding code, where some of these
144  // options are hardwired(ToDo: we should fix this at some point)
145  MfccOptions mfcc_opts;
146  mfcc_opts.use_energy = false;
147 
148  OnlineFasterDecoder decoder(*decode_fst, decoder_opts,
149  silence_phones, trans_model);
150  VectorFst<LatticeArc> out_fst;
151  int32 feature_dim = mfcc_opts.num_ceps; // default to 13 right now.
152  OnlineUdpInput udp_input(udp_port, feature_dim);
153  OnlineCmnInput cmn_input(&udp_input, cmn_window, min_cmn_window);
154  OnlineFeatInputItf *feat_transform = 0;
155 
156  if (lda_mat_rspecifier != "") {
157  feat_transform = new OnlineLdaInput(
158  &cmn_input, lda_transform,
159  left_context, right_context);
160  } else {
162  opts.order = kDeltaOrder;
163  feat_transform = new OnlineDeltaInput(opts, &cmn_input);
164  }
165 
166  // feature_reading_opts contains number of retries, batch size.
167  OnlineFeatureMatrix feature_matrix(feature_reading_opts,
168  feat_transform);
169 
170  OnlineDecodableDiagGmmScaled decodable(am_gmm, trans_model, acoustic_scale,
171  &feature_matrix);
172 
173  std::cerr << std::endl << "Listening on UDP port "
174  << udp_port << " ... " << std::endl;
175  bool partial_res = false;
176  while (1) {
177  OnlineFasterDecoder::DecodeState dstate = decoder.Decode(&decodable);
178  std::vector<int32> word_ids;
179  if (dstate & (decoder.kEndFeats | decoder.kEndUtt)) {
180  decoder.FinishTraceBack(&out_fst);
182  static_cast<vector<int32> *>(0),
183  &word_ids,
184  static_cast<LatticeArc::Weight*>(0));
185  SendPartialResult(word_ids, word_syms, partial_res || word_ids.size(),
186  udp_input.descriptor(), udp_input.client_addr());
187  partial_res = false;
188  } else {
189  if (decoder.PartialTraceback(&out_fst)) {
191  static_cast<vector<int32> *>(0),
192  &word_ids,
193  static_cast<LatticeArc::Weight*>(0));
194  SendPartialResult(word_ids, word_syms, false,
195  udp_input.descriptor(), udp_input.client_addr());
196  if (!partial_res)
197  partial_res = (word_ids.size() > 0);
198  }
199  }
200  }
201 
202  delete feat_transform;
203  delete word_syms;
204  delete decode_fst;
205  return 0;
206  } catch(const std::exception& e) {
207  std::cerr << e.what();
208  return -1;
209  }
210 } // main()
int32 words[kMaxOrder]
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool PartialTraceback(fst::MutableFst< LatticeArc > *out_fst)
void Register(OptionsItf *opts, bool full)
MfccOptions contains basic options for computing MFCC features.
Definition: feature-mfcc.h:38
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
const sockaddr_in & client_addr() const
DecodeState Decode(DecodableInterface *decodable)
int main(int argc, char *argv[])
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
void Register(const std::string &name, bool *ptr, const std::string &doc)
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void FinishTraceBack(fst::MutableFst< LatticeArc > *fst_out)
void Read(std::istream &is, bool binary)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
int NumArgs() const
Number of positional parameters (c.f. argc-1).
fst::Fst< fst::StdArc > * ReadDecodeGraph(const std::string &filename)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Register(OptionsItf *opts)
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
const int32 descriptor() const
void SendPartialResult(const std::vector< int32 > &words, const fst::SymbolTable *word_syms, const bool line_break, const int32 serv_sock, const sockaddr_in &client_addr)
std::string GetOptArg(int param) const