lattice-oracle.cc
Go to the documentation of this file.
1 // latbin/lattice-oracle.cc
2 
3 // Copyright 2011 Gilles Boulianne
4 // 2013 Johns Hopkins University (author: Daniel Povey)
5 // 2015 Guoguo Chen
6 //
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include "base/kaldi-common.h"
23 #include "util/common-utils.h"
24 #include "fstext/fstext-lib.h"
25 #include "lat/kaldi-lattice.h"
26 #include "lat/lattice-functions.h"
27 
28 namespace kaldi {
29 
30  using std::string;
31 
33 typedef std::vector<std::pair<Label, Label>> LabelPairVector;
34 
35 void ReadSymbolList(const std::string &rxfilename,
36  fst::SymbolTable *word_syms,
37  LabelPairVector *lpairs) {
38  Input ki(rxfilename);
39  std::string line;
40  KALDI_ASSERT(lpairs != NULL);
41  lpairs->clear();
42  while (getline(ki.Stream(), line)) {
43  std::string sym;
44  std::istringstream ss(line);
45  ss >> sym >> std::ws;
46  if (ss.fail() || !ss.eof()) {
47  KALDI_ERR << "Bad line in symbol list: "<< line
48  << ", file is: " << PrintableRxfilename(rxfilename);
49  }
50  fst::StdArc::Label lab = word_syms->Find(sym.c_str());
51  if (lab == -1) { // fst::kNoSymbol
52  KALDI_ERR << "Can't find symbol in symbol table: "
53  << line << ", file is: "
54  << PrintableRxfilename(rxfilename);
55  }
56  lpairs->emplace_back(lab, 0);
57  }
58 }
59 
60 // convert from Lattice to standard FST
61 // also maps wildcard symbols to epsilons
62 // then removes epsilons
64  const LabelPairVector &wildcards,
65  fst::StdVectorFst *ofst) {
66  // first convert from lattice to normal FST
67  fst::ConvertLattice(ilat, ofst);
68  // remove weights, project to output, sort according to input arg
69  fst::Map(ofst, fst::RmWeightMapper<fst::StdArc>());
70  fst::Project(ofst, fst::PROJECT_OUTPUT); // The words are on the output side
71  fst::Relabel(ofst, wildcards, wildcards);
72  fst::RmEpsilon(ofst); // Don't tolerate epsilons as they make it hard to
73  // tally errors
74  fst::ArcSort(ofst, fst::StdILabelCompare());
75 }
76 
78  const fst::StdVectorFst &fst2,
79  fst::StdVectorFst *pfst) {
80  typedef fst::StdArc StdArc;
82  typedef fst::StdArc::Label Label;
83  Weight correct_cost(0.0);
84  Weight substitution_cost(1.0);
85  Weight insertion_cost(1.0);
86  Weight deletion_cost(1.0);
87 
88  // create set of output symbols in fst1
89  std::vector<Label> fst1syms, fst2syms;
90  GetOutputSymbols(fst1, false /*no epsilons*/, &fst1syms);
91  GetInputSymbols(fst2, false /*no epsilons*/, &fst2syms);
92 
93  pfst->AddState();
94  pfst->SetStart(0);
95  for (size_t i = 0; i < fst1syms.size(); i++)
96  pfst->AddArc(0, StdArc(fst1syms[i], 0, deletion_cost, 0)); // deletions
97 
98  for (size_t i = 0; i < fst2syms.size(); i++)
99  pfst->AddArc(0, StdArc(0, fst2syms[i], insertion_cost, 0)); // insertions
100 
101  // stupid implementation O(N^2)
102  for (size_t i = 0; i < fst1syms.size(); i++) {
103  Label label1 = fst1syms[i];
104  for (size_t j = 0; j < fst2syms.size(); j++) {
105  Label label2 = fst2syms[j];
106  Weight cost(label1 == label2 ? correct_cost : substitution_cost);
107  pfst->AddArc(0, StdArc(label1, label2, cost, 0)); // substitutions
108  }
109  }
110  pfst->SetFinal(0, Weight::One());
111  ArcSort(pfst, fst::StdOLabelCompare());
112 }
113 
115  int32 *correct,
116  int32 *substitutions,
117  int32 *insertions,
118  int32 *deletions,
119  int32 *num_words) {
121  typedef fst::StdArc::Weight Weight;
122  *correct = *substitutions = *insertions = *deletions = *num_words = 0;
123 
124  // go through the first complete path in fst (there should be only one)
125  StateId src = fst.Start();
126  while (fst.Final(src)== Weight::Zero()) { // while not final
127  for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, src);
128  !aiter.Done(); aiter.Next()) {
129  fst::StdArc arc = aiter.Value();
130  if (arc.ilabel == arc.olabel && arc.ilabel != 0) {
131  (*correct)++;
132  (*num_words)++;
133  } else if (arc.ilabel == 0 && arc.olabel != 0) {
134  (*deletions)++;
135  (*num_words)++;
136  } else if (arc.ilabel != 0 && arc.olabel == 0) {
137  (*insertions)++;
138  } else if (arc.ilabel != 0 && arc.olabel != 0) {
139  (*substitutions)++;
140  (*num_words)++;
141  } else {
142  KALDI_ASSERT(arc.ilabel == 0 && arc.olabel == 0);
143  }
144  src = arc.nextstate;
145  continue; // jump to next state
146  }
147  }
148 }
149 
150 
151 bool CheckFst(const fst::StdVectorFst &fst, string name, string key) {
152 #ifdef DEBUG
153  fst::StdArc::StateId numstates = fst.NumStates();
154  std::cerr << " " << name << " has " << numstates << " states" << std::endl;
155  std::stringstream ss;
156  ss << name << key << ".fst";
157  fst.Write(ss.str());
158  return(fst.Start() == fst::kNoStateId);
159 #else
160  return true;
161 #endif
162 }
163 }
164 
165 int main(int argc, char *argv[]) {
166  try {
167  using namespace kaldi;
168  using fst::SymbolTable;
169  using fst::VectorFst;
170  using fst::StdArc;
171  typedef kaldi::int32 int32;
172  typedef kaldi::int64 int64;
173  typedef fst::StdArc::Weight Weight;
175 
176  const char *usage =
177  "Finds the path having the smallest edit-distance between a lattice\n"
178  "and a reference string.\n"
179  "\n"
180  "Usage: lattice-oracle [options] <test-lattice-rspecifier> \\\n"
181  " <reference-rspecifier> \\\n"
182  " <transcriptions-wspecifier> \\\n"
183  " [<edit-distance-wspecifier>]\n"
184  " e.g.: lattice-oracle ark:lat.1 'ark:sym2int.pl -f 2- \\\n"
185  " data/lang/words.txt <data/test/text|' ark,t:-\n"
186  "\n"
187  "Note the --write-lattices option by which you can write out the\n"
188  "optimal path as a lattice.\n"
189  "Note: you can use this program to compute the n-best oracle WER by\n"
190  "first piping the input lattices through lattice-to-nbest and then\n"
191  "nbest-to-lattice.\n";
192 
193  ParseOptions po(usage);
194 
195  std::string word_syms_filename;
196  std::string wild_syms_rxfilename;
197  std::string wildcard_symbols;
198  std::string lats_wspecifier;
199 
200  po.Register("word-symbol-table", &word_syms_filename,
201  "Symbol table for words [for debug output]");
202  po.Register("wildcard-symbols-list", &wild_syms_rxfilename, "Filename "
203  "(generally rxfilename) for file containing text-form list of "
204  "symbols that don't count as errors; this option requires "
205  "--word-symbol-table. Deprecated; use --wildcard-symbols "
206  "option.");
207  po.Register("wildcard-symbols", &wildcard_symbols,
208  "Colon-separated list of integer ids of symbols that "
209  "don't count as errors. Preferred alternative to deprecated "
210  "option --wildcard-symbols-list.");
211  po.Register("write-lattices", &lats_wspecifier, "If supplied, write the "
212  "lattice that contains only the oracle path to the given "
213  "wspecifier.");
214 
215  po.Read(argc, argv);
216 
217  if (po.NumArgs() != 3 && po.NumArgs() != 4) {
218  po.PrintUsage();
219  exit(1);
220  }
221 
222  std::string lats_rspecifier = po.GetArg(1),
223  reference_rspecifier = po.GetArg(2),
224  transcriptions_wspecifier = po.GetArg(3),
225  edit_distance_wspecifier = po.GetOptArg(4);
226 
227  // will read input as lattices
228  SequentialLatticeReader lattice_reader(lats_rspecifier);
229  RandomAccessInt32VectorReader reference_reader(reference_rspecifier);
230  Int32VectorWriter transcriptions_writer(transcriptions_wspecifier);
231  Int32Writer edit_distance_writer(edit_distance_wspecifier);
232  CompactLatticeWriter lats_writer(lats_wspecifier);
233 
234  fst::SymbolTable *word_syms = NULL;
235  if (word_syms_filename != "")
236  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
237  KALDI_ERR << "Could not read symbol table from file "
238  << word_syms_filename;
239 
240  LabelPairVector wildcards;
241  if (wild_syms_rxfilename != "") {
242  KALDI_WARN << "--wildcard-symbols-list option deprecated.";
243  KALDI_ASSERT(wildcard_symbols.empty() && "Do not use both "
244  "--wildcard-symbols and --wildcard-symbols-list options.");
245  KALDI_ASSERT(word_syms != NULL && "--wildcard-symbols-list option "
246  "requires --word-symbol-table option");
247  ReadSymbolList(wild_syms_rxfilename, word_syms, &wildcards);
248  } else {
249  std::vector<fst::StdArc::Label> wildcard_symbols_vec;
250  if (!SplitStringToIntegers(wildcard_symbols, ":", false,
251  &wildcard_symbols_vec)) {
252  KALDI_ERR << "Expected colon-separated list of integers for "
253  << "--wildcard-symbols option, got: " << wildcard_symbols;
254  }
255  for (size_t i = 0; i < wildcard_symbols_vec.size(); i++)
256  wildcards.emplace_back(wildcard_symbols_vec[i], 0);
257  }
258 
259  int32 n_done = 0, n_fail = 0;
260  int32 tot_correct = 0, tot_substitutions = 0,
261  tot_insertions = 0, tot_deletions = 0, tot_words = 0;
262 
263  for (; !lattice_reader.Done(); lattice_reader.Next()) {
264  std::string key = lattice_reader.Key();
265  const Lattice &lat = lattice_reader.Value();
266  std::cerr << "Lattice " << key << " read." << std::endl;
267 
268  // remove all weights while creating a standard FST
269  VectorFst<StdArc> lattice_fst;
270  ConvertLatticeToUnweightedAcceptor(lat, wildcards, &lattice_fst);
271  CheckFst(lattice_fst, "lattice_fst_", key);
272 
273  // TODO: map certain symbols (using an FST created with CreateMapFst())
274  if (!reference_reader.HasKey(key)) {
275  KALDI_WARN << "No reference present for utterance " << key;
276  n_fail++;
277  continue;
278  }
279  const std::vector<int32> &reference = reference_reader.Value(key);
280  VectorFst<StdArc> reference_fst;
281  MakeLinearAcceptor(reference, &reference_fst);
282 
283  // Remove any wildcards in reference.
284  fst::Relabel(&reference_fst, wildcards, wildcards);
285  CheckFst(reference_fst, "reference_fst_", key);
286 
287  // recreate edit distance fst if necessary
288  fst::StdVectorFst edit_distance_fst;
289  CreateEditDistance(lattice_fst, reference_fst, &edit_distance_fst);
290 
291  // compose with edit distance transducer
292  VectorFst<StdArc> edit_ref_fst;
293  fst::Compose(edit_distance_fst, reference_fst, &edit_ref_fst);
294  CheckFst(edit_ref_fst, "composed_", key);
295 
296  // make sure composed FST is input sorted
297  fst::ArcSort(&edit_ref_fst, fst::StdILabelCompare());
298 
299  // compose with previous result
300  VectorFst<StdArc> result_fst;
301  fst::Compose(lattice_fst, edit_ref_fst, &result_fst);
302  CheckFst(result_fst, "result_", key);
303 
304  // find out best path
305  VectorFst<StdArc> best_path;
306  fst::ShortestPath(result_fst, &best_path);
307  CheckFst(best_path, "best_path_", key);
308 
309  if (best_path.Start() == fst::kNoStateId) {
310  KALDI_WARN << "Best-path failed for key " << key;
311  n_fail++;
312  } else {
313  // count errors
314  int32 correct, substitutions, insertions, deletions, num_words;
315  CountErrors(best_path, &correct, &substitutions,
316  &insertions, &deletions, &num_words);
317  int32 tot_errs = substitutions + insertions + deletions;
318  if (edit_distance_wspecifier != "")
319  edit_distance_writer.Write(key, tot_errs);
320  KALDI_LOG << "%WER " << (100.*tot_errs) / num_words << " [ " << tot_errs
321  << " / " << num_words << ", " << insertions << " insertions, "
322  << deletions << " deletions, " << substitutions << " sub ]";
323  tot_correct += correct;
324  tot_substitutions += substitutions;
325  tot_insertions += insertions;
326  tot_deletions += deletions;
327  tot_words += num_words;
328 
329  std::vector<int32> oracle_words;
330  std::vector<int32> reference_words;
331  Weight weight;
332  GetLinearSymbolSequence(best_path, &oracle_words,
333  &reference_words, &weight);
334  KALDI_LOG << "For utterance " << key << ", best cost " << weight;
335  if (transcriptions_wspecifier != "")
336  transcriptions_writer.Write(key, oracle_words);
337  if (word_syms != NULL) {
338  std::cerr << key << " (oracle) ";
339  for (size_t i = 0; i < oracle_words.size(); i++) {
340  std::string s = word_syms->Find(oracle_words[i]);
341  if (s == "")
342  KALDI_ERR << "Word-id " << oracle_words[i]
343  << " not in symbol table.";
344  std::cerr << s << ' ';
345  }
346  std::cerr << '\n' << key << " (reference) ";
347  for (size_t i = 0; i < reference_words.size(); i++) {
348  std::string s = word_syms->Find(reference_words[i]);
349  if (s == "")
350  KALDI_ERR << "Word-id " << reference_words[i]
351  << " not in symbol table.";
352  std::cerr << s << ' ';
353  }
354  std::cerr << '\n';
355  }
356 
357  // If requested, write the lattice that only contains the oracle path.
358  if (lats_wspecifier != "") {
359  CompactLattice oracle_clat_mask;
360  MakeLinearAcceptor(oracle_words, &oracle_clat_mask);
361 
362  CompactLattice clat;
363  CompactLattice oracle_clat;
364  ConvertLattice(lat, &clat);
365  fst::Relabel(&clat, wildcards, LabelPairVector());
366  fst::ArcSort(&clat, fst::ILabelCompare<CompactLatticeArc>());
367  fst::Compose(oracle_clat_mask, clat, &oracle_clat_mask);
368  fst::ShortestPath(oracle_clat_mask, &oracle_clat);
369  fst::Project(&oracle_clat, fst::PROJECT_OUTPUT);
370  TopSortCompactLatticeIfNeeded(&oracle_clat);
371 
372  if (oracle_clat.Start() == fst::kNoStateId) {
373  KALDI_WARN << "Failed to find the oracle path in the original "
374  << "lattice: " << key;
375  } else {
376  lats_writer.Write(key, oracle_clat);
377  }
378  }
379  }
380  n_done++;
381  }
382  delete word_syms;
383  int32 tot_errs = tot_substitutions + tot_deletions + tot_insertions;
384  // Warning: the script egs/s5/*/steps/oracle_wer.sh parses the next line.
385  KALDI_LOG << "Overall %WER " << (100.*tot_errs)/tot_words << " [ "
386  << tot_errs << " / " << tot_words << ", " << tot_insertions
387  << " insertions, " << tot_deletions << " deletions, "
388  << tot_substitutions << " substitutions ]";
389  KALDI_LOG << "Scored " << n_done << " lattices, " << n_fail
390  << " not present in ref.";
391  } catch(const std::exception &e) {
392  std::cerr << e.what();
393  return -1;
394  }
395 }
fst::StdArc::StateId StateId
fst::StdArc::Label Label
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ReadSymbolList(const std::string &rxfilename, fst::SymbolTable *word_syms, LabelPairVector *lpairs)
Lattice::StateId StateId
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
fst::StdArc StdArc
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
fst::StdVectorFst StdVectorFst
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
void GetInputSymbols(const Fst< Arc > &fst, bool include_eps, std::vector< I > *symbols)
GetInputSymbols gets the list of symbols on the input of fst (including epsilon, if include_eps == tr...
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void MakeLinearAcceptor(const std::vector< I > &labels, MutableFst< Arc > *ofst)
Creates unweighted linear acceptor from symbol sequence.
void ConvertLatticeToUnweightedAcceptor(const kaldi::Lattice &ilat, const LabelPairVector &wildcards, fst::StdVectorFst *ofst)
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
bool CheckFst(const fst::StdVectorFst &fst, string name, string key)
const T & Value(const std::string &key)
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
fst::StdArc::Weight Weight
int main(int argc, char *argv[])
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void CreateEditDistance(const fst::StdVectorFst &fst1, const fst::StdVectorFst &fst2, fst::StdVectorFst *pfst)
Arc::Weight Weight
Definition: kws-search.cc:31
std::vector< std::pair< Label, Label > > LabelPairVector
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::string PrintableRxfilename(const std::string &rxfilename)
PrintableRxfilename turns the rxfilename into a more human-readable form for error reporting...
Definition: kaldi-io.cc:61
void TopSortCompactLatticeIfNeeded(CompactLattice *clat)
Topologically sort the compact lattice if not already topologically sorted.
#define KALDI_LOG
Definition: kaldi-error.h:153
void GetOutputSymbols(const Fst< Arc > &fst, bool include_eps, std::vector< I > *symbols)
GetOutputSymbols gets the list of symbols on the output of fst (including epsilon, if include_eps == true)
void CountErrors(const fst::StdVectorFst &fst, int32 *correct, int32 *substitutions, int32 *insertions, int32 *deletions, int32 *num_words)
std::string GetOptArg(int param) const