biglm-faster-decoder.h
Go to the documentation of this file.
1 // decoder/biglm-faster-decoder.h
2 
3 // Copyright 2009-2011 Microsoft Corporation, Gilles Boulianne
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_DECODER_BIGLM_FASTER_DECODER_H_
21 #define KALDI_DECODER_BIGLM_FASTER_DECODER_H_
22 
23 #include "util/stl-utils.h"
24 #include "util/hash-list.h"
25 #include "fst/fstlib.h"
26 #include "itf/decodable-itf.h"
27 #include "lat/kaldi-lattice.h" // for CompactLatticeArc
28 #include "decoder/faster-decoder.h" // for options class
30 
31 namespace kaldi {
32 
35  min_active = 200;
36  }
37 };
38 
52  public:
53  typedef fst::StdArc Arc;
54  typedef Arc::Label Label;
56  // A PairId will be constructed as: (StateId in fst) + (StateId in lm_diff_fst) << 32;
57  typedef uint64 PairId;
59 
60  // This constructor is the same as for FasterDecoder, except the second
61  // argument (lm_diff_fst) is new; it's an FST (actually, a
62  // DeterministicOnDemandFst) that represents the difference in LM scores
63  // between the LM we want and the LM the decoding-graph "fst" was built with.
64  // See e.g. gmm-decode-biglm-faster.cc for an example of how this is called.
65  // Basically, we are using fst o lm_diff_fst (where o is composition)
66  // as the decoding graph. Instead of having everything indexed by the state in
67  // "fst", we now index by the pair of states in (fst, lm_diff_fst).
68  // Whenever we cross a word, we need to propagate the state within
69  // lm_diff_fst.
70  BiglmFasterDecoder(const fst::Fst<fst::StdArc> &fst,
71  const BiglmFasterDecoderOptions &opts,
73  fst_(fst), lm_diff_fst_(lm_diff_fst), opts_(opts), warned_noarc_(false) {
74  KALDI_ASSERT(opts_.hash_ratio >= 1.0); // less doesn't make much sense.
75  KALDI_ASSERT(opts_.max_active > 1);
76  KALDI_ASSERT(fst.Start() != fst::kNoStateId &&
77  lm_diff_fst->Start() != fst::kNoStateId);
78  toks_.SetSize(1000); // just so on the first frame we do something reasonable.
79  }
80 
81  void SetOptions(const BiglmFasterDecoderOptions &opts) { opts_ = opts; }
82 
84  ClearToks(toks_.Clear());
85  }
86 
87  void Decode(DecodableInterface *decodable) {
88  // clean up from last time:
89  ClearToks(toks_.Clear());
90  PairId start_pair = ConstructPair(fst_.Start(), lm_diff_fst_->Start());
91  Arc dummy_arc(0, 0, Weight::One(), fst_.Start()); // actually, the last element of
92  // the Arcs (fst_.Start(), here) is never needed.
93  toks_.Insert(start_pair, new Token(dummy_arc, NULL));
94  ProcessNonemitting(std::numeric_limits<float>::max());
95  for (int32 frame = 0; !decodable->IsLastFrame(frame-1); frame++) {
96  BaseFloat weight_cutoff = ProcessEmitting(decodable, frame);
97  ProcessNonemitting(weight_cutoff);
98  }
99  }
100 
101  bool ReachedFinal() {
102  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
103  PairId state_pair = e->key;
104  StateId state = PairToState(state_pair),
105  lm_state = PairToLmState(state_pair);
106  Weight this_weight =
107  Times(e->val->weight_,
108  Times(fst_.Final(state), lm_diff_fst_->Final(lm_state)));
109  if (this_weight != Weight::Zero())
110  return true;
111  }
112  return false;
113  }
114 
115  bool GetBestPath(fst::MutableFst<LatticeArc> *fst_out,
116  bool use_final_probs = true) {
117  // GetBestPath gets the decoding output. If "use_final_probs" is true
118  // AND we reached a final state, it limits itself to final states;
119  // otherwise it gets the most likely token not taking into
120  // account final-probs. fst_out will be empty (Start() == kNoStateId) if
121  // nothing was available. It returns true if it got output (thus, fst_out
122  // will be nonempty).
123  fst_out->DeleteStates();
124  Token *best_tok = NULL;
125  Weight best_final = Weight::Zero(); // set only if is_final == true. The
126  // final-prob corresponding to the best final token (i.e. the one with best
127  // weight best_weight, below).
128  bool is_final = ReachedFinal();
129  if (!is_final) {
130  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
131  if (best_tok == NULL || *best_tok < *(e->val) )
132  best_tok = e->val;
133  } else {
134  Weight best_weight = Weight::Zero();
135  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
136  Weight fst_final = fst_.Final(PairToState(e->key)),
137  lm_final = lm_diff_fst_->Final(PairToLmState(e->key)),
138  final = Times(fst_final, lm_final);
139  Weight this_weight = Times(e->val->weight_, final);
140  if (this_weight != Weight::Zero() &&
141  this_weight.Value() < best_weight.Value()) {
142  best_weight = this_weight;
143  best_final = final;
144  best_tok = e->val;
145  }
146  }
147  }
148  if (best_tok == NULL) return false; // No output.
149 
150  std::vector<LatticeArc> arcs_reverse; // arcs in reverse order.
151 
152  for (Token *tok = best_tok; tok != NULL; tok = tok->prev_) {
153  BaseFloat tot_cost = tok->weight_.Value() -
154  (tok->prev_ ? tok->prev_->weight_.Value() : 0.0),
155  graph_cost = tok->arc_.weight.Value(),
156  ac_cost = tot_cost - graph_cost;
157  LatticeArc l_arc(tok->arc_.ilabel,
158  tok->arc_.olabel,
159  LatticeWeight(graph_cost, ac_cost),
160  tok->arc_.nextstate);
161  arcs_reverse.push_back(l_arc);
162  }
163  KALDI_ASSERT(arcs_reverse.back().nextstate == fst_.Start());
164  arcs_reverse.pop_back(); // that was a "fake" token... gives no info.
165 
166  StateId cur_state = fst_out->AddState();
167  fst_out->SetStart(cur_state);
168  for (ssize_t i = static_cast<ssize_t>(arcs_reverse.size())-1; i >= 0; i--) {
169  LatticeArc arc = arcs_reverse[i];
170  arc.nextstate = fst_out->AddState();
171  fst_out->AddArc(cur_state, arc);
172  cur_state = arc.nextstate;
173  }
174  if (is_final && use_final_probs) {
175  fst_out->SetFinal(cur_state, LatticeWeight(best_final.Value(), 0.0));
176  } else {
177  fst_out->SetFinal(cur_state, LatticeWeight::One());
178  }
179  RemoveEpsLocal(fst_out);
180  return true;
181  }
182 
183  private:
184  inline PairId ConstructPair(StateId fst_state, StateId lm_state) {
185  return static_cast<PairId>(fst_state) + (static_cast<PairId>(lm_state) << 32);
186  }
187 
188  static inline StateId PairToState(PairId state_pair) {
189  return static_cast<StateId>(static_cast<uint32>(state_pair));
190  }
191  static inline StateId PairToLmState(PairId state_pair) {
192  return static_cast<StateId>(static_cast<uint32>(state_pair >> 32));
193  }
194 
195  class Token {
196  public:
197  Arc arc_; // contains only the graph part of the cost,
198  // including the part in "fst" (== HCLG) plus lm_diff_fst.
199  // We can work out the acoustic part from difference between
200  // "weight_" and prev->weight_.
203  Weight weight_; // weight up to current point.
204  inline Token(const Arc &arc, Weight &ac_weight, Token *prev):
205  arc_(arc), prev_(prev), ref_count_(1) {
206  if (prev) {
207  prev->ref_count_++;
208  weight_ = Times(Times(prev->weight_, arc.weight), ac_weight);
209  } else {
210  weight_ = Times(arc.weight, ac_weight);
211  }
212  }
213  inline Token(const Arc &arc, Token *prev):
214  arc_(arc), prev_(prev), ref_count_(1) {
215  if (prev) {
216  prev->ref_count_++;
217  weight_ = Times(prev->weight_, arc.weight);
218  } else {
219  weight_ = arc.weight;
220  }
221  }
222  inline bool operator < (const Token &other) {
223  return weight_.Value() > other.weight_.Value();
224  // This makes sense for log + tropical semiring.
225  }
226 
227  inline ~Token() {
228  KALDI_ASSERT(ref_count_ == 1);
229  if (prev_ != NULL) TokenDelete(prev_);
230  }
231  inline static void TokenDelete(Token *tok) {
232  if (tok->ref_count_ == 1) {
233  delete tok;
234  } else {
235  tok->ref_count_--;
236  }
237  }
238  };
240 
241 
243  BaseFloat GetCutoff(Elem *list_head, size_t *tok_count,
244  BaseFloat *adaptive_beam, Elem **best_elem) {
245  BaseFloat best_weight = 1.0e+10; // positive == high cost == bad.
246  size_t count = 0;
247  if (opts_.max_active == std::numeric_limits<int32>::max() &&
248  opts_.min_active == 0) {
249  for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
250  BaseFloat w = static_cast<BaseFloat>(e->val->weight_.Value());
251  if (w < best_weight) {
252  best_weight = w;
253  if (best_elem) *best_elem = e;
254  }
255  }
256  if (tok_count != NULL) *tok_count = count;
257  if (adaptive_beam != NULL) *adaptive_beam = opts_.beam;
258  return best_weight + opts_.beam;
259  } else {
260  tmp_array_.clear();
261  for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
262  BaseFloat w = e->val->weight_.Value();
263  tmp_array_.push_back(w);
264  if (w < best_weight) {
265  best_weight = w;
266  if (best_elem) *best_elem = e;
267  }
268  }
269  if (tok_count != NULL) *tok_count = count;
270 
271  BaseFloat beam_cutoff = best_weight + opts_.beam,
272  min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
273  max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
274 
275  if (tmp_array_.size() > static_cast<size_t>(opts_.max_active)) {
276  std::nth_element(tmp_array_.begin(),
277  tmp_array_.begin() + opts_.max_active,
278  tmp_array_.end());
279  max_active_cutoff = tmp_array_[opts_.max_active];
280  }
281  if (tmp_array_.size() > static_cast<size_t>(opts_.min_active)) {
282  if (opts_.min_active == 0) min_active_cutoff = best_weight;
283  else {
284  std::nth_element(tmp_array_.begin(),
285  tmp_array_.begin() + opts_.min_active,
286  tmp_array_.size() > static_cast<size_t>(opts_.max_active) ?
287  tmp_array_.begin() + opts_.max_active :
288  tmp_array_.end());
289  min_active_cutoff = tmp_array_[opts_.min_active];
290  }
291  }
292 
293  if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam.
294  if (adaptive_beam)
295  *adaptive_beam = max_active_cutoff - best_weight + opts_.beam_delta;
296  return max_active_cutoff;
297  } else if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam.
298  if (adaptive_beam)
299  *adaptive_beam = min_active_cutoff - best_weight + opts_.beam_delta;
300  return min_active_cutoff;
301  } else {
302  *adaptive_beam = opts_.beam;
303  return beam_cutoff;
304  }
305  }
306  }
307 
308  void PossiblyResizeHash(size_t num_toks) {
309  size_t new_sz = static_cast<size_t>(static_cast<BaseFloat>(num_toks)
310  * opts_.hash_ratio);
311  if (new_sz > toks_.Size()) {
312  toks_.SetSize(new_sz);
313  }
314  }
315 
316  inline StateId PropagateLm(StateId lm_state,
317  Arc *arc) { // returns new LM state.
318  if (arc->olabel == 0) {
319  return lm_state; // no change in LM state if no word crossed.
320  } else { // Propagate in the LM-diff FST.
321  Arc lm_arc;
322  bool ans = lm_diff_fst_->GetArc(lm_state, arc->olabel, &lm_arc);
323  if (!ans) { // this case is unexpected for statistical LMs.
324  if (!warned_noarc_) {
325  warned_noarc_ = true;
326  KALDI_WARN << "No arc available in LM (unlikely to be correct "
327  "if a statistical language model); will not warn again";
328  }
329  arc->weight = Weight::Zero();
330  return lm_state; // doesn't really matter what we return here; will
331  // be pruned.
332  } else {
333  arc->weight = Times(arc->weight, lm_arc.weight);
334  arc->olabel = lm_arc.olabel; // probably will be the same.
335  return lm_arc.nextstate; // return the new LM state.
336  }
337  }
338  }
339 
340  // ProcessEmitting returns the likelihood cutoff used.
342  Elem *last_toks = toks_.Clear();
343  size_t tok_cnt;
344  BaseFloat adaptive_beam;
345  Elem *best_elem = NULL;
346  BaseFloat weight_cutoff = GetCutoff(last_toks, &tok_cnt,
347  &adaptive_beam, &best_elem);
348  PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough.
349 
350  // This is the cutoff we use after adding in the log-likes (i.e.
351  // for the next frame). This is a bound on the cutoff we will use
352  // on the next frame.
353  BaseFloat next_weight_cutoff = 1.0e+10;
354 
355  // First process the best token to get a hopefully
356  // reasonably tight bound on the next cutoff.
357  if (best_elem) {
358  PairId state_pair = best_elem->key;
359  StateId state = PairToState(state_pair),
360  lm_state = PairToLmState(state_pair);
361  Token *tok = best_elem->val;
362  for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
363  !aiter.Done();
364  aiter.Next()) {
365  Arc arc = aiter.Value();
366  if (arc.ilabel != 0) { // we'd propagate..
367  PropagateLm(lm_state, &arc); // may affect "arc.weight".
368  // We don't need the return value (the new LM state).
369  BaseFloat ac_cost = - decodable->LogLikelihood(frame, arc.ilabel),
370  new_weight = arc.weight.Value() + tok->weight_.Value() + ac_cost;
371  if (new_weight + adaptive_beam < next_weight_cutoff)
372  next_weight_cutoff = new_weight + adaptive_beam;
373  }
374  }
375  }
376 
377  // the tokens are now owned here, in last_toks, and the hash is empty.
378  // 'owned' is a complex thing here; the point is we need to call toks_.Delete(e)
379  // on each elem 'e' to let toks_ know we're done with them.
380  for (Elem *e = last_toks, *e_tail; e != NULL; e = e_tail) { // loop this way
381  // because we delete "e" as we go.
382  PairId state_pair = e->key;
383  StateId state = PairToState(state_pair),
384  lm_state = PairToLmState(state_pair);
385  Token *tok = e->val;
386  if (tok->weight_.Value() < weight_cutoff) { // not pruned.
387  KALDI_ASSERT(state == tok->arc_.nextstate);
388  for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
389  !aiter.Done();
390  aiter.Next()) {
391  Arc arc = aiter.Value();
392  if (arc.ilabel != 0) { // propagate.
393  StateId next_lm_state = PropagateLm(lm_state, &arc);
394  Weight ac_weight(-decodable->LogLikelihood(frame, arc.ilabel));
395  BaseFloat new_weight = arc.weight.Value() + tok->weight_.Value()
396  + ac_weight.Value();
397  if (new_weight < next_weight_cutoff) { // not pruned..
398  PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
399  Token *new_tok = new Token(arc, ac_weight, tok);
400  Elem *e_found = toks_.Insert(next_pair, new_tok);
401  if (new_weight + adaptive_beam < next_weight_cutoff)
402  next_weight_cutoff = new_weight + adaptive_beam;
403  if (e_found->val != new_tok) {
404  if (*(e_found->val) < *new_tok) {
405  Token::TokenDelete(e_found->val);
406  e_found->val = new_tok;
407  } else {
408  Token::TokenDelete(new_tok);
409  }
410  }
411  }
412  }
413  }
414  }
415  e_tail = e->tail;
416  Token::TokenDelete(e->val);
417  toks_.Delete(e);
418  }
419  return next_weight_cutoff;
420  }
421 
422  // TODO: first time we go through this, could avoid using the queue.
424  // Processes nonemitting arcs for one frame.
425  KALDI_ASSERT(queue_.empty());
426  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
427  queue_.push_back(e);
428  while (!queue_.empty()) {
429  const Elem *e = queue_.back();
430  queue_.pop_back();
431  PairId state_pair = e->key;
432  Token *tok = e->val; // would segfault if state not
433  // in toks_ but this can't happen.
434  if (tok->weight_.Value() > cutoff) { // Don't bother processing successors.
435  continue;
436  }
437  KALDI_ASSERT(tok != NULL);
438  StateId state = PairToState(state_pair),
439  lm_state = PairToLmState(state_pair);
440  for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
441  !aiter.Done();
442  aiter.Next()) {
443  const Arc &arc_ref = aiter.Value();
444  if (arc_ref.ilabel == 0) { // propagate nonemitting only...
445  Arc arc(arc_ref);
446  StateId next_lm_state = PropagateLm(lm_state, &arc);
447  PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
448  Token *new_tok = new Token(arc, tok);
449  if (new_tok->weight_.Value() > cutoff) { // prune
450  Token::TokenDelete(new_tok);
451  } else {
452  Elem *e_found = toks_.Insert(next_pair, new_tok);
453  if (e_found->val == new_tok) {
454  queue_.push_back(e_found);
455  } else {
456  if ( *(e_found->val) < *new_tok ) {
457  Token::TokenDelete(e_found->val);
458  e_found->val = new_tok;
459  queue_.push_back(e_found);
460  } else {
461  Token::TokenDelete(new_tok);
462  }
463  }
464  }
465  }
466  }
467  }
468  }
469 
470  // HashList defined in ../util/hash-list.h. It actually allows us to maintain
471  // more than one list (e.g. for current and previous frames), but only one of
472  // them at a time can be indexed by PairId.
474  const fst::Fst<fst::StdArc> &fst_;
478  std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
479  std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
480  // make it class member to avoid internal new/delete.
481 
482  // It might seem unclear why we call ClearToks(toks_.Clear()).
483  // There are two separate cleanup tasks we need to do at when we start a new file.
484  // one is to delete the Token objects in the list; the other is to delete
485  // the Elem objects. toks_.Clear() just clears them from the hash and gives ownership
486  // to the caller, who then has to call toks_.Delete(e) for each one. It was designed
487  // this way for convenience in propagating tokens from one frame to the next.
488  void ClearToks(Elem *list) {
489  for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {
490  Token::TokenDelete(e->val);
491  e_tail = e->tail;
492  toks_.Delete(e);
493  }
494  }
496 };
497 
498 
499 } // end namespace kaldi.
500 
501 
502 #endif
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
HashList< PairId, Token * > toks_
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
static const LatticeWeightTpl One()
void RemoveEpsLocal(MutableFst< Arc > *fst)
RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST, using an algorithm that is g...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
fst::StdArc StdArc
virtual bool IsLastFrame(int32 frame) const =0
Returns true if this is the last frame.
kaldi::int32 int32
BiglmFasterDecoder(const fst::Fst< fst::StdArc > &fst, const BiglmFasterDecoderOptions &opts, fst::DeterministicOnDemandFst< fst::StdArc > *lm_diff_fst)
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
virtual StateId Start()=0
const fst::Fst< fst::StdArc > & fst_
StateId PropagateLm(StateId lm_state, Arc *arc)
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
const size_t count
This is as FasterDecoder, but does online composition between HCLG and the "difference language model...
void ProcessNonemitting(BaseFloat cutoff)
PairId ConstructPair(StateId fst_state, StateId lm_state)
Token(const Arc &arc, Token *prev)
void SetOptions(const BiglmFasterDecoderOptions &opts)
bool GetBestPath(fst::MutableFst< LatticeArc > *fst_out, bool use_final_probs=true)
fst::DeterministicOnDemandFst< fst::StdArc > * lm_diff_fst_
#define KALDI_WARN
Definition: kaldi-error.h:150
Token(const Arc &arc, Weight &ac_weight, Token *prev)
void PossiblyResizeHash(size_t num_toks)
std::vector< BaseFloat > tmp_array_
BaseFloat ProcessEmitting(DecodableInterface *decodable, int frame)
fst::StdArc::Label Label
fst::StdArc::Weight Weight
bool operator<(const Int32Pair &a, const Int32Pair &b)
Definition: cu-matrixdim.h:83
BiglmFasterDecoderOptions opts_
static StateId PairToLmState(PairId state_pair)
std::vector< const Elem *> queue_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Decode(DecodableInterface *decodable)
BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem)
Gets the weight cutoff. Also counts the active tokens.
static StateId PairToState(PairId state_pair)
virtual BaseFloat LogLikelihood(int32 frame, int32 index)=0
Returns the log likelihood, which will be negated in the decoder.
void Delete(Elem *e)
Think of this like delete().
Definition: hash-list-inl.h:66
HashList< PairId, Token * >::Elem Elem