kws-search.cc
Go to the documentation of this file.
1 // kwsbin/kws-search.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (Authors: Guoguo Chen,
4 // Daniel Povey.
5 // Yenda Trmal)
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 
23 #include "base/kaldi-common.h"
24 #include "util/common-utils.h"
25 #include "fstext/kaldi-fst-io.h"
26 #include "kws/kaldi-kws.h"
27 
28 namespace kaldi {
29 
30 typedef KwsLexicographicArc Arc;
32 typedef Arc::StateId StateId;
33 
34 // encode ilabel, olabel pair as a single 64bit (output) symbol
35 uint64 EncodeLabel(StateId ilabel, StateId olabel) {
36  return (static_cast<int64>(olabel) << 32) + static_cast<int64>(ilabel);
37 }
38 
39 // extract the osymbol from the 64bit symbol. That represents the utterance id
40 // in this setup -- we throw away the isymbol which is typically 0 or an
41 // disambiguation symbol
42 StateId DecodeLabelUid(uint64 osymbol) {
43  return static_cast<StateId>(osymbol >> 32);
44 }
45 
46 // this is a mapper adapter that helps converting
47 // between the StdArc FST (i.e. tropical semiring FST)
48 // to the KwsLexicographic FST. Structure will be kept,
49 // the weights converted/recomputed
51  public:
56 
58 
59  ToArc operator()(const FromArc &arc) const {
60  return ToArc(arc.ilabel,
61  arc.olabel,
62  (arc.weight == FromWeight::Zero() ?
63  ToWeight::Zero() :
64  ToWeight(arc.weight.Value(),
65  StdLStdWeight::One())),
66  arc.nextstate);
67  }
68 
69  fst::MapFinalAction FinalAction() const {
70  return fst::MAP_NO_SUPERFINAL;
71  }
72 
73  fst::MapSymbolsAction InputSymbolsAction() const {
74  return fst::MAP_COPY_SYMBOLS;
75  }
76 
77  fst::MapSymbolsAction OutputSymbolsAction() const {
78  return fst::MAP_COPY_SYMBOLS;
79  }
80 
81  uint64 Properties(uint64 props) const { return props; }
82 };
83 
84 struct ActivePath {
85  std::vector<KwsLexicographicArc::Label> path;
88 };
89 
91  std::vector<ActivePath> *paths,
93  std::vector<KwsLexicographicArc::Label> cur_path,
94  KwsLexicographicArc::Weight cur_weight) {
95  for (fst::ArcIterator<KwsLexicographicFst> aiter(proxy, cur_state);
96  !aiter.Done(); aiter.Next()) {
97  const Arc &arc = aiter.Value();
98  Weight temp_weight = Times(arc.weight, cur_weight);
99 
100  cur_path.push_back(arc.ilabel);
101 
102  if ( arc.olabel != 0 ) {
103  ActivePath path;
104  path.path = cur_path;
105  path.weight = temp_weight;
106  path.last = arc.olabel;
107  paths->push_back(path);
108  } else {
109  GenerateActivePaths(proxy, paths,
110  arc.nextstate, cur_path, temp_weight);
111  }
112  cur_path.pop_back();
113  }
114 
115  return true;
116 }
117 } // namespace kaldi
118 
121 void OutputDetailedStatistics(const std::string &kwid,
122  const kaldi::KwsLexicographicFst &keyword,
123  const unordered_map<uint32, uint64> &label_decoder,
124  VectorOfDoublesWriter *output ) {
125  std::vector<kaldi::ActivePath> paths;
126 
127  if (keyword.Start() == fst::kNoStateId)
128  return;
129 
130  kaldi::GenerateActivePaths(keyword, &paths, keyword.Start(),
131  std::vector<kaldi::KwsLexicographicArc::Label>(),
132  kaldi::KwsLexicographicArc::Weight::One());
133 
134  for (int i = 0; i < paths.size(); ++i) {
135  std::vector<double> out;
136  double score;
137  int32 tbeg, tend, uid;
138 
139  uint64 osymbol = label_decoder.find(paths[i].last)->second;
140  uid = kaldi::DecodeLabelUid(osymbol);
141  tbeg = paths[i].weight.Value2().Value1().Value();
142  tend = paths[i].weight.Value2().Value2().Value();
143  score = paths[i].weight.Value1().Value();
144 
145  out.push_back(uid);
146  out.push_back(tbeg);
147  out.push_back(tend);
148  out.push_back(score);
149 
150  for (int j = 0; j < paths[i].path.size(); ++j) {
151  out.push_back(paths[i].path[j]);
152  }
153  output->Write(kwid, out);
154  }
155 }
156 
157 
158 int main(int argc, char *argv[]) {
159  try {
160  using namespace kaldi;
161  using namespace fst;
162  using std::vector;
163  typedef kaldi::int32 int32;
164  typedef kaldi::uint32 uint32;
165  typedef kaldi::uint64 uint64;
166 
167  const char *usage =
168  "Search the keywords over the index. This program can be executed\n"
169  "in parallel, either on the index side or the keywords side; we use\n"
170  "a script to combine the final search results. Note that the index\n"
171  "archive has a single key \"global\".\n\n"
172  "Search has one or two outputs. The first one is mandatory and will\n"
173  "contain the seach output, i.e. list of all found keyword instances\n"
174  "The file is in the following format:\n"
175  "kw_id utt_id beg_frame end_frame neg_logprob\n"
176  " e.g.: \n"
177  "KW105-0198 7 335 376 1.91254\n\n"
178  "The second parameter is optional and allows the user to gather more\n"
179  "statistics about the individual instances from the posting list.\n"
180  "Remember \"keyword\" is an FST and as such, there can be multiple\n"
181  "paths matching in the keyword and in the lattice index in that given\n"
182  "time period. The stats output will provide all matching paths\n"
183  "each with the appropriate score. \n"
184  "The format is as follows:\n"
185  "kw_id utt_id beg_frame end_frame neg_logprob 0 w_id1 w_id2 ... 0\n"
186  " e.g.: \n"
187  "KW105-0198 7 335 376 16.01254 0 5766 5659 0\n"
188  "\n"
189  "Usage: kws-search [options] <index-rspecifier> <keywords-rspecifier> "
190  "<results-wspecifier> [<stats_wspecifier>]\n"
191  " e.g.: kws-search ark:index.idx ark:keywords.fsts "
192  "ark:results ark:stats\n";
193 
194  ParseOptions po(usage);
195 
196  int32 n_best = -1;
197  int32 keyword_nbest = -1;
198  bool strict = true;
199  double negative_tolerance = -0.1;
200  double keyword_beam = -1;
201  int32 frame_subsampling_factor = 1;
202 
203  po.Register("frame-subsampling-factor", &frame_subsampling_factor,
204  "Frame subsampling factor. (Default value 1)");
205  po.Register("nbest", &n_best, "Return the best n hypotheses.");
206  po.Register("keyword-nbest", &keyword_nbest,
207  "Pick the best n keywords if the FST contains "
208  "multiple keywords.");
209  po.Register("strict", &strict, "Affects the return status of the program.");
210  po.Register("negative-tolerance", &negative_tolerance,
211  "The program will print a warning if we get negative score "
212  "smaller than this tolerance.");
213  po.Register("keyword-beam", &keyword_beam,
214  "Prune the FST with the given beam if the FST contains "
215  "multiple keywords.");
216 
217  if (n_best < 0 && n_best != -1) {
218  KALDI_ERR << "Bad number for nbest";
219  exit(1);
220  }
221  if (keyword_nbest < 0 && keyword_nbest != -1) {
222  KALDI_ERR << "Bad number for keyword-nbest";
223  exit(1);
224  }
225  if (keyword_beam < 0 && keyword_beam != -1) {
226  KALDI_ERR << "Bad number for keyword-beam";
227  exit(1);
228  }
229 
230  po.Read(argc, argv);
231 
232  if (po.NumArgs() < 3 || po.NumArgs() > 4) {
233  po.PrintUsage();
234  exit(1);
235  }
236 
237  std::string index_rspecifier = po.GetArg(1),
238  keyword_rspecifier = po.GetArg(2),
239  result_wspecifier = po.GetArg(3),
240  stats_wspecifier = po.GetOptArg(4);
241 
243  index_reader(index_rspecifier);
244  SequentialTableReader<VectorFstHolder> keyword_reader(keyword_rspecifier);
245  VectorOfDoublesWriter result_writer(result_wspecifier);
246  VectorOfDoublesWriter stats_writer(stats_wspecifier);
247 
248 
249  // Index has key "global"
250  KwsLexicographicFst index = index_reader.Value("global");
251 
252  // First we have to remove the disambiguation symbols. But rather than
253  // removing them totally, we actually move them from input side to output
254  // side, making the output symbol a "combined" symbol of the disambiguation
255  // symbols and the utterance id's.
256  // Note that in Dogan and Murat's original paper, they simply remove the
257  // disambiguation symbol on the input symbol side, which will not allow us
258  // to do epsilon removal after composition with the keyword FST. They have
259  // to traverse the resulting FST.
260  int32 label_count = 1;
261  unordered_map<uint64, uint32> label_encoder;
262  unordered_map<uint32, uint64> label_decoder;
263  for (StateIterator<KwsLexicographicFst> siter(index);
264  !siter.Done(); siter.Next()) {
265  StateId state_id = siter.Value();
266  for (MutableArcIterator<KwsLexicographicFst>
267  aiter(&index, state_id); !aiter.Done(); aiter.Next()) {
268  KwsLexicographicArc arc = aiter.Value();
269  // Skip the non-final arcs
270  if (index.Final(arc.nextstate) == Weight::Zero())
271  continue;
272  // Encode the input and output label of the final arc, and this is the
273  // new output label for this arc; set the input label to <epsilon>
274  uint64 osymbol = EncodeLabel(arc.ilabel, arc.olabel);
275  arc.ilabel = 0;
276  if (label_encoder.find(osymbol) == label_encoder.end()) {
277  arc.olabel = label_count;
278  label_encoder[osymbol] = label_count;
279  label_decoder[label_count] = osymbol;
280  label_count++;
281  } else {
282  arc.olabel = label_encoder[osymbol];
283  }
284  aiter.SetValue(arc);
285  }
286  }
287  ArcSort(&index, fst::ILabelCompare<KwsLexicographicArc>());
288 
289  int32 n_done = 0;
290  int32 n_fail = 0;
291  for (; !keyword_reader.Done(); keyword_reader.Next()) {
292  std::string key = keyword_reader.Key();
293  VectorFst<StdArc> keyword = keyword_reader.Value();
294  keyword_reader.FreeCurrent();
295 
296  // Process the case where we have confusion for keywords
297  if (keyword_beam != -1) {
298  Prune(&keyword, keyword_beam);
299  }
300  if (keyword_nbest != -1) {
301  VectorFst<StdArc> tmp;
302  ShortestPath(keyword, &tmp, keyword_nbest, true, true);
303  keyword = tmp;
304  }
305 
306  KwsLexicographicFst keyword_fst;
307  KwsLexicographicFst result_fst;
308  Map(keyword, &keyword_fst, VectorFstToKwsLexicographicFstMapper());
309  Compose(keyword_fst, index, &result_fst);
310 
311  if (stats_wspecifier != "") {
312  KwsLexicographicFst matched_seq(result_fst);
314  matched_seq,
315  label_decoder,
316  &stats_writer);
317  }
318 
319  Project(&result_fst, PROJECT_OUTPUT);
320  Minimize(&result_fst, (KwsLexicographicFst *) nullptr, kDelta, true);
321  ShortestPath(result_fst, &result_fst, n_best);
322  RmEpsilon(&result_fst);
323 
324  // No result found
325  if (result_fst.Start() == kNoStateId)
326  continue;
327 
328  // Got something here
329  double score;
330  int32 tbeg, tend, uid;
331  for (ArcIterator<KwsLexicographicFst>
332  aiter(result_fst, result_fst.Start()); !aiter.Done(); aiter.Next()) {
333  const KwsLexicographicArc &arc = aiter.Value();
334 
335  // We're expecting a two-state FST
336  if (result_fst.Final(arc.nextstate) != Weight::One()) {
337  KALDI_WARN << "The resulting FST does not have "
338  << "the expected structure for key " << key;
339  n_fail++;
340  continue;
341  }
342 
343  uint64 osymbol = label_decoder[arc.olabel];
344  uid = static_cast<int32>(DecodeLabelUid(osymbol));
345  tbeg = arc.weight.Value2().Value1().Value();
346  tend = arc.weight.Value2().Value2().Value();
347  score = arc.weight.Value1().Value();
348 
349  if (score < 0) {
350  if (score < negative_tolerance) {
351  KALDI_WARN << "Score out of expected range: " << score;
352  }
353  score = 0.0;
354  }
355  vector<double> result;
356  result.push_back(uid);
357  result.push_back(tbeg * frame_subsampling_factor);
358  result.push_back(tend * frame_subsampling_factor);
359  result.push_back(score);
360  result_writer.Write(key, result);
361  }
362 
363  n_done++;
364  }
365 
366  KALDI_LOG << "Done " << n_done << " keywords";
367  if (strict == true)
368  return (n_done != 0 ? 0 : 1);
369  else
370  return 0;
371  } catch(const std::exception &e) {
372  std::cerr << e.what();
373  return -1;
374  }
375 }
fst::StdArc::StateId StateId
ToArc operator()(const FromArc &arc) const
Definition: kws-search.cc:59
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int main(int argc, char *argv[])
Definition: kws-search.cc:158
Lattice::StateId StateId
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
fst::StdArc StdArc
uint64 Properties(uint64 props) const
Definition: kws-search.cc:81
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
fst::MapFinalAction FinalAction() const
Definition: kws-search.cc:69
KwsLexicographicArc::Weight weight
Definition: kws-search.cc:86
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
StdLStdLStdArc KwsLexicographicArc
Definition: kaldi-kws.h:45
fst::VectorFst< KwsLexicographicArc > KwsLexicographicFst
Definition: kaldi-kws.h:46
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
KwsLexicographicArc::Label last
Definition: kws-search.cc:87
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
fst::MapSymbolsAction InputSymbolsAction() const
Definition: kws-search.cc:73
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
kaldi::TableWriter< kaldi::BasicVectorHolder< double > > VectorOfDoublesWriter
Definition: kws-search.cc:120
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool GenerateActivePaths(const KwsLexicographicFst &proxy, std::vector< ActivePath > *paths, KwsLexicographicFst::StateId cur_state, std::vector< KwsLexicographicArc::Label > cur_path, KwsLexicographicArc::Weight cur_weight)
Definition: kws-search.cc:90
fst::StdArc::Label Label
uint64 EncodeLabel(StateId ilabel, StateId olabel)
Definition: kws-search.cc:35
int NumArgs() const
Number of positional parameters (c.f. argc-1).
std::vector< KwsLexicographicArc::Label > path
Definition: kws-search.cc:85
Arc::Weight Weight
Definition: kws-search.cc:31
fst::MapSymbolsAction OutputSymbolsAction() const
Definition: kws-search.cc:77
StdLStdLStdWeight KwsLexicographicWeight
Definition: kaldi-kws.h:44
StateId DecodeLabelUid(uint64 osymbol)
Definition: kws-search.cc:42
#define KALDI_LOG
Definition: kaldi-error.h:153
void OutputDetailedStatistics(const std::string &kwid, const kaldi::KwsLexicographicFst &keyword, const unordered_map< uint32, uint64 > &label_decoder, VectorOfDoublesWriter *output)
Definition: kws-search.cc:121
std::string GetOptArg(int param) const