kws-search.cc File Reference
Include dependency graph for kws-search.cc:

Go to the source code of this file.

Classes

class  VectorFstToKwsLexicographicFstMapper
 
struct  ActivePath
 

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Typedefs

typedef Arc::Weight Weight
 
typedef kaldi::TableWriter< kaldi::BasicVectorHolder< double > > VectorOfDoublesWriter
 

Functions

uint64 EncodeLabel (StateId ilabel, StateId olabel)
 
StateId DecodeLabelUid (uint64 osymbol)
 
bool GenerateActivePaths (const KwsLexicographicFst &proxy, std::vector< ActivePath > *paths, KwsLexicographicFst::StateId cur_state, std::vector< KwsLexicographicArc::Label > cur_path, KwsLexicographicArc::Weight cur_weight)
 
void OutputDetailedStatistics (const std::string &kwid, const kaldi::KwsLexicographicFst &keyword, const unordered_map< uint32, uint64 > &label_decoder, VectorOfDoublesWriter *output)
 
int main (int argc, char *argv[])
 

Typedef Documentation

◆ VectorOfDoublesWriter

Definition at line 120 of file kws-search.cc.

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 158 of file kws-search.cc.

References kaldi::DecodeLabelUid(), SequentialTableReader< Holder >::Done(), kaldi::EncodeLabel(), SequentialTableReader< Holder >::FreeCurrent(), ParseOptions::GetArg(), ParseOptions::GetOptArg(), KALDI_ERR, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), OutputDetailedStatistics(), ParseOptions::PrintUsage(), ParseOptions::Read(), ParseOptions::Register(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), VectorFstToKwsLexicographicFstMapper::VectorFstToKwsLexicographicFstMapper(), and TableWriter< Holder >::Write().

