lattice-oracle.cc File Reference
Include dependency graph for lattice-oracle.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Typedefs

typedef fst::StdArc::Label Label
 
typedef std::vector< std::pair< Label, Label > > LabelPairVector
 

Functions

void ReadSymbolList (const std::string &rxfilename, fst::SymbolTable *word_syms, LabelPairVector *lpairs)
 
void ConvertLatticeToUnweightedAcceptor (const kaldi::Lattice &ilat, const LabelPairVector &wildcards, fst::StdVectorFst *ofst)
 
void CreateEditDistance (const fst::StdVectorFst &fst1, const fst::StdVectorFst &fst2, fst::StdVectorFst *pfst)
 
void CountErrors (const fst::StdVectorFst &fst, int32 *correct, int32 *substitutions, int32 *insertions, int32 *deletions, int32 *num_words)
 
bool CheckFst (const fst::StdVectorFst &fst, string name, string key)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 165 of file lattice-oracle.cc.

References kaldi::CheckFst(), fst::ConvertLattice(), kaldi::ConvertLatticeToUnweightedAcceptor(), kaldi::CountErrors(), kaldi::CreateEditDistance(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), fst::GetLinearSymbolSequence(), ParseOptions::GetOptArg(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), fst::MakeLinearAcceptor(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadSymbolList(), ParseOptions::Register(), kaldi::SplitStringToIntegers(), kaldi::TopSortCompactLatticeIfNeeded(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

165  {
166  try {
167  using namespace kaldi;
168  using fst::SymbolTable;
169  using fst::VectorFst;
170  using fst::StdArc;
171  typedef kaldi::int32 int32;
172  typedef kaldi::int64 int64;
173  typedef fst::StdArc::Weight Weight;
175 
176  const char *usage =
177  "Finds the path having the smallest edit-distance between a lattice\n"
178  "and a reference string.\n"
179  "\n"
180  "Usage: lattice-oracle [options] <test-lattice-rspecifier> \\\n"
181  " <reference-rspecifier> \\\n"
182  " <transcriptions-wspecifier> \\\n"
183  " [<edit-distance-wspecifier>]\n"
184  " e.g.: lattice-oracle ark:lat.1 'ark:sym2int.pl -f 2- \\\n"
185  " data/lang/words.txt <data/test/text|' ark,t:-\n"
186  "\n"
187  "Note the --write-lattices option by which you can write out the\n"
188  "optimal path as a lattice.\n"
189  "Note: you can use this program to compute the n-best oracle WER by\n"
190  "first piping the input lattices through lattice-to-nbest and then\n"
191  "nbest-to-lattice.\n";
192 
193  ParseOptions po(usage);
194 
195  std::string word_syms_filename;
196  std::string wild_syms_rxfilename;
197  std::string wildcard_symbols;
198  std::string lats_wspecifier;
199 
200  po.Register("word-symbol-table", &word_syms_filename,
201  "Symbol table for words [for debug output]");
202  po.Register("wildcard-symbols-list", &wild_syms_rxfilename, "Filename "
203  "(generally rxfilename) for file containing text-form list of "
204  "symbols that don't count as errors; this option requires "
205  "--word-symbol-table. Deprecated; use --wildcard-symbols "
206  "option.");
207  po.Register("wildcard-symbols", &wildcard_symbols,
208  "Colon-separated list of integer ids of symbols that "
209  "don't count as errors. Preferred alternative to deprecated "
210  "option --wildcard-symbols-list.");
211  po.Register("write-lattices", &lats_wspecifier, "If supplied, write the "
212  "lattice that contains only the oracle path to the given "
213  "wspecifier.");
214 
215  po.Read(argc, argv);
216 
217  if (po.NumArgs() != 3 && po.NumArgs() != 4) {
218  po.PrintUsage();
219  exit(1);
220  }
221 
222  std::string lats_rspecifier = po.GetArg(1),
223  reference_rspecifier = po.GetArg(2),
224  transcriptions_wspecifier = po.GetArg(3),
225  edit_distance_wspecifier = po.GetOptArg(4);
226 
227  // will read input as lattices
228  SequentialLatticeReader lattice_reader(lats_rspecifier);
229  RandomAccessInt32VectorReader reference_reader(reference_rspecifier);
230  Int32VectorWriter transcriptions_writer(transcriptions_wspecifier);
231  Int32Writer edit_distance_writer(edit_distance_wspecifier);
232  CompactLatticeWriter lats_writer(lats_wspecifier);
233 
234  fst::SymbolTable *word_syms = NULL;
235  if (word_syms_filename != "")
236  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
237  KALDI_ERR << "Could not read symbol table from file "
238  << word_syms_filename;
239 
240  LabelPairVector wildcards;
241  if (wild_syms_rxfilename != "") {
242  KALDI_WARN << "--wildcard-symbols-list option deprecated.";
243  KALDI_ASSERT(wildcard_symbols.empty() && "Do not use both "
244  "--wildcard-symbols and --wildcard-symbols-list options.");
245  KALDI_ASSERT(word_syms != NULL && "--wildcard-symbols-list option "
246  "requires --word-symbol-table option");
247  ReadSymbolList(wild_syms_rxfilename, word_syms, &wildcards);
248  } else {
249  std::vector<fst::StdArc::Label> wildcard_symbols_vec;
250  if (!SplitStringToIntegers(wildcard_symbols, ":", false,
251  &wildcard_symbols_vec)) {
252  KALDI_ERR << "Expected colon-separated list of integers for "
253  << "--wildcard-symbols option, got: " << wildcard_symbols;
254  }
255  for (size_t i = 0; i < wildcard_symbols_vec.size(); i++)
256  wildcards.emplace_back(wildcard_symbols_vec[i], 0);
257  }
258 
259  int32 n_done = 0, n_fail = 0;
260  int32 tot_correct = 0, tot_substitutions = 0,
261  tot_insertions = 0, tot_deletions = 0, tot_words = 0;
262 
263  for (; !lattice_reader.Done(); lattice_reader.Next()) {
264  std::string key = lattice_reader.Key();
265  const Lattice &lat = lattice_reader.Value();
266  std::cerr << "Lattice " << key << " read." << std::endl;
267 
268  // remove all weights while creating a standard FST
269  VectorFst<StdArc> lattice_fst;
270  ConvertLatticeToUnweightedAcceptor(lat, wildcards, &lattice_fst);
271  CheckFst(lattice_fst, "lattice_fst_", key);
272 
273  // TODO: map certain symbols (using an FST created with CreateMapFst())
274  if (!reference_reader.HasKey(key)) {
275  KALDI_WARN << "No reference present for utterance " << key;
276  n_fail++;
277  continue;
278  }
279  const std::vector<int32> &reference = reference_reader.Value(key);
280  VectorFst<StdArc> reference_fst;
281  MakeLinearAcceptor(reference, &reference_fst);
282 
283  // Remove any wildcards in reference.
284  fst::Relabel(&reference_fst, wildcards, wildcards);
285  CheckFst(reference_fst, "reference_fst_", key);
286 
287  // recreate edit distance fst if necessary
288  fst::StdVectorFst edit_distance_fst;
289  CreateEditDistance(lattice_fst, reference_fst, &edit_distance_fst);
290 
291  // compose with edit distance transducer
292  VectorFst<StdArc> edit_ref_fst;
293  fst::Compose(edit_distance_fst, reference_fst, &edit_ref_fst);
294  CheckFst(edit_ref_fst, "composed_", key);
295 
296  // make sure composed FST is input sorted
297  fst::ArcSort(&edit_ref_fst, fst::StdILabelCompare());
298 
299  // compose with previous result
300  VectorFst<StdArc> result_fst;
301  fst::Compose(lattice_fst, edit_ref_fst, &result_fst);
302  CheckFst(result_fst, "result_", key);
303 
304  // find out best path
305  VectorFst<StdArc> best_path;
306  fst::ShortestPath(result_fst, &best_path);
307  CheckFst(best_path, "best_path_", key);
308 
309  if (best_path.Start() == fst::kNoStateId) {
310  KALDI_WARN << "Best-path failed for key " << key;
311  n_fail++;
312  } else {
313  // count errors
314  int32 correct, substitutions, insertions, deletions, num_words;
315  CountErrors(best_path, &correct, &substitutions,
316  &insertions, &deletions, &num_words);
317  int32 tot_errs = substitutions + insertions + deletions;
318  if (edit_distance_wspecifier != "")
319  edit_distance_writer.Write(key, tot_errs);
320  KALDI_LOG << "%WER " << (100.*tot_errs) / num_words << " [ " << tot_errs
321  << " / " << num_words << ", " << insertions << " insertions, "
322  << deletions << " deletions, " << substitutions << " sub ]";
323  tot_correct += correct;
324  tot_substitutions += substitutions;
325  tot_insertions += insertions;
326  tot_deletions += deletions;
327  tot_words += num_words;
328 
329  std::vector<int32> oracle_words;
330  std::vector<int32> reference_words;
331  Weight weight;
332  GetLinearSymbolSequence(best_path, &oracle_words,
333  &reference_words, &weight);
334  KALDI_LOG << "For utterance " << key << ", best cost " << weight;
335  if (transcriptions_wspecifier != "")
336  transcriptions_writer.Write(key, oracle_words);
337  if (word_syms != NULL) {
338  std::cerr << key << " (oracle) ";
339  for (size_t i = 0; i < oracle_words.size(); i++) {
340  std::string s = word_syms->Find(oracle_words[i]);
341  if (s == "")
342  KALDI_ERR << "Word-id " << oracle_words[i]
343  << " not in symbol table.";
344  std::cerr << s << ' ';
345  }
346  std::cerr << '\n' << key << " (reference) ";
347  for (size_t i = 0; i < reference_words.size(); i++) {
348  std::string s = word_syms->Find(reference_words[i]);
349  if (s == "")
350  KALDI_ERR << "Word-id " << reference_words[i]
351  << " not in symbol table.";
352  std::cerr << s << ' ';
353  }
354  std::cerr << '\n';
355  }
356 
357  // If requested, write the lattice that only contains the oracle path.
358  if (lats_wspecifier != "") {
359  CompactLattice oracle_clat_mask;
360  MakeLinearAcceptor(oracle_words, &oracle_clat_mask);
361 
362  CompactLattice clat;
363  CompactLattice oracle_clat;
364  ConvertLattice(lat, &clat);
365  fst::Relabel(&clat, wildcards, LabelPairVector());
366  fst::ArcSort(&clat, fst::ILabelCompare<CompactLatticeArc>());
367  fst::Compose(oracle_clat_mask, clat, &oracle_clat_mask);
368  fst::ShortestPath(oracle_clat_mask, &oracle_clat);
369  fst::Project(&oracle_clat, fst::PROJECT_OUTPUT);
370  TopSortCompactLatticeIfNeeded(&oracle_clat);
371 
372  if (oracle_clat.Start() == fst::kNoStateId) {
373  KALDI_WARN << "Failed to find the oracle path in the original "
374  << "lattice: " << key;
375  } else {
376  lats_writer.Write(key, oracle_clat);
377  }
378  }
379  }
380  n_done++;
381  }
382  delete word_syms;
383  int32 tot_errs = tot_substitutions + tot_deletions + tot_insertions;
384  // Warning: the script egs/s5/*/steps/oracle_wer.sh parses the next line.
385  KALDI_LOG << "Overall %WER " << (100.*tot_errs)/tot_words << " [ "
386  << tot_errs << " / " << tot_words << ", " << tot_insertions
387  << " insertions, " << tot_deletions << " deletions, "
388  << tot_substitutions << " substitutions ]";
389  KALDI_LOG << "Scored " << n_done << " lattices, " << n_fail
390  << " not present in ref.";
391  } catch(const std::exception &e) {
392  std::cerr << e.what();
393  return -1;
394  }
395 }
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ReadSymbolList(const std::string &rxfilename, fst::SymbolTable *word_syms, LabelPairVector *lpairs)
Lattice::StateId StateId
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
fst::StdArc StdArc
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
fst::StdVectorFst StdVectorFst
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void MakeLinearAcceptor(const std::vector< I > &labels, MutableFst< Arc > *ofst)
Creates unweighted linear acceptor from symbol sequence.
void ConvertLatticeToUnweightedAcceptor(const kaldi::Lattice &ilat, const LabelPairVector &wildcards, fst::StdVectorFst *ofst)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
bool CheckFst(const fst::StdVectorFst &fst, string name, string key)
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
fst::StdArc::Weight Weight
void CreateEditDistance(const fst::StdVectorFst &fst1, const fst::StdVectorFst &fst2, fst::StdVectorFst *pfst)
Arc::Weight Weight
Definition: kws-search.cc:31
std::vector< std::pair< Label, Label > > LabelPairVector
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void TopSortCompactLatticeIfNeeded(CompactLattice *clat)
Topologically sort the compact lattice if not already topologically sorted.
#define KALDI_LOG
Definition: kaldi-error.h:153
void CountErrors(const fst::StdVectorFst &fst, int32 *correct, int32 *substitutions, int32 *insertions, int32 *deletions, int32 *num_words)