lattice-faster-decoder.h
Go to the documentation of this file.
1 // decoder/lattice-faster-decoder.h
2 
3 // Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
4 // 2013-2014 Johns Hopkins University (Author: Daniel Povey)
5 // 2014 Guoguo Chen
6 // 2018 Zhehuai Chen
7 
8 // See ../../COPYING for clarification regarding multiple authors
9 //
10 // Licensed under the Apache License, Version 2.0 (the "License");
11 // you may not use this file except in compliance with the License.
12 // You may obtain a copy of the License at
13 //
14 // http://www.apache.org/licenses/LICENSE-2.0
15 //
16 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
18 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
19 // MERCHANTABLITY OR NON-INFRINGEMENT.
20 // See the Apache 2 License for the specific language governing permissions and
21 // limitations under the License.
22 
23 #ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
24 #define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
25 
26 
27 #include "util/stl-utils.h"
28 #include "util/hash-list.h"
29 #include "fst/fstlib.h"
30 #include "itf/decodable-itf.h"
31 #include "fstext/fstext-lib.h"
33 #include "lat/kaldi-lattice.h"
34 #include "decoder/grammar-fst.h"
35 
36 namespace kaldi {
37 
44  bool determinize_lattice; // not inspected by this class... used in
45  // command-line program.
48  // Note: we don't make prune_scale configurable on the command line, it's not
49  // a very important parameter. It affects the algorithm that prunes the
50  // tokens as we go.
52 
53  // Most of the options inside det_opts are not actually queried by the
54  // LatticeFasterDecoder class itself, but by the code that calls it, for
55  // example in the function DecodeUtteranceLatticeFaster.
57 
59  max_active(std::numeric_limits<int32>::max()),
60  min_active(200),
61  lattice_beam(10.0),
62  prune_interval(25),
63  determinize_lattice(true),
64  beam_delta(0.5),
65  hash_ratio(2.0),
66  prune_scale(0.1) { }
67  void Register(OptionsItf *opts) {
68  det_opts.Register(opts);
69  opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate.");
70  opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; "
71  "more accurate");
72  opts->Register("min-active", &min_active, "Decoder minimum #active states.");
73  opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, "
74  "and deeper lattices");
75  opts->Register("prune-interval", &prune_interval, "Interval (in frames) at "
76  "which to prune tokens");
77  opts->Register("determinize-lattice", &determinize_lattice, "If true, "
78  "determinize the lattice (lattice-determinization, keeping only "
79  "best pdf-sequence for each word-sequence).");
80  opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this "
81  "parameter is obscure and relates to a speedup in the way the "
82  "max-active constraint is applied. Larger is more accurate.");
83  opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to "
84  "control hash behavior");
85  }
86  void Check() const {
87  KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0
88  && min_active <= max_active
89  && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0
90  && prune_scale > 0.0 && prune_scale < 1.0);
91  }
92 };
93 
94 namespace decoder {
95 // We will template the decoder on the token type as well as the FST type; this
96 // is a mechanism so that we can use the same underlying decoder code for
97 // versions of the decoder that support quickly getting the best path
98 // (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also
99 // those that do not (LatticeFasterDecoder).
100 
101 
102 // ForwardLinks are the links from a token to a token on the next frame.
103 // or sometimes on the current frame (for input-epsilon links).
104 template <typename Token>
105 struct ForwardLink {
107 
108  Token *next_tok; // the next token [or NULL if represents final-state]
109  Label ilabel; // ilabel on arc
110  Label olabel; // olabel on arc
111  BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.)
112  BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc
113  ForwardLink *next; // next in singly-linked list of forward arcs (arcs
114  // in the state-level lattice) from a token.
115  inline ForwardLink(Token *next_tok, Label ilabel, Label olabel,
116  BaseFloat graph_cost, BaseFloat acoustic_cost,
117  ForwardLink *next):
118  next_tok(next_tok), ilabel(ilabel), olabel(olabel),
119  graph_cost(graph_cost), acoustic_cost(acoustic_cost),
120  next(next) { }
121 };
122 
123 
124 struct StdToken {
126  using Token = StdToken;
127 
128  // Standard token type for LatticeFasterDecoder. Each active HCLG
129  // (decoding-graph) state on each frame has one token.
130 
131  // tot_cost is the total (LM + acoustic) cost from the beginning of the
132  // utterance up to this point. (but see cost_offset_, which is subtracted
133  // to keep it in a good numerical range).
135 
136  // exta_cost is >= 0. After calling PruneForwardLinks, this equals the
137  // minimum difference between the cost of the best path that this link is a
138  // part of, and the cost of the absolute best path, under the assumption that
139  // any of the currently active states at the decoding front may eventually
140  // succeed (e.g. if you were to take the currently active states one by one
141  // and compute this difference, and then take the minimum).
143 
144  // 'links' is the head of singly-linked list of ForwardLinks, which is what we
145  // use for lattice generation.
147 
148  //'next' is the next in the singly-linked list of tokens for this frame.
150 
151  // This function does nothing and should be optimized out; it's needed
152  // so we can share the regular LatticeFasterDecoderTpl code and the code
153  // for LatticeFasterOnlineDecoder that supports fast traceback.
154  inline void SetBackpointer (Token *backpointer) { }
155 
156  // This constructor just ignores the 'backpointer' argument. That argument is
157  // needed so that we can use the same decoder code for LatticeFasterDecoderTpl
158  // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a
159  // fast way to obtain the best path).
160  inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
161  Token *next, Token *backpointer):
162  tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { }
163 };
164 
168 
169  // BackpointerToken is like Token but also
170  // Standard token type for LatticeFasterDecoder. Each active HCLG
171  // (decoding-graph) state on each frame has one token.
172 
173  // tot_cost is the total (LM + acoustic) cost from the beginning of the
174  // utterance up to this point. (but see cost_offset_, which is subtracted
175  // to keep it in a good numerical range).
177 
178  // exta_cost is >= 0. After calling PruneForwardLinks, this equals
179  // the minimum difference between the cost of the best path, and the cost of
180  // this is on, and the cost of the absolute best path, under the assumption
181  // that any of the currently active states at the decoding front may
182  // eventually succeed (e.g. if you were to take the currently active states
183  // one by one and compute this difference, and then take the minimum).
185 
186  // 'links' is the head of singly-linked list of ForwardLinks, which is what we
187  // use for lattice generation.
189 
190  //'next' is the next in the singly-linked list of tokens for this frame.
192 
193  // Best preceding BackpointerToken (could be a on this frame, connected to
194  // this via an epsilon transition, or on a previous frame). This is only
195  // required for an efficient GetBestPath function in
196  // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation
197  // (the "links" list is what stores the forward links, for that).
199 
200  inline void SetBackpointer (Token *backpointer) {
201  this->backpointer = backpointer;
202  }
203 
204  inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
205  Token *next, Token *backpointer):
206  tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next),
207  backpointer(backpointer) { }
208 };
209 
210 } // namespace decoder
211 
212 
228 template <typename FST, typename Token = decoder::StdToken>
230  public:
231  using Arc = typename FST::Arc;
232  using Label = typename Arc::Label;
233  using StateId = typename Arc::StateId;
234  using Weight = typename Arc::Weight;
236 
237  // Instantiate this class once for each thing you have to decode.
238  // This version of the constructor does not take ownership of
239  // 'fst'.
240  LatticeFasterDecoderTpl(const FST &fst,
241  const LatticeFasterDecoderConfig &config);
242 
243  // This version of the constructor takes ownership of the fst, and will delete
244  // it when this object is destroyed.
246  FST *fst);
247 
249  config_ = config;
250  }
251 
253  return config_;
254  }
255 
257 
262  bool Decode(DecodableInterface *decodable);
263 
264 
267  bool ReachedFinal() const {
268  return FinalRelativeCost() != std::numeric_limits<BaseFloat>::infinity();
269  }
270 
277  bool GetBestPath(Lattice *ofst,
278  bool use_final_probs = true) const;
279 
291  bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const;
292 
293 
294 
302  bool GetLattice(CompactLattice *ofst,
303  bool use_final_probs = true) const;
304 
309  void InitDecoding();
310 
315  void AdvanceDecoding(DecodableInterface *decodable,
316  int32 max_num_frames = -1);
317 
325  void FinalizeDecoding();
326 
335  BaseFloat FinalRelativeCost() const;
336 
337 
338  // Returns the number of frames decoded so far. The value returned changes
339  // whenever we call ProcessEmitting().
340  inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }
341 
342  protected:
343  // we make things protected instead of private, as code in
344  // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the
345  // internals.
346 
347  // Deletes the elements of the singly linked list tok->links.
348  inline static void DeleteForwardLinks(Token *tok);
349 
350  // head of per-frame list of Tokens (list is in topological order),
351  // and something saying whether we ever pruned it using PruneForwardLinks.
352  struct TokenList {
353  Token *toks;
356  TokenList(): toks(NULL), must_prune_forward_links(true),
357  must_prune_tokens(true) { }
358  };
359 
361  // Equivalent to:
362  // struct Elem {
363  // StateId key;
364  // Token *val;
365  // Elem *tail;
366  // };
367 
368  void PossiblyResizeHash(size_t num_toks);
369 
370  // FindOrAddToken either locates a token in hash of toks_, or if necessary
371  // inserts a new, empty token (i.e. with no forward links) for the current
372  // frame. [note: it's inserted if necessary into hash toks_ and also into the
373  // singly linked list of tokens active on this frame (whose head is at
374  // active_toks_[frame]). The frame_plus_one argument is the acoustic frame
375  // index plus one, which is used to index into the active_toks_ array.
376  // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the
377  // token was newly created or the cost changed.
378  // If Token == StdToken, the 'backpointer' argument has no purpose (and will
379  // hopefully be optimized out).
380  inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one,
381  BaseFloat tot_cost, Token *backpointer,
382  bool *changed);
383 
384  // prunes outgoing links for all tokens in active_toks_[frame]
385  // it's called by PruneActiveTokens
386  // all links, that have link_extra_cost > lattice_beam are pruned
387  // delta is the amount by which the extra_costs must change
388  // before we set *extra_costs_changed = true.
389  // If delta is larger, we'll tend to go back less far
390  // toward the beginning of the file.
391  // extra_costs_changed is set to true if extra_cost was changed for any token
392  // links_pruned is set to true if any link in any token was pruned
393  void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed,
394  bool *links_pruned,
395  BaseFloat delta);
396 
397  // This function computes the final-costs for tokens active on the final
398  // frame. It outputs to final-costs, if non-NULL, a map from the Token*
399  // pointer to the final-prob of the corresponding state, for all Tokens
400  // that correspond to states that have final-probs. This map will be
401  // empty if there were no final-probs. It outputs to
402  // final_relative_cost, if non-NULL, the difference between the best
403  // forward-cost including the final-prob cost, and the best forward-cost
404  // without including the final-prob cost (this will usually be positive), or
405  // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which
406  // outputs this quanitity]. It outputs to final_best_cost, if
407  // non-NULL, the lowest for any token t active on the final frame, of
408  // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in
409  // the graph of the state corresponding to token t, or the best of
410  // forward-cost[t] if there were no final-probs active on the final frame.
411  // You cannot call this after FinalizeDecoding() has been called; in that
412  // case you should get the answer from class-member variables.
413  void ComputeFinalCosts(unordered_map<Token*, BaseFloat> *final_costs,
414  BaseFloat *final_relative_cost,
415  BaseFloat *final_best_cost) const;
416 
417  // PruneForwardLinksFinal is a version of PruneForwardLinks that we call
418  // on the final frame. If there are final tokens active, it uses
419  // the final-probs for pruning, otherwise it treats all tokens as final.
420  void PruneForwardLinksFinal();
421 
422  // Prune away any tokens on this frame that have no forward links.
423  // [we don't do this in PruneForwardLinks because it would give us
424  // a problem with dangling pointers].
425  // It's called by PruneActiveTokens if any forward links have been pruned
426  void PruneTokensForFrame(int32 frame_plus_one);
427 
428 
429  // Go backwards through still-alive tokens, pruning them if the
430  // forward+backward cost is more than lat_beam away from the best path. It's
431  // possible to prove that this is "correct" in the sense that we won't lose
432  // anything outside of lat_beam, regardless of what happens in the future.
433  // delta controls when it considers a cost to have changed enough to continue
434  // going backward and propagating the change. larger delta -> will recurse
435  // less far.
436  void PruneActiveTokens(BaseFloat delta);
437 
439  BaseFloat GetCutoff(Elem *list_head, size_t *tok_count,
440  BaseFloat *adaptive_beam, Elem **best_elem);
441 
445  BaseFloat ProcessEmitting(DecodableInterface *decodable);
446 
450  void ProcessNonemitting(BaseFloat cost_cutoff);
451 
452  // HashList defined in ../util/hash-list.h. It actually allows us to maintain
453  // more than one list (e.g. for current and previous frames), but only one of
454  // them at a time can be indexed by StateId. It is indexed by frame-index
455  // plus one, where the frame-index is zero-based, as used in decodable object.
456  // That is, the emitting probs of frame t are accounted for in tokens at
457  // toks_[t+1]. The zeroth frame is for nonemitting transition at the start of
458  // the graph.
460 
461  std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
462  // frame (members of TokenList are toks, must_prune_forward_links,
463  // must_prune_tokens).
464  std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
465  std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
466 
467  // fst_ is a pointer to the FST we are decoding from.
468  const FST *fst_;
469  // delete_fst_ is true if the pointer fst_ needs to be deleted when this
470  // object is destroyed.
472 
473  std::vector<BaseFloat> cost_offsets_; // This contains, for each
474  // frame, an offset that was added to the acoustic log-likelihoods on that
475  // frame in order to keep everything in a nice dynamic range i.e. close to
476  // zero, to reduce roundoff errors.
478  int32 num_toks_; // current total #toks allocated...
479  bool warned_;
480 
491  unordered_map<Token*, BaseFloat> final_costs_;
494 
495  // There are various cleanup tasks... the toks_ structure contains
496  // singly linked lists of Token pointers, where Elem is the list type.
497  // It also indexes them in a hash, indexed by state (this hash is only
498  // maintained for the most recent frame). toks_.Clear()
499  // deletes them from the hash and returns the list of Elems. The
500  // function DeleteElems calls toks_.Delete(elem) for each elem in
501  // the list, which returns ownership of the Elem to the toks_ structure
502  // for reuse, but does not delete the Token pointer. The Token pointers
503  // are reference-counted and are ultimately deleted in PruneTokensForFrame,
504  // but are also linked together on each frame by their own linked-list,
505  // using the "next" pointer. We delete them manually.
506  void DeleteElems(Elem *list);
507 
508  // This function takes a singly linked list of tokens for a single frame, and
509  // outputs a list of them in topological order (it will crash if no such order
510  // can be found, which will typically be due to decoding graphs with epsilon
511  // cycles, which are not allowed). Note: the output list may contain NULLs,
512  // which the caller should pass over; it just happens to be more efficient for
513  // the algorithm to output a list that contains NULLs.
514  static void TopSortTokens(Token *tok_list,
515  std::vector<Token*> *topsorted_list);
516 
517  void ClearActiveTokens();
518 
520 };
521 
523 
524 
525 
526 } // end namespace kaldi.
527 
528 #endif
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void SetOptions(const LatticeFasterDecoderConfig &config)
void SetBackpointer(Token *backpointer)
const LatticeFasterDecoderConfig & GetOptions() const
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
bool ReachedFinal() const
says whether a final-state was active on the last frame.
LatticeFasterDecoderConfig config_
StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, Token *next, Token *backpointer)
kaldi::int32 int32
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
void SetBackpointer(Token *backpointer)
std::vector< const Elem *> queue_
unordered_map< Token *, BaseFloat > final_costs_
For the meaning of the next 3 variables, see the comment for decoding_finalized_ above., and ComputeFinalCosts().
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
fst::StdArc::Label Label
std::vector< TokenList > active_toks_
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
fst::StdArc::Weight Weight
fst::DeterminizeLatticePhonePrunedOptions det_opts
This is the "normal" lattice-generating decoder.
std::vector< BaseFloat > tmp_array_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
LatticeFasterDecoderTpl< fst::StdFst, decoder::StdToken > LatticeFasterDecoder
bool decoding_finalized_
decoding_finalized_ is true if someone called FinalizeDecoding().
BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, Token *next, Token *backpointer)
typename HashList< StateId, decoder::BackpointerToken *>::Elem Elem
std::vector< BaseFloat > cost_offsets_
HashList< StateId, Token * > toks_