36 return (static_cast<int64>(olabel) << 32) +
static_cast<int64
>(ilabel);
43 return static_cast<StateId
>(osymbol >> 32);
60 return ToArc(arc.ilabel,
62 (arc.weight == FromWeight::Zero() ?
64 ToWeight(arc.weight.Value(),
65 StdLStdWeight::One())),
70 return fst::MAP_NO_SUPERFINAL;
74 return fst::MAP_COPY_SYMBOLS;
78 return fst::MAP_COPY_SYMBOLS;
85 std::vector<KwsLexicographicArc::Label>
path;
91 std::vector<ActivePath> *paths,
93 std::vector<KwsLexicographicArc::Label> cur_path,
95 for (fst::ArcIterator<KwsLexicographicFst> aiter(proxy, cur_state);
96 !aiter.Done(); aiter.Next()) {
97 const Arc &arc = aiter.Value();
98 Weight temp_weight =
Times(arc.weight, cur_weight);
100 cur_path.push_back(arc.ilabel);
102 if ( arc.olabel != 0 ) {
104 path.
path = cur_path;
105 path.
weight = temp_weight;
106 path.
last = arc.olabel;
107 paths->push_back(path);
110 arc.nextstate, cur_path, temp_weight);
123 const unordered_map<uint32, uint64> &label_decoder,
125 std::vector<kaldi::ActivePath> paths;
127 if (keyword.Start() == fst::kNoStateId)
131 std::vector<kaldi::KwsLexicographicArc::Label>(),
132 kaldi::KwsLexicographicArc::Weight::One());
134 for (
int i = 0;
i < paths.size(); ++
i) {
135 std::vector<double> out;
137 int32 tbeg, tend, uid;
139 uint64 osymbol = label_decoder.find(paths[
i].last)->second;
141 tbeg = paths[
i].weight.Value2().Value1().Value();
142 tend = paths[
i].weight.Value2().Value2().Value();
143 score = paths[
i].weight.Value1().Value();
148 out.push_back(score);
150 for (
int j = 0;
j < paths[
i].path.size(); ++
j) {
151 out.push_back(paths[
i].path[
j]);
153 output->
Write(kwid, out);
158 int main(
int argc,
char *argv[]) {
160 using namespace kaldi;
164 typedef kaldi::uint32 uint32;
165 typedef kaldi::uint64 uint64;
168 "Search the keywords over the index. This program can be executed\n" 169 "in parallel, either on the index side or the keywords side; we use\n" 170 "a script to combine the final search results. Note that the index\n" 171 "archive has a single key \"global\".\n\n" 172 "Search has one or two outputs. The first one is mandatory and will\n" 173 "contain the seach output, i.e. list of all found keyword instances\n" 174 "The file is in the following format:\n" 175 "kw_id utt_id beg_frame end_frame neg_logprob\n" 177 "KW105-0198 7 335 376 1.91254\n\n" 178 "The second parameter is optional and allows the user to gather more\n" 179 "statistics about the individual instances from the posting list.\n" 180 "Remember \"keyword\" is an FST and as such, there can be multiple\n" 181 "paths matching in the keyword and in the lattice index in that given\n" 182 "time period. The stats output will provide all matching paths\n" 183 "each with the appropriate score. \n" 184 "The format is as follows:\n" 185 "kw_id utt_id beg_frame end_frame neg_logprob 0 w_id1 w_id2 ... 0\n" 187 "KW105-0198 7 335 376 16.01254 0 5766 5659 0\n" 189 "Usage: kws-search [options] <index-rspecifier> <keywords-rspecifier> " 190 "<results-wspecifier> [<stats_wspecifier>]\n" 191 " e.g.: kws-search ark:index.idx ark:keywords.fsts " 192 "ark:results ark:stats\n";
197 int32 keyword_nbest = -1;
199 double negative_tolerance = -0.1;
200 double keyword_beam = -1;
201 int32 frame_subsampling_factor = 1;
203 po.
Register(
"frame-subsampling-factor", &frame_subsampling_factor,
204 "Frame subsampling factor. (Default value 1)");
205 po.
Register(
"nbest", &n_best,
"Return the best n hypotheses.");
206 po.
Register(
"keyword-nbest", &keyword_nbest,
207 "Pick the best n keywords if the FST contains " 208 "multiple keywords.");
209 po.
Register(
"strict", &strict,
"Affects the return status of the program.");
210 po.
Register(
"negative-tolerance", &negative_tolerance,
211 "The program will print a warning if we get negative score " 212 "smaller than this tolerance.");
213 po.
Register(
"keyword-beam", &keyword_beam,
214 "Prune the FST with the given beam if the FST contains " 215 "multiple keywords.");
217 if (n_best < 0 && n_best != -1) {
221 if (keyword_nbest < 0 && keyword_nbest != -1) {
222 KALDI_ERR <<
"Bad number for keyword-nbest";
225 if (keyword_beam < 0 && keyword_beam != -1) {
226 KALDI_ERR <<
"Bad number for keyword-beam";
237 std::string index_rspecifier = po.
GetArg(1),
238 keyword_rspecifier = po.
GetArg(2),
239 result_wspecifier = po.
GetArg(3),
243 index_reader(index_rspecifier);
260 int32 label_count = 1;
261 unordered_map<uint64, uint32> label_encoder;
262 unordered_map<uint32, uint64> label_decoder;
263 for (StateIterator<KwsLexicographicFst> siter(index);
264 !siter.Done(); siter.Next()) {
265 StateId state_id = siter.Value();
266 for (MutableArcIterator<KwsLexicographicFst>
267 aiter(&index, state_id); !aiter.Done(); aiter.Next()) {
270 if (index.Final(arc.nextstate) == Weight::Zero())
274 uint64 osymbol =
EncodeLabel(arc.ilabel, arc.olabel);
276 if (label_encoder.find(osymbol) == label_encoder.end()) {
277 arc.olabel = label_count;
278 label_encoder[osymbol] = label_count;
279 label_decoder[label_count] = osymbol;
282 arc.olabel = label_encoder[osymbol];
287 ArcSort(&index, fst::ILabelCompare<KwsLexicographicArc>());
291 for (; !keyword_reader.
Done(); keyword_reader.
Next()) {
292 std::string key = keyword_reader.
Key();
293 VectorFst<StdArc> keyword = keyword_reader.
Value();
297 if (keyword_beam != -1) {
298 Prune(&keyword, keyword_beam);
300 if (keyword_nbest != -1) {
301 VectorFst<StdArc> tmp;
302 ShortestPath(keyword, &tmp, keyword_nbest,
true,
true);
309 Compose(keyword_fst, index, &result_fst);
311 if (stats_wspecifier !=
"") {
319 Project(&result_fst, PROJECT_OUTPUT);
321 ShortestPath(result_fst, &result_fst, n_best);
322 RmEpsilon(&result_fst);
325 if (result_fst.Start() == kNoStateId)
330 int32 tbeg, tend, uid;
331 for (ArcIterator<KwsLexicographicFst>
332 aiter(result_fst, result_fst.Start()); !aiter.Done(); aiter.Next()) {
336 if (result_fst.Final(arc.nextstate) != Weight::One()) {
337 KALDI_WARN <<
"The resulting FST does not have " 338 <<
"the expected structure for key " << key;
343 uint64 osymbol = label_decoder[arc.olabel];
345 tbeg = arc.weight.Value2().Value1().Value();
346 tend = arc.weight.Value2().Value2().Value();
347 score = arc.weight.Value1().Value();
350 if (score < negative_tolerance) {
351 KALDI_WARN <<
"Score out of expected range: " << score;
355 vector<double> result;
356 result.push_back(uid);
357 result.push_back(tbeg * frame_subsampling_factor);
358 result.push_back(tend * frame_subsampling_factor);
359 result.push_back(score);
360 result_writer.
Write(key, result);
366 KALDI_LOG <<
"Done " << n_done <<
" keywords";
368 return (n_done != 0 ? 0 : 1);
371 }
catch(
const std::exception &e) {
372 std::cerr << e.what();
fst::StdArc::StateId StateId
ToArc operator()(const FromArc &arc) const
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
FromArc::Weight FromWeight
int main(int argc, char *argv[])
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
uint64 Properties(uint64 props) const
A templated class for writing objects to an archive or script file; see The Table concept...
fst::MapFinalAction FinalAction() const
KwsLexicographicArc::Weight weight
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
StdLStdLStdArc KwsLexicographicArc
fst::VectorFst< KwsLexicographicArc > KwsLexicographicFst
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
KwsLexicographicArc ToArc
KwsLexicographicArc::Label last
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
fst::MapSymbolsAction InputSymbolsAction() const
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
kaldi::TableWriter< kaldi::BasicVectorHolder< double > > VectorOfDoublesWriter
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool GenerateActivePaths(const KwsLexicographicFst &proxy, std::vector< ActivePath > *paths, KwsLexicographicFst::StateId cur_state, std::vector< KwsLexicographicArc::Label > cur_path, KwsLexicographicArc::Weight cur_weight)
uint64 EncodeLabel(StateId ilabel, StateId olabel)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
std::vector< KwsLexicographicArc::Label > path
fst::MapSymbolsAction OutputSymbolsAction() const
StdLStdLStdWeight KwsLexicographicWeight
StateId DecodeLabelUid(uint64 osymbol)
VectorFstToKwsLexicographicFstMapper()
KwsLexicographicWeight ToWeight
void OutputDetailedStatistics(const std::string &kwid, const kaldi::KwsLexicographicFst &keyword, const unordered_map< uint32, uint64 > &label_decoder, VectorOfDoublesWriter *output)
std::string GetOptArg(int param) const