158  {
159  try {
160  using namespace kaldi;
161  using namespace fst;
162  using std::vector;
163  typedef kaldi::int32 int32;
164  typedef kaldi::uint32 uint32;
165  typedef kaldi::uint64 uint64;
166 
167  const char *usage =
168  "Search the keywords over the index. This program can be executed\n"
169  "in parallel, either on the index side or the keywords side; we use\n"
170  "a script to combine the final search results. Note that the index\n"
171  "archive has a single key \"global\".\n\n"
172  "Search has one or two outputs. The first one is mandatory and will\n"
173  "contain the seach output, i.e. list of all found keyword instances\n"
174  "The file is in the following format:\n"
175  "kw_id utt_id beg_frame end_frame neg_logprob\n"
176  " e.g.: \n"
177  "KW105-0198 7 335 376 1.91254\n\n"
178  "The second parameter is optional and allows the user to gather more\n"
179  "statistics about the individual instances from the posting list.\n"
180  "Remember \"keyword\" is an FST and as such, there can be multiple\n"
181  "paths matching in the keyword and in the lattice index in that given\n"
182  "time period. The stats output will provide all matching paths\n"
183  "each with the appropriate score. \n"
184  "The format is as follows:\n"
185  "kw_id utt_id beg_frame end_frame neg_logprob 0 w_id1 w_id2 ... 0\n"
186  " e.g.: \n"
187  "KW105-0198 7 335 376 16.01254 0 5766 5659 0\n"
188  "\n"
189  "Usage: kws-search [options] <index-rspecifier> <keywords-rspecifier> "
190  "<results-wspecifier> [<stats_wspecifier>]\n"
191  " e.g.: kws-search ark:index.idx ark:keywords.fsts "
192  "ark:results ark:stats\n";
193 
194  ParseOptions po(usage);
195 
196  int32 n_best = -1;
197  int32 keyword_nbest = -1;
198  bool strict = true;
199  double negative_tolerance = -0.1;
200  double keyword_beam = -1;
201  int32 frame_subsampling_factor = 1;
202 
203  po.Register("frame-subsampling-factor", &frame_subsampling_factor,
204  "Frame subsampling factor. (Default value 1)");
205  po.Register("nbest", &n_best, "Return the best n hypotheses.");
206  po.Register("keyword-nbest", &keyword_nbest,
207  "Pick the best n keywords if the FST contains "
208  "multiple keywords.");
209  po.Register("strict", &strict, "Affects the return status of the program.");
210  po.Register("negative-tolerance", &negative_tolerance,
211  "The program will print a warning if we get negative score "
212  "smaller than this tolerance.");
213  po.Register("keyword-beam", &keyword_beam,
214  "Prune the FST with the given beam if the FST contains "
215  "multiple keywords.");
216 
217  if (n_best < 0 && n_best != -1) {
218  KALDI_ERR << "Bad number for nbest";
219  exit(1);
220  }
221  if (keyword_nbest < 0 && keyword_nbest != -1) {
222  KALDI_ERR << "Bad number for keyword-nbest";
223  exit(1);
224  }
225  if (keyword_beam < 0 && keyword_beam != -1) {
226  KALDI_ERR << "Bad number for keyword-beam";
227  exit(1);
228  }
229 
230  po.Read(argc, argv);
231 
232  if (po.NumArgs() < 3 || po.NumArgs() > 4) {
233  po.PrintUsage();
234  exit(1);
235  }
236 
237  std::string index_rspecifier = po.GetArg(1),
238  keyword_rspecifier = po.GetArg(2),
239  result_wspecifier = po.GetArg(3),
240  stats_wspecifier = po.GetOptArg(4);
241 
243  index_reader(index_rspecifier);
244  SequentialTableReader<VectorFstHolder> keyword_reader(keyword_rspecifier);
245  VectorOfDoublesWriter result_writer(result_wspecifier);
246  VectorOfDoublesWriter stats_writer(stats_wspecifier);
247 
248 
249  // Index has key "global"
250  KwsLexicographicFst index = index_reader.Value("global");
251 
252  // First we have to remove the disambiguation symbols. But rather than
253  // removing them totally, we actually move them from input side to output
254  // side, making the output symbol a "combined" symbol of the disambiguation
255  // symbols and the utterance id's.
256  // Note that in Dogan and Murat's original paper, they simply remove the
257  // disambiguation symbol on the input symbol side, which will not allow us
258  // to do epsilon removal after composition with the keyword FST. They have
259  // to traverse the resulting FST.
260  int32 label_count = 1;
261  unordered_map<uint64, uint32> label_encoder;
262  unordered_map<uint32, uint64> label_decoder;
263  for (StateIterator<KwsLexicographicFst> siter(index);
264  !siter.Done(); siter.Next()) {
265  StateId state_id = siter.Value();
266  for (MutableArcIterator<KwsLexicographicFst>
267  aiter(&index, state_id); !aiter.Done(); aiter.Next()) {
268  KwsLexicographicArc arc = aiter.Value();
269  // Skip the non-final arcs
270  if (index.Final(arc.nextstate) == Weight::Zero())
271  continue;
272  // Encode the input and output label of the final arc, and this is the
273  // new output label for this arc; set the input label to <epsilon>
274  uint64 osymbol = EncodeLabel(arc.ilabel, arc.olabel);
275  arc.ilabel = 0;
276  if (label_encoder.find(osymbol) == label_encoder.end()) {
277  arc.olabel = label_count;
278  label_encoder[osymbol] = label_count;
279  label_decoder[label_count] = osymbol;
280  label_count++;
281  } else {
282  arc.olabel = label_encoder[osymbol];
283  }
284  aiter.SetValue(arc);
285  }
286  }
287  ArcSort(&index, fst::ILabelCompare<KwsLexicographicArc>());
288 
289  int32 n_done = 0;
290  int32 n_fail = 0;
291  for (; !keyword_reader.Done(); keyword_reader.Next()) {
292  std::string key = keyword_reader.Key();
293  VectorFst<StdArc> keyword = keyword_reader.Value();
294  keyword_reader.FreeCurrent();
295 
296  // Process the case where we have confusion for keywords
297  if (keyword_beam != -1) {
298  Prune(&keyword, keyword_beam);
299  }
300  if (keyword_nbest != -1) {
301  VectorFst<StdArc> tmp;
302  ShortestPath(keyword, &tmp, keyword_nbest, true, true);
303  keyword = tmp;
304  }
305 
306  KwsLexicographicFst keyword_fst;
307  KwsLexicographicFst result_fst;
308  Map(keyword, &keyword_fst, VectorFstToKwsLexicographicFstMapper());
309  Compose(keyword_fst, index, &result_fst);
310 
311  if (stats_wspecifier != "") {
312  KwsLexicographicFst matched_seq(result_fst);
314  matched_seq,
315  label_decoder,
316  &stats_writer);
317  }
318 
319  Project(&result_fst, PROJECT_OUTPUT);
320  Minimize(&result_fst, (KwsLexicographicFst *) nullptr, kDelta, true);
321  ShortestPath(result_fst, &result_fst, n_best);
322  RmEpsilon(&result_fst);
323 
324  // No result found
325  if (result_fst.Start() == kNoStateId)
326  continue;
327 
328  // Got something here
329  double score;
330  int32 tbeg, tend, uid;
331  for (ArcIterator<KwsLexicographicFst>
332  aiter(result_fst, result_fst.Start()); !aiter.Done(); aiter.Next()) {
333  const KwsLexicographicArc &arc = aiter.Value();
334 
335  // We're expecting a two-state FST
336  if (result_fst.Final(arc.nextstate) != Weight::One()) {
337  KALDI_WARN << "The resulting FST does not have "
338  << "the expected structure for key " << key;
339  n_fail++;
340  continue;
341  }
342 
343  uint64 osymbol = label_decoder[arc.olabel];
344  uid = static_cast<int32>(DecodeLabelUid(osymbol));
345  tbeg = arc.weight.Value2().Value1().Value();
346  tend = arc.weight.Value2().Value2().Value();
347  score = arc.weight.Value1().Value();
348 
349  if (score < 0) {
350  if (score < negative_tolerance) {
351  KALDI_WARN << "Score out of expected range: " << score;
352  }
353  score = 0.0;
354  }
355  vector<double> result;
356  result.push_back(uid);
357  result.push_back(tbeg * frame_subsampling_factor);
358  result.push_back(tend * frame_subsampling_factor);
359  result.push_back(score);
360  result_writer.Write(key, result);
361  }
362 
363  n_done++;
364  }
365 
366  KALDI_LOG << "Done " << n_done << " keywords";
367  if (strict == true)
368  return (n_done != 0 ? 0 : 1);
369  else
370  return 0;
371  } catch(const std::exception &e) {
372  std::cerr << e.what();
373  return -1;
374  }
375 }
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
StdLStdLStdArc KwsLexicographicArc
Definition: kaldi-kws.h:45
fst::VectorFst< KwsLexicographicArc > KwsLexicographicFst
Definition: kaldi-kws.h:46
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
uint64 EncodeLabel(StateId ilabel, StateId olabel)
Definition: kws-search.cc:35
StateId DecodeLabelUid(uint64 osymbol)
Definition: kws-search.cc:42
#define KALDI_LOG
Definition: kaldi-error.h:153
void OutputDetailedStatistics(const std::string &kwid, const kaldi::KwsLexicographicFst &keyword, const unordered_map< uint32, uint64 > &label_decoder, VectorOfDoublesWriter *output)
Definition: kws-search.cc:121

