faster-decoder.cc
Go to the documentation of this file.
1 // decoder/faster-decoder.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 // 2012-2013 Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "decoder/faster-decoder.h"
22 
23 namespace kaldi {
24 
25 
26 FasterDecoder::FasterDecoder(const fst::Fst<fst::StdArc> &fst,
27  const FasterDecoderOptions &opts):
28  fst_(fst), config_(opts), num_frames_decoded_(-1) {
29  KALDI_ASSERT(config_.hash_ratio >= 1.0); // less doesn't make much sense.
32  toks_.SetSize(1000); // just so on the first frame we do something reasonable.
33 }
34 
35 
37  // clean up from last time:
39  StateId start_state = fst_.Start();
40  KALDI_ASSERT(start_state != fst::kNoStateId);
41  Arc dummy_arc(0, 0, Weight::One(), start_state);
42  toks_.Insert(start_state, new Token(dummy_arc, NULL));
43  ProcessNonemitting(std::numeric_limits<float>::max());
45 }
46 
47 
49  InitDecoding();
50  AdvanceDecoding(decodable);
51 }
52 
54  int32 max_num_frames) {
56  "You must call InitDecoding() before AdvanceDecoding()");
57  int32 num_frames_ready = decodable->NumFramesReady();
58  // num_frames_ready must be >= num_frames_decoded, or else
59  // the number of frames ready must have decreased (which doesn't
60  // make sense) or the decodable object changed between calls
61  // (which isn't allowed).
62  KALDI_ASSERT(num_frames_ready >= num_frames_decoded_);
63  int32 target_frames_decoded = num_frames_ready;
64  if (max_num_frames >= 0)
65  target_frames_decoded = std::min(target_frames_decoded,
66  num_frames_decoded_ + max_num_frames);
67  while (num_frames_decoded_ < target_frames_decoded) {
68  // note: ProcessEmitting() increments num_frames_decoded_
69  double weight_cutoff = ProcessEmitting(decodable);
70  ProcessNonemitting(weight_cutoff);
71  }
72 }
73 
74 
76  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
77  if (e->val->cost_ != std::numeric_limits<double>::infinity() &&
78  fst_.Final(e->key) != Weight::Zero())
79  return true;
80  }
81  return false;
82 }
83 
84 bool FasterDecoder::GetBestPath(fst::MutableFst<LatticeArc> *fst_out,
85  bool use_final_probs) {
86  // GetBestPath gets the decoding output. If "use_final_probs" is true
87  // AND we reached a final state, it limits itself to final states;
88  // otherwise it gets the most likely token not taking into
89  // account final-probs. fst_out will be empty (Start() == kNoStateId) if
90  // nothing was available. It returns true if it got output (thus, fst_out
91  // will be nonempty).
92  fst_out->DeleteStates();
93  Token *best_tok = NULL;
94  bool is_final = ReachedFinal();
95  if (!is_final) {
96  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
97  if (best_tok == NULL || *best_tok < *(e->val) )
98  best_tok = e->val;
99  } else {
100  double infinity = std::numeric_limits<double>::infinity(),
101  best_cost = infinity;
102  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
103  double this_cost = e->val->cost_ + fst_.Final(e->key).Value();
104  if (this_cost < best_cost && this_cost != infinity) {
105  best_cost = this_cost;
106  best_tok = e->val;
107  }
108  }
109  }
110  if (best_tok == NULL) return false; // No output.
111 
112  std::vector<LatticeArc> arcs_reverse; // arcs in reverse order.
113 
114  for (Token *tok = best_tok; tok != NULL; tok = tok->prev_) {
115  BaseFloat tot_cost = tok->cost_ -
116  (tok->prev_ ? tok->prev_->cost_ : 0.0),
117  graph_cost = tok->arc_.weight.Value(),
118  ac_cost = tot_cost - graph_cost;
119  LatticeArc l_arc(tok->arc_.ilabel,
120  tok->arc_.olabel,
121  LatticeWeight(graph_cost, ac_cost),
122  tok->arc_.nextstate);
123  arcs_reverse.push_back(l_arc);
124  }
125  KALDI_ASSERT(arcs_reverse.back().nextstate == fst_.Start());
126  arcs_reverse.pop_back(); // that was a "fake" token... gives no info.
127 
128  StateId cur_state = fst_out->AddState();
129  fst_out->SetStart(cur_state);
130  for (ssize_t i = static_cast<ssize_t>(arcs_reverse.size())-1; i >= 0; i--) {
131  LatticeArc arc = arcs_reverse[i];
132  arc.nextstate = fst_out->AddState();
133  fst_out->AddArc(cur_state, arc);
134  cur_state = arc.nextstate;
135  }
136  if (is_final && use_final_probs) {
137  Weight final_weight = fst_.Final(best_tok->arc_.nextstate);
138  fst_out->SetFinal(cur_state, LatticeWeight(final_weight.Value(), 0.0));
139  } else {
140  fst_out->SetFinal(cur_state, LatticeWeight::One());
141  }
142  RemoveEpsLocal(fst_out);
143  return true;
144 }
145 
146 
147 // Gets the weight cutoff. Also counts the active tokens.
148 double FasterDecoder::GetCutoff(Elem *list_head, size_t *tok_count,
149  BaseFloat *adaptive_beam, Elem **best_elem) {
150  double best_cost = std::numeric_limits<double>::infinity();
151  size_t count = 0;
152  if (config_.max_active == std::numeric_limits<int32>::max() &&
153  config_.min_active == 0) {
154  for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
155  double w = e->val->cost_;
156  if (w < best_cost) {
157  best_cost = w;
158  if (best_elem) *best_elem = e;
159  }
160  }
161  if (tok_count != NULL) *tok_count = count;
162  if (adaptive_beam != NULL) *adaptive_beam = config_.beam;
163  return best_cost + config_.beam;
164  } else {
165  tmp_array_.clear();
166  for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
167  double w = e->val->cost_;
168  tmp_array_.push_back(w);
169  if (w < best_cost) {
170  best_cost = w;
171  if (best_elem) *best_elem = e;
172  }
173  }
174  if (tok_count != NULL) *tok_count = count;
175  double beam_cutoff = best_cost + config_.beam,
176  min_active_cutoff = std::numeric_limits<double>::infinity(),
177  max_active_cutoff = std::numeric_limits<double>::infinity();
178 
179  if (tmp_array_.size() > static_cast<size_t>(config_.max_active)) {
180  std::nth_element(tmp_array_.begin(),
181  tmp_array_.begin() + config_.max_active,
182  tmp_array_.end());
183  max_active_cutoff = tmp_array_[config_.max_active];
184  }
185  if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam.
186  if (adaptive_beam)
187  *adaptive_beam = max_active_cutoff - best_cost + config_.beam_delta;
188  return max_active_cutoff;
189  }
190  if (tmp_array_.size() > static_cast<size_t>(config_.min_active)) {
191  if (config_.min_active == 0) min_active_cutoff = best_cost;
192  else {
193  std::nth_element(tmp_array_.begin(),
194  tmp_array_.begin() + config_.min_active,
195  tmp_array_.size() > static_cast<size_t>(config_.max_active) ?
196  tmp_array_.begin() + config_.max_active :
197  tmp_array_.end());
198  min_active_cutoff = tmp_array_[config_.min_active];
199  }
200  }
201  if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam.
202  if (adaptive_beam)
203  *adaptive_beam = min_active_cutoff - best_cost + config_.beam_delta;
204  return min_active_cutoff;
205  } else {
206  *adaptive_beam = config_.beam;
207  return beam_cutoff;
208  }
209  }
210 }
211 
212 void FasterDecoder::PossiblyResizeHash(size_t num_toks) {
213  size_t new_sz = static_cast<size_t>(static_cast<BaseFloat>(num_toks)
214  * config_.hash_ratio);
215  if (new_sz > toks_.Size()) {
216  toks_.SetSize(new_sz);
217  }
218 }
219 
220 // ProcessEmitting returns the likelihood cutoff used.
222  int32 frame = num_frames_decoded_;
223  Elem *last_toks = toks_.Clear();
224  size_t tok_cnt;
225  BaseFloat adaptive_beam;
226  Elem *best_elem = NULL;
227  double weight_cutoff = GetCutoff(last_toks, &tok_cnt,
228  &adaptive_beam, &best_elem);
229  KALDI_VLOG(3) << tok_cnt << " tokens active.";
230  PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough.
231 
232  // This is the cutoff we use after adding in the log-likes (i.e.
233  // for the next frame). This is a bound on the cutoff we will use
234  // on the next frame.
235  double next_weight_cutoff = std::numeric_limits<double>::infinity();
236 
237  // First process the best token to get a hopefully
238  // reasonably tight bound on the next cutoff.
239  if (best_elem) {
240  StateId state = best_elem->key;
241  Token *tok = best_elem->val;
242  for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
243  !aiter.Done();
244  aiter.Next()) {
245  const Arc &arc = aiter.Value();
246  if (arc.ilabel != 0) { // we'd propagate..
247  BaseFloat ac_cost = - decodable->LogLikelihood(frame, arc.ilabel);
248  double new_weight = arc.weight.Value() + tok->cost_ + ac_cost;
249  if (new_weight + adaptive_beam < next_weight_cutoff)
250  next_weight_cutoff = new_weight + adaptive_beam;
251  }
252  }
253  }
254 
255  // int32 n = 0, np = 0;
256 
257  // the tokens are now owned here, in last_toks, and the hash is empty.
258  // 'owned' is a complex thing here; the point is we need to call TokenDelete
259  // on each elem 'e' to let toks_ know we're done with them.
260  for (Elem *e = last_toks, *e_tail; e != NULL; e = e_tail) { // loop this way
261  // n++;
262  // because we delete "e" as we go.
263  StateId state = e->key;
264  Token *tok = e->val;
265  if (tok->cost_ < weight_cutoff) { // not pruned.
266  // np++;
267  KALDI_ASSERT(state == tok->arc_.nextstate);
268  for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
269  !aiter.Done();
270  aiter.Next()) {
271  Arc arc = aiter.Value();
272  if (arc.ilabel != 0) { // propagate..
273  BaseFloat ac_cost = - decodable->LogLikelihood(frame, arc.ilabel);
274  double new_weight = arc.weight.Value() + tok->cost_ + ac_cost;
275  if (new_weight < next_weight_cutoff) { // not pruned..
276  Token *new_tok = new Token(arc, ac_cost, tok);
277  Elem *e_found = toks_.Insert(arc.nextstate, new_tok);
278  if (new_weight + adaptive_beam < next_weight_cutoff)
279  next_weight_cutoff = new_weight + adaptive_beam;
280  if (e_found->val != new_tok) {
281  if (*(e_found->val) < *new_tok) {
282  Token::TokenDelete(e_found->val);
283  e_found->val = new_tok;
284  } else {
285  Token::TokenDelete(new_tok);
286  }
287  }
288  }
289  }
290  }
291  }
292  e_tail = e->tail;
293  Token::TokenDelete(e->val);
294  toks_.Delete(e);
295  }
297  return next_weight_cutoff;
298 }
299 
300 // TODO: first time we go through this, could avoid using the queue.
302  // Processes nonemitting arcs for one frame.
303  KALDI_ASSERT(queue_.empty());
304  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
305  queue_.push_back(e);
306  while (!queue_.empty()) {
307  const Elem* e = queue_.back();
308  queue_.pop_back();
309  StateId state = e->key;
310  Token *tok = e->val; // would segfault if state not
311  // in toks_ but this can't happen.
312  if (tok->cost_ > cutoff) { // Don't bother processing successors.
313  continue;
314  }
315  KALDI_ASSERT(tok != NULL && state == tok->arc_.nextstate);
316  for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
317  !aiter.Done();
318  aiter.Next()) {
319  const Arc &arc = aiter.Value();
320  if (arc.ilabel == 0) { // propagate nonemitting only...
321  Token *new_tok = new Token(arc, tok);
322  if (new_tok->cost_ > cutoff) { // prune
323  Token::TokenDelete(new_tok);
324  } else {
325  Elem *e_found = toks_.Insert(arc.nextstate, new_tok);
326  if (e_found->val == new_tok) {
327  queue_.push_back(e_found);
328  } else {
329  if (*(e_found->val) < *new_tok) {
330  Token::TokenDelete(e_found->val);
331  e_found->val = new_tok;
332  queue_.push_back(e_found);
333  } else {
334  Token::TokenDelete(new_tok);
335  }
336  }
337  }
338  }
339  }
340  }
341 }
342 
344  for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {
345  Token::TokenDelete(e->val);
346  e_tail = e->tail;
347  toks_.Delete(e);
348  }
349 }
350 
351 } // end namespace kaldi.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
Elem * Insert(I key, T val)
Insert inserts a new element into the hashtable/stored list.
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
void ClearToks(Elem *list)
FasterDecoderOptions config_
static const LatticeWeightTpl One()
void RemoveEpsLocal(MutableFst< Arc > *fst)
RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST, using an algorithm that is g...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void InitDecoding()
As a new alternative to Decode(), you can call InitDecoding and then (possibly multiple times) Advanc...
void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames=-1)
This will decode until there are no more frames ready in the decodable object, but if max_num_frames ...
void Decode(DecodableInterface *decodable)
kaldi::int32 int32
void PossiblyResizeHash(size_t num_toks)
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
bool GetBestPath(fst::MutableFst< LatticeArc > *fst_out, bool use_final_probs=true)
GetBestPath gets the decoding traceback.
static void TokenDelete(Token *tok)
const fst::Fst< fst::StdArc > & fst_
const size_t count
double GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem)
Gets the weight cutoff. Also counts the active tokens.
void SetSize(size_t sz)
SetSize tells the object how many hash buckets to allocate (should typically be at least twice the nu...
Definition: hash-list-inl.h:37
void ProcessNonemitting(double cutoff)
const Elem * GetList() const
Gives the head of the current list to the user.
Definition: hash-list-inl.h:61
Elem * Clear()
Clears the hash and gives the head of the current list to the user; ownership is transferred to the u...
Definition: hash-list-inl.h:46
HashList< StateId, Token * > toks_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
std::vector< BaseFloat > tmp_array_
Arc::StateId StateId
virtual BaseFloat LogLikelihood(int32 frame, int32 index)=0
Returns the log likelihood, which will be negated in the decoder.
FasterDecoder(const fst::Fst< fst::StdArc > &fst, const FasterDecoderOptions &config)
double ProcessEmitting(DecodableInterface *decodable)
void Delete(Elem *e)
Think of this like delete().
Definition: hash-list-inl.h:66
std::vector< const Elem *> queue_
bool ReachedFinal() const
Returns true if a final state was active on the last frame.
size_t Size()
Returns current number of hash buckets.
Definition: hash-list.h:113