lattice-incremental-decoder.h
Go to the documentation of this file.
1 // decoder/lattice-incremental-decoder.h
2 
3 // Copyright 2019 Zhehuai Chen, Daniel Povey
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_
21 #define KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_
22 
23 #include "util/stl-utils.h"
24 #include "util/hash-list.h"
25 #include "fst/fstlib.h"
26 #include "itf/decodable-itf.h"
27 #include "fstext/fstext-lib.h"
29 #include "lat/kaldi-lattice.h"
30 #include "decoder/grammar-fst.h"
31 #include "lattice-faster-decoder.h"
32 
33 namespace kaldi {
107  // All the configuration values until det_opts are the same as in
108  // LatticeFasterDecoder. For clarity we repeat them rather than inheriting.
114  BaseFloat beam_delta; // has nothing to do with beam_ratio
116  BaseFloat prune_scale; // Note: we don't make this configurable on the command line,
117  // it's not a very important parameter. It affects the
118  // algorithm that prunes the tokens as we go.
119  // Most of the options inside det_opts are not actually queried by the
120  // LatticeIncrementalDecoder class itself, but by the code that calls it, for
121  // example in the function DecodeUtteranceLatticeIncremental.
123 
124  // The configuration values from this point on are specific to the
125  // incremental determinization. See where they are registered for
126  // explanation.
127  // Caution: these are only inspected in UpdateLatticeDeterminization().
128  // If you call
131 
132 
134  : beam(16.0),
135  max_active(std::numeric_limits<int32>::max()),
136  min_active(200),
137  lattice_beam(10.0),
138  prune_interval(25),
139  beam_delta(0.5),
140  hash_ratio(2.0),
141  prune_scale(0.01),
142  determinize_max_delay(60),
143  determinize_min_chunk_size(20) {
144  det_opts.minimize = false;
145  }
146  void Register(OptionsItf *opts) {
147  det_opts.Register(opts);
148  opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate.");
149  opts->Register("max-active", &max_active,
150  "Decoder max active states. Larger->slower; "
151  "more accurate");
152  opts->Register("min-active", &min_active, "Decoder minimum #active states.");
153  opts->Register("lattice-beam", &lattice_beam,
154  "Lattice generation beam. Larger->slower, "
155  "and deeper lattices");
156  opts->Register("prune-interval", &prune_interval,
157  "Interval (in frames) at "
158  "which to prune tokens");
159  opts->Register("beam-delta", &beam_delta,
160  "Increment used in decoding-- this "
161  "parameter is obscure and relates to a speedup in the way the "
162  "max-active constraint is applied. Larger is more accurate.");
163  opts->Register("hash-ratio", &hash_ratio,
164  "Setting used in decoder to "
165  "control hash behavior");
166  opts->Register("determinize-max-delay", &determinize_max_delay,
167  "Maximum frames of delay between decoding a frame and "
168  "determinizing it");
169  opts->Register("determinize-min-chunk-size", &determinize_min_chunk_size,
170  "Minimum chunk size used in determinization");
171 
172  }
173  void Check() const {
174  if (!(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 &&
175  min_active <= max_active && prune_interval > 0 &&
176  beam_delta > 0.0 && hash_ratio >= 1.0 &&
177  prune_scale > 0.0 && prune_scale < 1.0 &&
178  determinize_max_delay > determinize_min_chunk_size &&
179  determinize_min_chunk_size > 0))
180  KALDI_ERR << "Invalid options given to decoder";
181  /* Minimization of the chunks is not compatible withour algorithm (or at
182  least, would require additional complexity to implement.) */
183  if (det_opts.minimize || !det_opts.word_determinize)
184  KALDI_ERR << "Invalid determinization options given to decoder.";
185  }
186 };
187 
188 
189 
197  public:
198  using Label = typename LatticeArc::Label; /* Actualy the same labels appear
199  in both lattice and compact
200  lattice, so we don't use the
201  specific type all the time but
202  just say 'Label' */
204  const TransitionModel &trans_model,
205  const LatticeIncrementalDecoderConfig &config):
206  trans_model_(trans_model), config_(config) { }
207 
208  // Resets the lattice determinization data for new utterance
209  void Init();
210 
211  // Returns the current determinized lattice.
212  const CompactLattice &GetDeterminizedLattice() const { return clat_; }
213 
232  void InitializeRawLatticeChunk(
233  Lattice *olat,
234  unordered_map<Label, LatticeArc::StateId> *token_label2state);
235 
257  bool AcceptRawLatticeChunk(Lattice *raw_fst);
258 
259  /*
260  Sets final-probs in `clat_`. Must only be called if the final chunk
261  has not been processed. (The final chunk is whenever GetLattice() is
262  called with finalize == true).
263 
264  The reason this is a separate function from AcceptRawLatticeChunk() is that
265  there may be situations where a user wants to get the latice with
266  final-probs in it, after previously getting it without final-probs; or
267  vice versa. By final-probs, we mean the Final() probabilities in the
268  HCLG (decoding graph; this->fst_).
269 
270  @param [in] token_label2final_cost A map from the token-label
271  corresponding to Tokens active on the final frame of the
272  lattice in the object, to the final-cost we want to use for
273  those tokens. If NULL, it means all Tokens should be treated
274  as final with probability One(). If non-NULL, and a particular
275  token-label is not a key of this map, it means that Token
276  corresponded to a state that was not final in HCLG; and
277  such tokens will be treated as non-final. However,
278  if this would result in no states in the lattice being final,
279  we will treat all Tokens as final with probability One(),
280  a warning will be printed (this should not happen.)
281  */
282  void SetFinalCosts(const unordered_map<Label, BaseFloat> *token_label2final_cost = NULL);
283 
284  const CompactLattice &GetLattice() { return clat_; }
285 
286  // kStateLabelOffset is what we add to state-ids in clat_ to produce labels
287  // to identify them in the raw lattice chunk
288  // kTokenLabelOffset is where we start allocating labels corresponding to Tokens
289  // (these correspond with raw lattice states);
290  enum { kStateLabelOffset = (int)1e8, kTokenLabelOffset = (int)2e8, kMaxTokenLabel = (int)3e8 };
291 
292  private:
293 
294  // [called from AcceptRawLatticeChunk()]
295  // Gets the final costs from token-final states in the raw lattice (see
296  // glossary for definition). These final costs will be subtracted after
297  // determinization; in the normal case they are `temporaries` used to guide
298  // pruning. NOTE: the index of the array is not the FST state that is final,
299  // but the label on arcs entering it (these will be `token-labels`). Each
300  // token-final state will have the same label on all arcs entering it.
301  //
302  // `old_final_costs` is assumed to be empty at entry.
303  void GetRawLatticeFinalCosts(const Lattice &raw_fst,
304  std::unordered_map<Label, BaseFloat> *old_final_costs);
305 
306  // Sets up non_final_redet_states_. See documentation for that variable.
307  void GetNonFinalRedetStates();
308 
333  bool ProcessArcsFromChunkStartState(
334  const CompactLattice &chunk_clat,
335  std::unordered_map<CompactLattice::StateId, CompactLattice::StateId> *state_map);
336 
361  void TransferArcsToClat(
362  const CompactLattice &chunk_clat,
363  bool is_first_chunk,
364  const std::unordered_map<CompactLattice::StateId, CompactLattice::StateId> &state_map,
365  const std::unordered_map<CompactLattice::StateId, Label> &chunk_state_to_token,
366  const std::unordered_map<Label, BaseFloat> &old_final_costs);
367 
368 
369 
374  void AddArcToClat(CompactLattice::StateId state,
375  const CompactLatticeArc &arc);
376  CompactLattice::StateId AddStateToClat();
377 
378 
379  // Identifies token-final states in `chunk_clat`; see glossary above for
380  // definition of `token-final`. This function outputs a map from such states
381  // in chunk_clat, to the `token-label` on arcs entering them. (It is not
382  // possible that the same state would have multiple arcs entering it with
383  // different token-labels, or some arcs entering with one token-label and some
384  // another, or be both initial and have such arcs; this is true due to how we
385  // construct the raw lattice.)
386  void IdentifyTokenFinalStates(
387  const CompactLattice &chunk_clat,
388  std::unordered_map<CompactLattice::StateId, CompactLatticeArc::Label> *token_map) const;
389 
390  // trans_model_ is needed by DeterminizeLatticePhonePrunedWrapper() which this
391  // class calls.
393  // config_ is needed by DeterminizeLatticePhonePrunedWrapper() which this
394  // class calls.
396 
397 
398  // Contains the set of redeterminized-states which are not final in the
399  // canonical appended lattice. Since the final ones don't physically appear
400  // in clat_, this means the set of redeterminized-states which are physically
401  // in clat_. In code terms, this means set of .first elements in final_arcs,
402  // plus whatever other states in clat_ are reachable from such states.
403  std::unordered_set<CompactLattice::StateId> non_final_redet_states_;
404 
405 
406  // clat_ is the appended lattice (containing all chunks processed so
407  // far), except its `final-arcs` (i.e. arcs which in the canonical
408  // lattice would go to final-states) are not present (they are stored
409  // separately in final_arcs_) and states which in the canonical lattice
410  // should have final-arcs leaving them will instead have a final-prob.
412 
413 
414  // arcs_in_ is indexed by (state-id in clat_), and is a list of
415  // arcs that come into this state, in the form (prev-state,
416  // arc-index). CAUTION: not all these input-arc records will always
417  // be valid (some may be out-of-date, and may refer to an out-of-range
418  // arc or an arc that does not point to this state). But all
419  // input arcs will always be listed.
420  std::vector<std::vector<std::pair<CompactLattice::StateId, int32> > > arcs_in_;
421 
422  // final_arcs_ contains arcs which would appear in the canonical appended
423  // lattice but for implementation reasons are not physically present in clat_.
424  // These are arcs to final states in the canonical appended lattice. The
425  // .first elements are the source states in clat_ (these will all be elements
426  // of non_final_redet_states_); the .nextstate elements of the arcs does not
427  // contain a physical state, but contain state-labels allocated by
428  // AllocateNewStateLabel().
429  std::vector<CompactLatticeArc> final_arcs_;
430 
431  // forward_costs_, indexed by the state-id in clat_, stores the alpha
432  // (forward) costs, i.e. the minimum cost from the start state to each state
433  // in clat_. This is relevant for pruned determinization. The BaseFloat can
434  // be thought of as the sum of a Value1() + Value2() in a LatticeWeight.
435  std::vector<BaseFloat> forward_costs_;
436 
437  // temporary used in a function, kept here to avoid excessive reallocation.
438  std::unordered_set<int32> temp_;
439 
441 };
442 
443 
464 template <typename FST, typename Token = decoder::StdToken>
466  public:
467  using Arc = typename FST::Arc;
468  using Label = typename Arc::Label;
469  using StateId = typename Arc::StateId;
470  using Weight = typename Arc::Weight;
472 
473  // Instantiate this class once for each thing you have to decode.
474  // This version of the constructor does not take ownership of
475  // 'fst'.
476  LatticeIncrementalDecoderTpl(const FST &fst, const TransitionModel &trans_model,
477  const LatticeIncrementalDecoderConfig &config);
478 
479  // This version of the constructor takes ownership of the fst, and will delete
480  // it when this object is destroyed.
482  FST *fst, const TransitionModel &trans_model);
483 
484  void SetOptions(const LatticeIncrementalDecoderConfig &config) { config_ = config; }
485 
486  const LatticeIncrementalDecoderConfig &GetOptions() const { return config_; }
487 
489 
510  bool Decode(DecodableInterface *decodable);
511 
514  bool ReachedFinal() const {
515  return FinalRelativeCost() != std::numeric_limits<BaseFloat>::infinity();
516  }
517 
557  const CompactLattice &GetLattice(int32 num_frames_to_include,
558  bool use_final_probs = false);
559 
560  /*
561  Returns the number of frames in the currently-determinized part of the
562  lattice which will be a number in [0, NumFramesDecoded()]. It will
563  be the largest number that GetLattice() was called with, but note
564  that GetLattice() may be called from UpdateLatticeDeterminization().
565 
566  Made available in case the user wants to give that same number to
567  GetLattice().
568  */
569  int NumFramesInLattice() const { return num_frames_in_lattice_; }
570 
577  void InitDecoding();
578 
586  void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames = -1);
587 
588 
597  BaseFloat FinalRelativeCost() const;
598 
600  inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }
601 
606  void FinalizeDecoding();
607 
608  protected:
609  /* Some protected things are needed in LatticeIncrementalOnlineDecoderTpl. */
610 
613  inline static void DeleteForwardLinks(Token *tok);
614  struct TokenList {
615  Token *toks;
618  int32 num_toks; /* Note: you can only trust `num_toks` if must_prune_tokens
619  * == false, because it is only set in
620  * PruneTokensForFrame(). */
622  : toks(NULL), must_prune_forward_links(true), must_prune_tokens(true),
623  num_toks(-1) {}
624  };
626  void PossiblyResizeHash(size_t num_toks);
627  inline Token *FindOrAddToken(StateId state, int32 frame_plus_one,
628  BaseFloat tot_cost, Token *backpointer, bool *changed);
629  void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed,
630  bool *links_pruned, BaseFloat delta);
631  void ComputeFinalCosts(unordered_map<Token *, BaseFloat> *final_costs,
632  BaseFloat *final_relative_cost,
633  BaseFloat *final_best_cost) const;
634  void PruneForwardLinksFinal();
635  void PruneTokensForFrame(int32 frame_plus_one);
636  void PruneActiveTokens(BaseFloat delta);
637  BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam,
638  Elem **best_elem);
639  BaseFloat ProcessEmitting(DecodableInterface *decodable);
640  void ProcessNonemitting(BaseFloat cost_cutoff);
641 
643  std::vector<TokenList> active_toks_; // indexed by frame.
644  std::vector<StateId> queue_; // temp variable used in ProcessNonemitting,
645  std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
646  const FST *fst_;
648  std::vector<BaseFloat> cost_offsets_;
650  bool warned_;
652 
653  unordered_map<Token *, BaseFloat> final_costs_;
656 
657  /***********************
658  Variables below this point relate to the incremental
659  determinization.
660  *********************/
665 
666 
667  /* Just a temporary used in a function; stored here to avoid reallocation. */
668  unordered_map<Token*, StateId> temp_token_map_;
669 
673 
674  // A map from Token to its token_label. Will contain an entry for
675  // each Token in active_toks_[num_frames_in_lattice_].
676  unordered_map<Token*, Label> token2label_map_;
677 
678  // A temporary used in a function, kept here to avoid reallocation.
679  unordered_map<Token*, Label> token2label_map_temp_;
680 
681  // we allocate a unique id for each Token
683 
684  inline Label AllocateNewTokenLabel() { return next_token_label_++; }
685 
686 
687  // There are various cleanup tasks... the the toks_ structure contains
688  // singly linked lists of Token pointers, where Elem is the list type.
689  // It also indexes them in a hash, indexed by state (this hash is only
690  // maintained for the most recent frame). toks_.Clear()
691  // deletes them from the hash and returns the list of Elems. The
692  // function DeleteElems calls toks_.Delete(elem) for each elem in
693  // the list, which returns ownership of the Elem to the toks_ structure
694  // for reuse, but does not delete the Token pointer. The Token pointers
695  // are reference-counted and are ultimately deleted in PruneTokensForFrame,
696  // but are also linked together on each frame by their own linked-list,
697  // using the "next" pointer. We delete them manually.
698  void DeleteElems(Elem *list);
699 
700  void ClearActiveTokens();
701 
702 
703  // Returns the number of active tokens on frame `frame`. Can be used as part
704  // of a heuristic to decide which frame to determinize until, if you are not
705  // at the end of an utterance.
706  int32 GetNumToksForFrame(int32 frame);
707 
717  void UpdateLatticeDeterminization();
718 
719 
721 };
722 
725 
726 
727 } // end namespace kaldi.
728 
729 #endif
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
std::unordered_set< CompactLattice::StateId > non_final_redet_states_
unordered_map< Token *, StateId > temp_token_map_
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
unordered_map< Token *, Label > token2label_map_
fst::DeterminizeLatticePhonePrunedOptions det_opts
This is an extention to the "normal" lattice-generating decoder.
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
unordered_map< Token *, Label > token2label_map_temp_
void SetOptions(const LatticeIncrementalDecoderConfig &config)
kaldi::int32 int32
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
typename HashList< StateId, decoder::BackpointerToken *>::Elem Elem
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
const CompactLattice & GetDeterminizedLattice() const
This class is used inside LatticeIncrementalDecoderTpl; it handles some of the details of incremental...
int32 num_frames_in_lattice_
num_frames_in_lattice_ is the highest `num_frames_to_include_` argument for any prior call to GetLatt...
const LatticeIncrementalDecoderConfig & config_
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
LatticeIncrementalDeterminizer(const TransitionModel &trans_model, const LatticeIncrementalDecoderConfig &config)
LatticeIncrementalDecoderTpl< fst::StdFst, decoder::StdToken > LatticeIncrementalDecoder
fst::StdArc::Label Label
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
fst::StdArc::Weight Weight
The normal decoder, lattice-faster-decoder.h, sometimes has an issue when doing real-time application...
int32 NumFramesDecoded() const
Returns the number of frames decoded so far.
const LatticeIncrementalDecoderConfig & GetOptions() const
fst::ArcTpl< CompactLatticeWeight > CompactLatticeArc
Definition: kaldi-lattice.h:42
LatticeIncrementalDeterminizer determinizer_
Much of the the incremental determinization algorithm is encapsulated in the determinize_ object...
std::vector< std::vector< std::pair< CompactLattice::StateId, int32 > > > arcs_in_
bool ReachedFinal() const
says whether a final-state was active on the last frame.
std::vector< CompactLatticeArc > final_arcs_
unordered_map< Token *, BaseFloat > final_costs_