◆ OutputDetailedStatistics()

void OutputDetailedStatistics ( const std::string &  kwid,
const kaldi::KwsLexicographicFst keyword,
const unordered_map< uint32, uint64 > &  label_decoder,
VectorOfDoublesWriter output 
)

Definition at line 121 of file kws-search.cc.

References kaldi::DecodeLabelUid(), kaldi::GenerateActivePaths(), rnnlm::i, rnnlm::j, and TableWriter< Holder >::Write().

Referenced by main().

124  {
125  std::vector<kaldi::ActivePath> paths;
126 
127  if (keyword.Start() == fst::kNoStateId)
128  return;
129 
130  kaldi::GenerateActivePaths(keyword, &paths, keyword.Start(),
131  std::vector<kaldi::KwsLexicographicArc::Label>(),
132  kaldi::KwsLexicographicArc::Weight::One());
133 
134  for (int i = 0; i < paths.size(); ++i) {
135  std::vector<double> out;
136  double score;
137  int32 tbeg, tend, uid;
138 
139  uint64 osymbol = label_decoder.find(paths[i].last)->second;
140  uid = kaldi::DecodeLabelUid(osymbol);
141  tbeg = paths[i].weight.Value2().Value1().Value();
142  tend = paths[i].weight.Value2().Value2().Value();
143  score = paths[i].weight.Value1().Value();
144 
145  out.push_back(uid);
146  out.push_back(tbeg);
147  out.push_back(tend);
148  out.push_back(score);
149 
150  for (int j = 0; j < paths[i].path.size(); ++j) {
151  out.push_back(paths[i].path[j]);
152  }
153  output->Write(kwid, out);
154  }
155 }
kaldi::int32 int32
void Write(const std::string &key, const T &value) const
bool GenerateActivePaths(const KwsLexicographicFst &proxy, std::vector< ActivePath > *paths, KwsLexicographicFst::StateId cur_state, std::vector< KwsLexicographicArc::Label > cur_path, KwsLexicographicArc::Weight cur_weight)
Definition: kws-search.cc:90
StateId DecodeLabelUid(uint64 osymbol)
Definition: kws-search.cc:42