lattice-faster-decoder.cc
Go to the documentation of this file.
1 // decoder/lattice-faster-decoder.cc
2 
3 // Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
4 // 2013-2018 Johns Hopkins University (Author: Daniel Povey)
5 // 2014 Guoguo Chen
6 // 2018 Zhehuai Chen
7 
8 // See ../../COPYING for clarification regarding multiple authors
9 //
10 // Licensed under the Apache License, Version 2.0 (the "License");
11 // you may not use this file except in compliance with the License.
12 // You may obtain a copy of the License at
13 //
14 // http://www.apache.org/licenses/LICENSE-2.0
15 //
16 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
18 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
19 // MERCHANTABLITY OR NON-INFRINGEMENT.
20 // See the Apache 2 License for the specific language governing permissions and
21 // limitations under the License.
22 
24 #include "lat/lattice-functions.h"
25 
26 namespace kaldi {
27 
28 // instantiate this class once for each thing you have to decode.
29 template <typename FST, typename Token>
31  const FST &fst,
32  const LatticeFasterDecoderConfig &config):
33  fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) {
34  config.Check();
35  toks_.SetSize(1000); // just so on the first frame we do something reasonable.
36 }
37 
38 
39 template <typename FST, typename Token>
41  const LatticeFasterDecoderConfig &config, FST *fst):
42  fst_(fst), delete_fst_(true), config_(config), num_toks_(0) {
43  config.Check();
44  toks_.SetSize(1000); // just so on the first frame we do something reasonable.
45 }
46 
47 
48 template <typename FST, typename Token>
50  DeleteElems(toks_.Clear());
51  ClearActiveTokens();
52  if (delete_fst_) delete fst_;
53 }
54 
55 template <typename FST, typename Token>
57  // clean up from last time:
58  DeleteElems(toks_.Clear());
59  cost_offsets_.clear();
60  ClearActiveTokens();
61  warned_ = false;
62  num_toks_ = 0;
63  decoding_finalized_ = false;
64  final_costs_.clear();
65  StateId start_state = fst_->Start();
66  KALDI_ASSERT(start_state != fst::kNoStateId);
67  active_toks_.resize(1);
68  Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL);
69  active_toks_[0].toks = start_tok;
70  toks_.Insert(start_state, start_tok);
71  num_toks_++;
72  ProcessNonemitting(config_.beam);
73 }
74 
75 // Returns true if any kind of traceback is available (not necessarily from
76 // a final state). It should only very rarely return false; this indicates
77 // an unusual search error.
78 template <typename FST, typename Token>
80  InitDecoding();
81  // We use 1-based indexing for frames in this decoder (if you view it in
82  // terms of features), but note that the decodable object uses zero-based
83  // numbering, which we have to correct for when we call it.
84  AdvanceDecoding(decodable);
85  FinalizeDecoding();
86 
87  // Returns true if we have any kind of traceback available (not necessarily
88  // to the end state; query ReachedFinal() for that).
89  return !active_toks_.empty() && active_toks_.back().toks != NULL;
90 }
91 
92 
93 // Outputs an FST corresponding to the single best path through the lattice.
94 template <typename FST, typename Token>
96  bool use_final_probs) const {
97  Lattice raw_lat;
98  GetRawLattice(&raw_lat, use_final_probs);
99  ShortestPath(raw_lat, olat);
100  return (olat->NumStates() != 0);
101 }
102 
103 
104 // Outputs an FST corresponding to the raw, state-level lattice
105 template <typename FST, typename Token>
107  Lattice *ofst,
108  bool use_final_probs) const {
109  typedef LatticeArc Arc;
110  typedef Arc::StateId StateId;
111  typedef Arc::Weight Weight;
112  typedef Arc::Label Label;
113 
114  // Note: you can't use the old interface (Decode()) if you want to
115  // get the lattice with use_final_probs = false. You'd have to do
116  // InitDecoding() and then AdvanceDecoding().
117  if (decoding_finalized_ && !use_final_probs)
118  KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
119  << "GetRawLattice() with use_final_probs == false";
120 
121  unordered_map<Token*, BaseFloat> final_costs_local;
122 
123  const unordered_map<Token*, BaseFloat> &final_costs =
124  (decoding_finalized_ ? final_costs_ : final_costs_local);
125  if (!decoding_finalized_ && use_final_probs)
126  ComputeFinalCosts(&final_costs_local, NULL, NULL);
127 
128  ofst->DeleteStates();
129  // num-frames plus one (since frames are one-based, and we have
130  // an extra frame for the start-state).
131  int32 num_frames = active_toks_.size() - 1;
132  KALDI_ASSERT(num_frames > 0);
133  const int32 bucket_count = num_toks_/2 + 3;
134  unordered_map<Token*, StateId> tok_map(bucket_count);
135  // First create all states.
136  std::vector<Token*> token_list;
137  for (int32 f = 0; f <= num_frames; f++) {
138  if (active_toks_[f].toks == NULL) {
139  KALDI_WARN << "GetRawLattice: no tokens active on frame " << f
140  << ": not producing lattice.\n";
141  return false;
142  }
143  TopSortTokens(active_toks_[f].toks, &token_list);
144  for (size_t i = 0; i < token_list.size(); i++)
145  if (token_list[i] != NULL)
146  tok_map[token_list[i]] = ofst->AddState();
147  }
148  // The next statement sets the start state of the output FST. Because we
149  // topologically sorted the tokens, state zero must be the start-state.
150  ofst->SetStart(0);
151 
152  KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:"
153  << tok_map.bucket_count() << " load:" << tok_map.load_factor()
154  << " max:" << tok_map.max_load_factor();
155  // Now create all arcs.
156  for (int32 f = 0; f <= num_frames; f++) {
157  for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) {
158  StateId cur_state = tok_map[tok];
159  for (ForwardLinkT *l = tok->links;
160  l != NULL;
161  l = l->next) {
162  typename unordered_map<Token*, StateId>::const_iterator
163  iter = tok_map.find(l->next_tok);
164  StateId nextstate = iter->second;
165  KALDI_ASSERT(iter != tok_map.end());
166  BaseFloat cost_offset = 0.0;
167  if (l->ilabel != 0) { // emitting..
168  KALDI_ASSERT(f >= 0 && f < cost_offsets_.size());
169  cost_offset = cost_offsets_[f];
170  }
171  Arc arc(l->ilabel, l->olabel,
172  Weight(l->graph_cost, l->acoustic_cost - cost_offset),
173  nextstate);
174  ofst->AddArc(cur_state, arc);
175  }
176  if (f == num_frames) {
177  if (use_final_probs && !final_costs.empty()) {
178  typename unordered_map<Token*, BaseFloat>::const_iterator
179  iter = final_costs.find(tok);
180  if (iter != final_costs.end())
181  ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
182  } else {
183  ofst->SetFinal(cur_state, LatticeWeight::One());
184  }
185  }
186  }
187  }
188  return (ofst->NumStates() > 0);
189 }
190 
191 
192 // This function is now deprecated, since now we do determinization from outside
193 // the LatticeFasterDecoder class. Outputs an FST corresponding to the
194 // lattice-determinized lattice (one path per word sequence).
195 template <typename FST, typename Token>
197  bool use_final_probs) const {
198  Lattice raw_fst;
199  GetRawLattice(&raw_fst, use_final_probs);
200  Invert(&raw_fst); // make it so word labels are on the input.
201  // (in phase where we get backward-costs).
202  fst::ILabelCompare<LatticeArc> ilabel_comp;
203  ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes
204  // lattice-determinization more efficient.
205 
207  lat_opts.max_mem = config_.det_opts.max_mem;
208 
209  DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts);
210  raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed.
211  Connect(ofst); // Remove unreachable states... there might be
212  // a small number of these, in some cases.
213  // Note: if something went wrong and the raw lattice was empty,
214  // we should still get to this point in the code without warnings or failures.
215  return (ofst->NumStates() != 0);
216 }
217 
218 template <typename FST, typename Token>
220  size_t new_sz = static_cast<size_t>(static_cast<BaseFloat>(num_toks)
221  * config_.hash_ratio);
222  if (new_sz > toks_.Size()) {
223  toks_.SetSize(new_sz);
224  }
225 }
226 
227 /*
228  A note on the definition of extra_cost.
229 
230  extra_cost is used in pruning tokens, to save memory.
231 
232  extra_cost can be thought of as a beta (backward) cost assuming
233  we had set the betas on currently-active tokens to all be the negative
234  of the alphas for those tokens. (So all currently active tokens would
235  be on (tied) best paths).
236 
237  We can use the extra_cost to accurately prune away tokens that we know will
238  never appear in the lattice. If the extra_cost is greater than the desired
239  lattice beam, the token would provably never appear in the lattice, so we can
240  prune away the token.
241 
242  (Note: we don't update all the extra_costs every time we update a frame; we
243  only do it every 'config_.prune_interval' frames).
244  */
245 
246 // FindOrAddToken either locates a token in hash of toks_,
247 // or if necessary inserts a new, empty token (i.e. with no forward links)
248 // for the current frame. [note: it's inserted if necessary into hash toks_
249 // and also into the singly linked list of tokens active on this frame
250 // (whose head is at active_toks_[frame]).
251 template <typename FST, typename Token>
254  StateId state, int32 frame_plus_one, BaseFloat tot_cost,
255  Token *backpointer, bool *changed) {
256  // Returns the Token pointer. Sets "changed" (if non-NULL) to true
257  // if the token was newly created or the cost changed.
258  KALDI_ASSERT(frame_plus_one < active_toks_.size());
259  Token *&toks = active_toks_[frame_plus_one].toks;
260  Elem *e_found = toks_.Insert(state, NULL);
261  if (e_found->val == NULL) { // no such token presently.
262  const BaseFloat extra_cost = 0.0;
263  // tokens on the currently final frame have zero extra_cost
264  // as any of them could end up
265  // on the winning path.
266  Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer);
267  // NULL: no forward links yet
268  toks = new_tok;
269  num_toks_++;
270  e_found->val = new_tok;
271  if (changed) *changed = true;
272  return e_found;
273  } else {
274  Token *tok = e_found->val; // There is an existing Token for this state.
275  if (tok->tot_cost > tot_cost) { // replace old token
276  tok->tot_cost = tot_cost;
277  // SetBackpointer() just does tok->backpointer = backpointer in
278  // the case where Token == BackpointerToken, else nothing.
279  tok->SetBackpointer(backpointer);
280  // we don't allocate a new token, the old stays linked in active_toks_
281  // we only replace the tot_cost
282  // in the current frame, there are no forward links (and no extra_cost)
283  // only in ProcessNonemitting we have to delete forward links
284  // in case we visit a state for the second time
285  // those forward links, that lead to this replaced token before:
286  // they remain and will hopefully be pruned later (PruneForwardLinks...)
287  if (changed) *changed = true;
288  } else {
289  if (changed) *changed = false;
290  }
291  return e_found;
292  }
293 }
294 
295 // prunes outgoing links for all tokens in active_toks_[frame]
296 // it's called by PruneActiveTokens
297 // all links, that have link_extra_cost > lattice_beam are pruned
298 template <typename FST, typename Token>
300  int32 frame_plus_one, bool *extra_costs_changed,
301  bool *links_pruned, BaseFloat delta) {
302  // delta is the amount by which the extra_costs must change
303  // If delta is larger, we'll tend to go back less far
304  // toward the beginning of the file.
305  // extra_costs_changed is set to true if extra_cost was changed for any token
306  // links_pruned is set to true if any link in any token was pruned
307 
308  *extra_costs_changed = false;
309  *links_pruned = false;
310  KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
311  if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen.
312  if (!warned_) {
313  KALDI_WARN << "No tokens alive [doing pruning].. warning first "
314  "time only for each utterance\n";
315  warned_ = true;
316  }
317  }
318 
319  // We have to iterate until there is no more change, because the links
320  // are not guaranteed to be in topological order.
321  bool changed = true; // difference new minus old extra cost >= delta ?
322  while (changed) {
323  changed = false;
324  for (Token *tok = active_toks_[frame_plus_one].toks;
325  tok != NULL; tok = tok->next) {
326  ForwardLinkT *link, *prev_link = NULL;
327  // will recompute tok_extra_cost for tok.
328  BaseFloat tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
329  // tok_extra_cost is the best (min) of link_extra_cost of outgoing links
330  for (link = tok->links; link != NULL; ) {
331  // See if we need to excise this link...
332  Token *next_tok = link->next_tok;
333  BaseFloat link_extra_cost = next_tok->extra_cost +
334  ((tok->tot_cost + link->acoustic_cost + link->graph_cost)
335  - next_tok->tot_cost); // difference in brackets is >= 0
336  // link_exta_cost is the difference in score between the best paths
337  // through link source state and through link destination state
338  KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN
339  if (link_extra_cost > config_.lattice_beam) { // excise link
340  ForwardLinkT *next_link = link->next;
341  if (prev_link != NULL) prev_link->next = next_link;
342  else tok->links = next_link;
343  delete link;
344  link = next_link; // advance link but leave prev_link the same.
345  *links_pruned = true;
346  } else { // keep the link and update the tok_extra_cost if needed.
347  if (link_extra_cost < 0.0) { // this is just a precaution.
348  if (link_extra_cost < -0.01)
349  KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
350  link_extra_cost = 0.0;
351  }
352  if (link_extra_cost < tok_extra_cost)
353  tok_extra_cost = link_extra_cost;
354  prev_link = link; // move to next link
355  link = link->next;
356  }
357  } // for all outgoing links
358  if (fabs(tok_extra_cost - tok->extra_cost) > delta)
359  changed = true; // difference new minus old is bigger than delta
360  tok->extra_cost = tok_extra_cost;
361  // will be +infinity or <= lattice_beam_.
362  // infinity indicates, that no forward link survived pruning
363  } // for all Token on active_toks_[frame]
364  if (changed) *extra_costs_changed = true;
365 
366  // Note: it's theoretically possible that aggressive compiler
367  // optimizations could cause an infinite loop here for small delta and
368  // high-dynamic-range scores.
369  } // while changed
370 }
371 
372 // PruneForwardLinksFinal is a version of PruneForwardLinks that we call
373 // on the final frame. If there are final tokens active, it uses
374 // the final-probs for pruning, otherwise it treats all tokens as final.
375 template <typename FST, typename Token>
377  KALDI_ASSERT(!active_toks_.empty());
378  int32 frame_plus_one = active_toks_.size() - 1;
379 
380  if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen.
381  KALDI_WARN << "No tokens alive at end of file";
382 
383  typedef typename unordered_map<Token*, BaseFloat>::const_iterator IterType;
384  ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_);
385  decoding_finalized_ = true;
386  // We call DeleteElems() as a nicety, not because it's really necessary;
387  // otherwise there would be a time, after calling PruneTokensForFrame() on the
388  // final frame, when toks_.GetList() or toks_.Clear() would contain pointers
389  // to nonexistent tokens.
390  DeleteElems(toks_.Clear());
391 
392  // Now go through tokens on this frame, pruning forward links... may have to
393  // iterate a few times until there is no more change, because the list is not
394  // in topological order. This is a modified version of the code in
395  // PruneForwardLinks, but here we also take account of the final-probs.
396  bool changed = true;
397  BaseFloat delta = 1.0e-05;
398  while (changed) {
399  changed = false;
400  for (Token *tok = active_toks_[frame_plus_one].toks;
401  tok != NULL; tok = tok->next) {
402  ForwardLinkT *link, *prev_link = NULL;
403  // will recompute tok_extra_cost. It has a term in it that corresponds
404  // to the "final-prob", so instead of initializing tok_extra_cost to infinity
405  // below we set it to the difference between the (score+final_prob) of this token,
406  // and the best such (score+final_prob).
407  BaseFloat final_cost;
408  if (final_costs_.empty()) {
409  final_cost = 0.0;
410  } else {
411  IterType iter = final_costs_.find(tok);
412  if (iter != final_costs_.end())
413  final_cost = iter->second;
414  else
415  final_cost = std::numeric_limits<BaseFloat>::infinity();
416  }
417  BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_;
418  // tok_extra_cost will be a "min" over either directly being final, or
419  // being indirectly final through other links, and the loop below may
420  // decrease its value:
421  for (link = tok->links; link != NULL; ) {
422  // See if we need to excise this link...
423  Token *next_tok = link->next_tok;
424  BaseFloat link_extra_cost = next_tok->extra_cost +
425  ((tok->tot_cost + link->acoustic_cost + link->graph_cost)
426  - next_tok->tot_cost);
427  if (link_extra_cost > config_.lattice_beam) { // excise link
428  ForwardLinkT *next_link = link->next;
429  if (prev_link != NULL) prev_link->next = next_link;
430  else tok->links = next_link;
431  delete link;
432  link = next_link; // advance link but leave prev_link the same.
433  } else { // keep the link and update the tok_extra_cost if needed.
434  if (link_extra_cost < 0.0) { // this is just a precaution.
435  if (link_extra_cost < -0.01)
436  KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
437  link_extra_cost = 0.0;
438  }
439  if (link_extra_cost < tok_extra_cost)
440  tok_extra_cost = link_extra_cost;
441  prev_link = link;
442  link = link->next;
443  }
444  }
445  // prune away tokens worse than lattice_beam above best path. This step
446  // was not necessary in the non-final case because then, this case
447  // showed up as having no forward links. Here, the tok_extra_cost has
448  // an extra component relating to the final-prob.
449  if (tok_extra_cost > config_.lattice_beam)
450  tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
451  // to be pruned in PruneTokensForFrame
452 
453  if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta))
454  changed = true;
455  tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_.
456  }
457  } // while changed
458 }
459 
460 template <typename FST, typename Token>
462  if (!decoding_finalized_) {
463  BaseFloat relative_cost;
464  ComputeFinalCosts(NULL, &relative_cost, NULL);
465  return relative_cost;
466  } else {
467  // we're not allowed to call that function if FinalizeDecoding() has
468  // been called; return a cached value.
469  return final_relative_cost_;
470  }
471 }
472 
473 
474 // Prune away any tokens on this frame that have no forward links.
475 // [we don't do this in PruneForwardLinks because it would give us
476 // a problem with dangling pointers].
477 // It's called by PruneActiveTokens if any forward links have been pruned
478 template <typename FST, typename Token>
480  KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
481  Token *&toks = active_toks_[frame_plus_one].toks;
482  if (toks == NULL)
483  KALDI_WARN << "No tokens alive [doing pruning]";
484  Token *tok, *next_tok, *prev_tok = NULL;
485  for (tok = toks; tok != NULL; tok = next_tok) {
486  next_tok = tok->next;
487  if (tok->extra_cost == std::numeric_limits<BaseFloat>::infinity()) {
488  // token is unreachable from end of graph; (no forward links survived)
489  // excise tok from list and delete tok.
490  if (prev_tok != NULL) prev_tok->next = tok->next;
491  else toks = tok->next;
492  delete tok;
493  num_toks_--;
494  } else { // fetch next Token
495  prev_tok = tok;
496  }
497  }
498 }
499 
500 // Go backwards through still-alive tokens, pruning them, starting not from
501 // the current frame (where we want to keep all tokens) but from the frame before
502 // that. We go backwards through the frames and stop when we reach a point
503 // where the delta-costs are not changing (and the delta controls when we consider
504 // a cost to have "not changed").
505 template <typename FST, typename Token>
507  int32 cur_frame_plus_one = NumFramesDecoded();
508  int32 num_toks_begin = num_toks_;
509  // The index "f" below represents a "frame plus one", i.e. you'd have to subtract
510  // one to get the corresponding index for the decodable object.
511  for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) {
512  // Reason why we need to prune forward links in this situation:
513  // (1) we have never pruned them (new TokenList)
514  // (2) we have not yet pruned the forward links to the next f,
515  // after any of those tokens have changed their extra_cost.
516  if (active_toks_[f].must_prune_forward_links) {
517  bool extra_costs_changed = false, links_pruned = false;
518  PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta);
519  if (extra_costs_changed && f > 0) // any token has changed extra_cost
520  active_toks_[f-1].must_prune_forward_links = true;
521  if (links_pruned) // any link was pruned
522  active_toks_[f].must_prune_tokens = true;
523  active_toks_[f].must_prune_forward_links = false; // job done
524  }
525  if (f+1 < cur_frame_plus_one && // except for last f (no forward links)
526  active_toks_[f+1].must_prune_tokens) {
527  PruneTokensForFrame(f+1);
528  active_toks_[f+1].must_prune_tokens = false;
529  }
530  }
531  KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin
532  << " to " << num_toks_;
533 }
534 
535 template <typename FST, typename Token>
537  unordered_map<Token*, BaseFloat> *final_costs,
538  BaseFloat *final_relative_cost,
539  BaseFloat *final_best_cost) const {
540  KALDI_ASSERT(!decoding_finalized_);
541  if (final_costs != NULL)
542  final_costs->clear();
543  const Elem *final_toks = toks_.GetList();
544  BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
545  BaseFloat best_cost = infinity,
546  best_cost_with_final = infinity;
547 
548  while (final_toks != NULL) {
549  StateId state = final_toks->key;
550  Token *tok = final_toks->val;
551  const Elem *next = final_toks->tail;
552  BaseFloat final_cost = fst_->Final(state).Value();
553  BaseFloat cost = tok->tot_cost,
554  cost_with_final = cost + final_cost;
555  best_cost = std::min(cost, best_cost);
556  best_cost_with_final = std::min(cost_with_final, best_cost_with_final);
557  if (final_costs != NULL && final_cost != infinity)
558  (*final_costs)[tok] = final_cost;
559  final_toks = next;
560  }
561  if (final_relative_cost != NULL) {
562  if (best_cost == infinity && best_cost_with_final == infinity) {
563  // Likely this will only happen if there are no tokens surviving.
564  // This seems the least bad way to handle it.
565  *final_relative_cost = infinity;
566  } else {
567  *final_relative_cost = best_cost_with_final - best_cost;
568  }
569  }
570  if (final_best_cost != NULL) {
571  if (best_cost_with_final != infinity) { // final-state exists.
572  *final_best_cost = best_cost_with_final;
573  } else { // no final-state exists.
574  *final_best_cost = best_cost;
575  }
576  }
577 }
578 
579 template <typename FST, typename Token>
581  int32 max_num_frames) {
582  if (std::is_same<FST, fst::Fst<fst::StdArc> >::value) {
583  // if the type 'FST' is the FST base-class, then see if the FST type of fst_
584  // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding()
585  // function after casting *this to the more specific type.
586  if (fst_->Type() == "const") {
588  reinterpret_cast<LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, Token>* >(this);
589  this_cast->AdvanceDecoding(decodable, max_num_frames);
590  return;
591  } else if (fst_->Type() == "vector") {
593  reinterpret_cast<LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, Token>* >(this);
594  this_cast->AdvanceDecoding(decodable, max_num_frames);
595  return;
596  }
597  }
598 
599 
600  KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ &&
601  "You must call InitDecoding() before AdvanceDecoding");
602  int32 num_frames_ready = decodable->NumFramesReady();
603  // num_frames_ready must be >= num_frames_decoded, or else
604  // the number of frames ready must have decreased (which doesn't
605  // make sense) or the decodable object changed between calls
606  // (which isn't allowed).
607  KALDI_ASSERT(num_frames_ready >= NumFramesDecoded());
608  int32 target_frames_decoded = num_frames_ready;
609  if (max_num_frames >= 0)
610  target_frames_decoded = std::min(target_frames_decoded,
611  NumFramesDecoded() + max_num_frames);
612  while (NumFramesDecoded() < target_frames_decoded) {
613  if (NumFramesDecoded() % config_.prune_interval == 0) {
614  PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
615  }
616  BaseFloat cost_cutoff = ProcessEmitting(decodable);
617  ProcessNonemitting(cost_cutoff);
618  }
619 }
620 
621 // FinalizeDecoding() is a version of PruneActiveTokens that we call
622 // (optionally) on the final frame. Takes into account the final-prob of
623 // tokens. This function used to be called PruneActiveTokensFinal().
624 template <typename FST, typename Token>
626  int32 final_frame_plus_one = NumFramesDecoded();
627  int32 num_toks_begin = num_toks_;
628  // PruneForwardLinksFinal() prunes final frame (with final-probs), and
629  // sets decoding_finalized_.
630  PruneForwardLinksFinal();
631  for (int32 f = final_frame_plus_one - 1; f >= 0; f--) {
632  bool b1, b2; // values not used.
633  BaseFloat dontcare = 0.0; // delta of zero means we must always update
634  PruneForwardLinks(f, &b1, &b2, dontcare);
635  PruneTokensForFrame(f + 1);
636  }
637  PruneTokensForFrame(0);
638  KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin
639  << " to " << num_toks_;
640 }
641 
643 template <typename FST, typename Token>
645  BaseFloat *adaptive_beam, Elem **best_elem) {
646  BaseFloat best_weight = std::numeric_limits<BaseFloat>::infinity();
647  // positive == high cost == bad.
648  size_t count = 0;
649  if (config_.max_active == std::numeric_limits<int32>::max() &&
650  config_.min_active == 0) {
651  for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
652  BaseFloat w = static_cast<BaseFloat>(e->val->tot_cost);
653  if (w < best_weight) {
654  best_weight = w;
655  if (best_elem) *best_elem = e;
656  }
657  }
658  if (tok_count != NULL) *tok_count = count;
659  if (adaptive_beam != NULL) *adaptive_beam = config_.beam;
660  return best_weight + config_.beam;
661  } else {
662  tmp_array_.clear();
663  for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
664  BaseFloat w = e->val->tot_cost;
665  tmp_array_.push_back(w);
666  if (w < best_weight) {
667  best_weight = w;
668  if (best_elem) *best_elem = e;
669  }
670  }
671  if (tok_count != NULL) *tok_count = count;
672 
673  BaseFloat beam_cutoff = best_weight + config_.beam,
674  min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
675  max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
676 
677  KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded()
678  << " is " << tmp_array_.size();
679 
680  if (tmp_array_.size() > static_cast<size_t>(config_.max_active)) {
681  std::nth_element(tmp_array_.begin(),
682  tmp_array_.begin() + config_.max_active,
683  tmp_array_.end());
684  max_active_cutoff = tmp_array_[config_.max_active];
685  }
686  if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam.
687  if (adaptive_beam)
688  *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta;
689  return max_active_cutoff;
690  }
691  if (tmp_array_.size() > static_cast<size_t>(config_.min_active)) {
692  if (config_.min_active == 0) min_active_cutoff = best_weight;
693  else {
694  std::nth_element(tmp_array_.begin(),
695  tmp_array_.begin() + config_.min_active,
696  tmp_array_.size() > static_cast<size_t>(config_.max_active) ?
697  tmp_array_.begin() + config_.max_active :
698  tmp_array_.end());
699  min_active_cutoff = tmp_array_[config_.min_active];
700  }
701  }
702  if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam.
703  if (adaptive_beam)
704  *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta;
705  return min_active_cutoff;
706  } else {
707  *adaptive_beam = config_.beam;
708  return beam_cutoff;
709  }
710  }
711 }
712 
713 template <typename FST, typename Token>
715  DecodableInterface *decodable) {
716  KALDI_ASSERT(active_toks_.size() > 0);
717  int32 frame = active_toks_.size() - 1; // frame is the frame-index
718  // (zero-based) used to get likelihoods
719  // from the decodable object.
720  active_toks_.resize(active_toks_.size() + 1);
721 
722  Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_
723  // in simple-decoder.h. Removes the Elems from
724  // being indexed in the hash in toks_.
725  Elem *best_elem = NULL;
726  BaseFloat adaptive_beam;
727  size_t tok_cnt;
728  BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
729  KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is "
730  << adaptive_beam;
731 
732  PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough.
733 
734  BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
735  // pruning "online" before having seen all tokens
736 
737  BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good
738  // dynamic range.
739 
740 
741  // First process the best token to get a hopefully
742  // reasonably tight bound on the next cutoff. The only
743  // products of the next block are "next_cutoff" and "cost_offset".
744  if (best_elem) {
745  StateId state = best_elem->key;
746  Token *tok = best_elem->val;
747  cost_offset = - tok->tot_cost;
748  for (fst::ArcIterator<FST> aiter(*fst_, state);
749  !aiter.Done();
750  aiter.Next()) {
751  const Arc &arc = aiter.Value();
752  if (arc.ilabel != 0) { // propagate..
753  BaseFloat new_weight = arc.weight.Value() + cost_offset -
754  decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost;
755  if (new_weight + adaptive_beam < next_cutoff)
756  next_cutoff = new_weight + adaptive_beam;
757  }
758  }
759  }
760 
761  // Store the offset on the acoustic likelihoods that we're applying.
762  // Could just do cost_offsets_.push_back(cost_offset), but we
763  // do it this way as it's more robust to future code changes.
764  cost_offsets_.resize(frame + 1, 0.0);
765  cost_offsets_[frame] = cost_offset;
766 
767  // the tokens are now owned here, in final_toks, and the hash is empty.
768  // 'owned' is a complex thing here; the point is we need to call DeleteElem
769  // on each elem 'e' to let toks_ know we're done with them.
770  for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) {
771  // loop this way because we delete "e" as we go.
772  StateId state = e->key;
773  Token *tok = e->val;
774  if (tok->tot_cost <= cur_cutoff) {
775  for (fst::ArcIterator<FST> aiter(*fst_, state);
776  !aiter.Done();
777  aiter.Next()) {
778  const Arc &arc = aiter.Value();
779  if (arc.ilabel != 0) { // propagate..
780  BaseFloat ac_cost = cost_offset -
781  decodable->LogLikelihood(frame, arc.ilabel),
782  graph_cost = arc.weight.Value(),
783  cur_cost = tok->tot_cost,
784  tot_cost = cur_cost + ac_cost + graph_cost;
785  if (tot_cost >= next_cutoff) continue;
786  else if (tot_cost + adaptive_beam < next_cutoff)
787  next_cutoff = tot_cost + adaptive_beam; // prune by best current token
788  // Note: the frame indexes into active_toks_ are one-based,
789  // hence the + 1.
790  Elem *e_next = FindOrAddToken(arc.nextstate,
791  frame + 1, tot_cost, tok, NULL);
792  // NULL: no change indicator needed
793 
794  // Add ForwardLink from tok to next_tok (put on head of list tok->links)
795  tok->links = new ForwardLinkT(e_next->val, arc.ilabel, arc.olabel,
796  graph_cost, ac_cost, tok->links);
797  }
798  } // for all arcs
799  }
800  e_tail = e->tail;
801  toks_.Delete(e); // delete Elem
802  }
803  return next_cutoff;
804 }
805 
806 // static inline
807 template <typename FST, typename Token>
809  ForwardLinkT *l = tok->links, *m;
810  while (l != NULL) {
811  m = l->next;
812  delete l;
813  l = m;
814  }
815  tok->links = NULL;
816 }
817 
818 
819 template <typename FST, typename Token>
821  KALDI_ASSERT(!active_toks_.empty());
822  int32 frame = static_cast<int32>(active_toks_.size()) - 2;
823  // Note: "frame" is the time-index we just processed, or -1 if
824  // we are processing the nonemitting transitions before the
825  // first frame (called from InitDecoding()).
826 
827  // Processes nonemitting arcs for one frame. Propagates within toks_.
828  // Note-- this queue structure is not very optimal as
829  // it may cause us to process states unnecessarily (e.g. more than once),
830  // but in the baseline code, turning this vector into a set to fix this
831  // problem did not improve overall speed.
832 
833  KALDI_ASSERT(queue_.empty());
834 
835  if (toks_.GetList() == NULL) {
836  if (!warned_) {
837  KALDI_WARN << "Error, no surviving tokens: frame is " << frame;
838  warned_ = true;
839  }
840  }
841 
842  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
843  StateId state = e->key;
844  if (fst_->NumInputEpsilons(state) != 0)
845  queue_.push_back(e);
846  }
847 
848  while (!queue_.empty()) {
849  const Elem *e = queue_.back();
850  queue_.pop_back();
851 
852  StateId state = e->key;
853  Token *tok = e->val; // would segfault if e is a NULL pointer but this can't happen.
854  BaseFloat cur_cost = tok->tot_cost;
855  if (cur_cost >= cutoff) // Don't bother processing successors.
856  continue;
857  // If "tok" has any existing forward links, delete them,
858  // because we're about to regenerate them. This is a kind
859  // of non-optimality (remember, this is the simple decoder),
860  // but since most states are emitting it's not a huge issue.
861  DeleteForwardLinks(tok); // necessary when re-visiting
862  tok->links = NULL;
863  for (fst::ArcIterator<FST> aiter(*fst_, state);
864  !aiter.Done();
865  aiter.Next()) {
866  const Arc &arc = aiter.Value();
867  if (arc.ilabel == 0) { // propagate nonemitting only...
868  BaseFloat graph_cost = arc.weight.Value(),
869  tot_cost = cur_cost + graph_cost;
870  if (tot_cost < cutoff) {
871  bool changed;
872 
873  Elem *e_new = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
874  tok, &changed);
875 
876  tok->links = new ForwardLinkT(e_new->val, 0, arc.olabel,
877  graph_cost, 0, tok->links);
878 
879  // "changed" tells us whether the new token has a different
880  // cost from before, or is new [if so, add into queue].
881  if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0)
882  queue_.push_back(e_new);
883  }
884  }
885  } // for all arcs
886  } // while queue not empty
887 }
888 
889 
890 template <typename FST, typename Token>
892  for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {
893  e_tail = e->tail;
894  toks_.Delete(e);
895  }
896 }
897 
898 template <typename FST, typename Token>
899 void LatticeFasterDecoderTpl<FST, Token>::ClearActiveTokens() { // a cleanup routine, at utt end/begin
900  for (size_t i = 0; i < active_toks_.size(); i++) {
901  // Delete all tokens alive on this frame, and any forward
902  // links they may have.
903  for (Token *tok = active_toks_[i].toks; tok != NULL; ) {
904  DeleteForwardLinks(tok);
905  Token *next_tok = tok->next;
906  delete tok;
907  num_toks_--;
908  tok = next_tok;
909  }
910  }
911  active_toks_.clear();
912  KALDI_ASSERT(num_toks_ == 0);
913 }
914 
915 // static
916 template <typename FST, typename Token>
918  Token *tok_list, std::vector<Token*> *topsorted_list) {
919  unordered_map<Token*, int32> token2pos;
920  typedef typename unordered_map<Token*, int32>::iterator IterType;
921  int32 num_toks = 0;
922  for (Token *tok = tok_list; tok != NULL; tok = tok->next)
923  num_toks++;
924  int32 cur_pos = 0;
925  // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0.
926  // This is likely to be in closer to topological order than
927  // if we had given them ascending order, because of the way
928  // new tokens are put at the front of the list.
929  for (Token *tok = tok_list; tok != NULL; tok = tok->next)
930  token2pos[tok] = num_toks - ++cur_pos;
931 
932  unordered_set<Token*> reprocess;
933 
934  for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) {
935  Token *tok = iter->first;
936  int32 pos = iter->second;
937  for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) {
938  if (link->ilabel == 0) {
939  // We only need to consider epsilon links, since non-epsilon links
940  // transition between frames and this function only needs to sort a list
941  // of tokens from a single frame.
942  IterType following_iter = token2pos.find(link->next_tok);
943  if (following_iter != token2pos.end()) { // another token on this frame,
944  // so must consider it.
945  int32 next_pos = following_iter->second;
946  if (next_pos < pos) { // reassign the position of the next Token.
947  following_iter->second = cur_pos++;
948  reprocess.insert(link->next_tok);
949  }
950  }
951  }
952  }
953  // In case we had previously assigned this token to be reprocessed, we can
954  // erase it from that set because it's "happy now" (we just processed it).
955  reprocess.erase(tok);
956  }
957 
958  size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles.
959  for (loop_count = 0;
960  !reprocess.empty() && loop_count < max_loop; ++loop_count) {
961  std::vector<Token*> reprocess_vec;
962  for (typename unordered_set<Token*>::iterator iter = reprocess.begin();
963  iter != reprocess.end(); ++iter)
964  reprocess_vec.push_back(*iter);
965  reprocess.clear();
966  for (typename std::vector<Token*>::iterator iter = reprocess_vec.begin();
967  iter != reprocess_vec.end(); ++iter) {
968  Token *tok = *iter;
969  int32 pos = token2pos[tok];
970  // Repeat the processing we did above (for comments, see above).
971  for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) {
972  if (link->ilabel == 0) {
973  IterType following_iter = token2pos.find(link->next_tok);
974  if (following_iter != token2pos.end()) {
975  int32 next_pos = following_iter->second;
976  if (next_pos < pos) {
977  following_iter->second = cur_pos++;
978  reprocess.insert(link->next_tok);
979  }
980  }
981  }
982  }
983  }
984  }
985  KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding "
986  "graph (this is not allowed!)");
987 
988  topsorted_list->clear();
989  topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between.
990  for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter)
991  (*topsorted_list)[iter->second] = iter->first;
992 }
993 
994 // Instantiate the template for the combination of token types and FST types
995 // that we'll need.
997 template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::StdToken >;
998 template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::StdToken >;
1000 
1002 template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::BackpointerToken >;
1003 template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::BackpointerToken >;
1005 
1006 
1007 } // end namespace kaldi.
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
bool DeterminizeLatticePruned(const ExpandedFst< ArcTpl< Weight > > &ifst, double beam, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > *ofst, DeterminizeLatticePrunedOptions opts)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
LatticeFasterDecoderTpl(const FST &fst, const LatticeFasterDecoderConfig &config)
kaldi::int32 int32
const size_t count
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::StdArc::Label Label
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
fst::StdArc::Weight Weight
This is the "normal" lattice-generating decoder.
void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames=-1)
This will decode until there are no more frames ready in the decodable object.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
typename HashList< StateId, Token * >::Elem Elem
virtual BaseFloat LogLikelihood(int32 frame, int32 index)=0
Returns the log likelihood, which will be negated in the decoder.
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Definition: kaldi-math.h:265