generate-proxy-keywords.cc
Go to the documentation of this file.
1 // kwsbin/generate-proxy-keywords.cc
2 
3 // Copyright 2012 Johns Hopkins University (Author: Guoguo Chen)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "fstext/kaldi-fst-io.h"
24 #include "fstext/fstext-utils.h"
25 #include "fstext/prune-special.h"
26 
27 namespace fst {
28 using std::vector;
29 
30 bool PrintProxyFstPath(const VectorFst<StdArc> &proxy,
31  vector<vector<StdArc::Label> > *path,
32  vector<StdArc::Weight> *weight,
33  StdArc::StateId cur_state,
34  vector<StdArc::Label> cur_path,
35  StdArc::Weight cur_weight) {
36  if (proxy.Final(cur_state) != StdArc::Weight::Zero()) {
37  // Assumes only final state has non-zero weight.
38  cur_weight = Times(proxy.Final(cur_state), cur_weight);
39  path->push_back(cur_path);
40  weight->push_back(cur_weight);
41  return true;
42  }
43 
44  for (ArcIterator<StdFst> aiter(proxy, cur_state);
45  !aiter.Done(); aiter.Next()) {
46  const StdArc &arc = aiter.Value();
47  StdArc::Weight temp_weight = Times(arc.weight, cur_weight);
48  cur_path.push_back(arc.ilabel);
49  PrintProxyFstPath(proxy, path, weight,
50  arc.nextstate, cur_path, temp_weight);
51  cur_path.pop_back();
52  }
53 
54  return true;
55 }
56 } // namespace fst
57 
58 int main(int argc, char *argv[]) {
59  try {
60  using namespace kaldi;
61  using namespace fst;
62  using std::vector;
63  using std::string;
64  typedef kaldi::int32 int32;
65  typedef kaldi::uint64 uint64;
66  typedef StdArc::StateId StateId;
67  typedef StdArc::Weight Weight;
68 
69  const char *usage =
70  "Convert the keywords into in-vocabulary words using the given phone\n"
71  "level edit distance fst (E.fst). The large lexicon (L2.fst) and\n"
72  "inverted small lexicon (L1'.fst) are also expected to be present. We\n"
73  "actually use the composed FST L2xE.fst to be more efficient. Ideally\n"
74  "we should have used L2xExL1'.fst but this is quite computationally\n"
75  "expensive at command level. Keywords.int is in the transcription\n"
76  "format. If kwlist-wspecifier is given, the program also prints out\n"
77  "the proxy fst in a format where each line is \"kwid weight proxy\".\n"
78  "\n"
79  "Usage: generate-proxy-keywords [options] <L2xE.fst> <L1'.fst> \\\n"
80  " <keyword-rspecifier> <proxy-wspecifier> [kwlist-wspecifier] \n"
81  " e.g.: generate-proxy-keywords L2xE.fst L1'.fst ark:keywords.int \\\n"
82  " ark:proxy.fsts [ark,t:proxy.kwlist.txt]\n";
83 
84  ParseOptions po(usage);
85 
86  int32 max_states = 100000;
87  int32 phone_nbest = 50;
88  int32 proxy_nbest = 100;
89  double phone_beam = 5;
90  double proxy_beam = 5;
91  po.Register("phone-nbest", &phone_nbest, "Prune KxL2xE transducer to only "
92  "contain top n phone sequences, -1 means all sequences.");
93  po.Register("proxy-nbest", &proxy_nbest, "Prune KxL2xExL1' transducer to "
94  "only contain top n proxy keywords, -1 means all proxies.");
95  po.Register("phone-beam", &phone_beam, "Prune KxL2xE transducer to the "
96  "given beam, -1 means no prune.");
97  po.Register("proxy-beam", &proxy_beam, "Prune KxL2xExL1' transducer to the "
98  "given beam, -1 means no prune.");
99  po.Register("max-states", &max_states, "Prune kxL2xExL1' transducer to the "
100  "given number of states, 0 means no prune.");
101 
102  po.Read(argc, argv);
103 
104  // Checks input options.
105  if (phone_nbest != -1 && phone_nbest <= 0) {
106  KALDI_ERR << "--phone-nbest must either be -1 or positive.";
107  exit(1);
108  }
109  if (proxy_nbest != -1 && proxy_nbest <= 0) {
110  KALDI_ERR << "--proxy-nbest must either be -1 or positive.";
111  exit(1);
112  }
113  if (phone_beam != -1 && phone_beam < 0) {
114  KALDI_ERR << "--phone-beam must either be -1 or non-negative.";
115  exit(1);
116  }
117  if (proxy_beam != -1 && proxy_beam <=0) {
118  KALDI_ERR << "--proxy-beam must either be -1 or non-negative.";
119  exit(1);
120  }
121 
122  if (po.NumArgs() < 4 || po.NumArgs() > 5) {
123  po.PrintUsage();
124  exit(1);
125  }
126 
127  std::string L2xE_filename = po.GetArg(1),
128  L1_filename = po.GetArg(2),
129  keyword_rspecifier = po.GetArg(3),
130  proxy_wspecifier = po.GetArg(4),
131  kwlist_wspecifier = po.GetOptArg(5);
132 
133  VectorFst<StdArc> *L2xE = ReadFstKaldi(L2xE_filename);
134  VectorFst<StdArc> *L1 = ReadFstKaldi(L1_filename);
135  SequentialInt32VectorReader keyword_reader(keyword_rspecifier);
136  TableWriter<VectorFstHolder> proxy_writer(proxy_wspecifier);
137  TableWriter<BasicVectorHolder<double> > kwlist_writer(kwlist_wspecifier);
138 
139  // Processing the keywords.
140  int32 n_done = 0;
141  for (; !keyword_reader.Done(); keyword_reader.Next()) {
142  std::string key = keyword_reader.Key();
143  std::vector<int32> keyword = keyword_reader.Value();
144  keyword_reader.FreeCurrent();
145 
146  KALDI_LOG << "Processing " << key;
147 
148  VectorFst<StdArc> proxy;
149  VectorFst<StdArc> tmp_proxy;
150  MakeLinearAcceptor(keyword, &proxy);
151 
152  // Composing K and L2xE. We assume L2xE is ilabel sorted.
153  KALDI_VLOG(1) << "Compose(K, L2xE)";
154  ArcSort(&proxy, OLabelCompare<StdArc>());
155  Compose(proxy, *L2xE, &tmp_proxy);
156 
157  // Processing KxL2xE.
158  KALDI_VLOG(1) << "Project(KxL2xE, PROJECT_OUTPUT)";
159  Project(&tmp_proxy, PROJECT_OUTPUT);
160  if (phone_beam >= 0) {
161  KALDI_VLOG(1) << "Prune(KxL2xE, " << phone_beam << ")";
162  Prune(&tmp_proxy, phone_beam);
163  }
164  if (phone_nbest > 0) {
165  KALDI_VLOG(1) << "ShortestPath(KxL2xE, " << phone_nbest << ")";
166  RmEpsilon(&tmp_proxy);
167  ShortestPath(tmp_proxy, &proxy, phone_nbest, true, true);
168  tmp_proxy.DeleteStates(); // Not needed for now.
169  KALDI_VLOG(1) << "Determinize(KxL2xE)";
170  Determinize(proxy, &tmp_proxy);
171  proxy.DeleteStates(); // Not needed for now.
172  }
173  KALDI_VLOG(1) << "ArcSort(KxL2xE, OLabel)";
174  proxy = tmp_proxy;
175  tmp_proxy.DeleteStates(); // Not needed for now.
176  ArcSort(&proxy, OLabelCompare<StdArc>());
177 
178 
179  // Processing KxL2xExL1'.
180  RmEpsilon(&proxy);
181  ArcSort(&proxy, OLabelCompare<StdArc>());
182  if (proxy_beam >= 0) {
183  // We only use the delayed FST when pruning is requested, because we do
184  // the optimization in pruning.
185  // Composing KxL2xE and L1'. We assume L1' is ilabel sorted.
186  KALDI_VLOG(1) << "Compose(KxL2xE, L1')";
187  ComposeFst<StdArc> lazy_compose(proxy, *L1);
188  proxy.DeleteStates();
189 
190  KALDI_VLOG(1) << "Project(KxL2xExL1', PROJECT_OUTPUT)";
191  ProjectFst<StdArc> lazy_project(lazy_compose, PROJECT_OUTPUT);
192 
193  // This will likely be the most time consuming part, we use a special
194  // pruning algorithm where we don't expand the full FST.
195  KALDI_VLOG(1) << "Prune(KxL2xExL1', " << proxy_beam << ")";
196  PruneSpecial(lazy_project, &tmp_proxy, proxy_beam, max_states);
197  } else {
198  // If no pruning is requested, we do the normal composition.
199  KALDI_VLOG(1) << "Compose(KxL2xE, L1')";
200  Compose(proxy, *L1, &tmp_proxy);
201  proxy.DeleteStates();
202 
203  KALDI_VLOG(1) << "Project(KxL2xExL1', PROJECT_OUTPUT)";
204  Project(&tmp_proxy, PROJECT_OUTPUT);
205  }
206  if (proxy_nbest > 0) {
207  KALDI_VLOG(1) << "ShortestPath(KxL2xExL1', " << proxy_nbest << ")";
208  proxy = tmp_proxy;
209  tmp_proxy.DeleteStates(); // Not needed for now.
210  RmEpsilon(&proxy);
211  ShortestPath(proxy, &tmp_proxy, proxy_nbest, true, true);
212  proxy.DeleteStates(); // Not needed for now.
213  }
214  KALDI_VLOG(1) << "RmEpsilon(KxL2xExL1')";
215  RmEpsilon(&tmp_proxy);
216  KALDI_VLOG(1) << "Determinize(KxL2xExL1')";
217  Determinize(tmp_proxy, &proxy);
218  tmp_proxy.DeleteStates();
219  KALDI_VLOG(1) << "ArcSort(KxL2xExL1', OLabel)";
220  ArcSort(&proxy, fst::OLabelCompare<StdArc>());
221 
222  // Write the proxy FST.
223  proxy_writer.Write(key, proxy);
224 
225  // Print the proxy FST with each line looks like "kwid weight proxy"
226  if (po.NumArgs() == 5) {
227  if (proxy.Properties(kAcyclic, true) == 0) {
228  KALDI_WARN << "Proxy FST has cycles, skip printing paths for " << key;
229  } else {
230  vector<vector<StdArc::Label> > path;
231  vector<StdArc::Weight> weight;
232  PrintProxyFstPath(proxy, &path, &weight, proxy.Start(),
233  vector<StdArc::Label>(), StdArc::Weight::One());
234  KALDI_ASSERT(path.size() == weight.size());
235  for (int32 i = 0; i < path.size(); i++) {
236  vector<double> kwlist;
237  kwlist.push_back(static_cast<double>(weight[i].Value()));
238  for (int32 j = 0; j < path[i].size(); j++) {
239  kwlist.push_back(static_cast<double>(path[i][j]));
240  }
241  kwlist_writer.Write(key, kwlist);
242  }
243  }
244  }
245 
246  n_done++;
247  }
248 
249  delete L1;
250  delete L2xE;
251  KALDI_LOG << "Done " << n_done << " keywords";
252  return (n_done != 0 ? 0 : 1);
253  } catch(const std::exception &e) {
254  std::cerr << e.what();
255  return -1;
256  }
257 }
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int main(int argc, char *argv[])
void PruneSpecial(const Fst< Arc > &ifst, VectorFst< Arc > *ofst, typename Arc::Weight beam, size_t max_states)
The function PruneSpecial is like the standard OpenFst function "prune", except it does not expand th...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
fst::StdArc StdArc
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
void MakeLinearAcceptor(const std::vector< I > &labels, MutableFst< Arc > *ofst)
Creates unweighted linear acceptor from symbol sequence.
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
bool PrintProxyFstPath(const VectorFst< StdArc > &proxy, vector< vector< StdArc::Label > > *path, vector< StdArc::Weight > *weight, StdArc::StateId cur_state, vector< StdArc::Label > cur_path, StdArc::Weight cur_weight)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
fst::StdArc::Weight Weight
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void ReadFstKaldi(std::istream &is, bool binary, VectorFst< Arc > *fst)
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
#define KALDI_LOG
Definition: kaldi-error.h:153
std::string GetOptArg(int param) const