simple-decoder.cc
Go to the documentation of this file.
1 // decoder/simple-decoder.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 // 2012-2013 Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "decoder/simple-decoder.h"
23 #include <algorithm>
24 
25 namespace kaldi {
26 
30 }
31 
32 
34  InitDecoding();
35  AdvanceDecoding(decodable);
36  return (!cur_toks_.empty());
37 }
38 
40  // clean up from last time:
43  // initialize decoding:
44  StateId start_state = fst_.Start();
45  KALDI_ASSERT(start_state != fst::kNoStateId);
46  StdArc dummy_arc(0, 0, StdWeight::One(), start_state);
47  cur_toks_[start_state] = new Token(dummy_arc, 0.0, NULL);
50 }
51 
53  int32 max_num_frames) {
55  "You must call InitDecoding() before AdvanceDecoding()");
56  int32 num_frames_ready = decodable->NumFramesReady();
57  // num_frames_ready must be >= num_frames_decoded, or else
58  // the number of frames ready must have decreased (which doesn't
59  // make sense) or the decodable object changed between calls
60  // (which isn't allowed).
61  KALDI_ASSERT(num_frames_ready >= num_frames_decoded_);
62  int32 target_frames_decoded = num_frames_ready;
63  if (max_num_frames >= 0)
64  target_frames_decoded = std::min(target_frames_decoded,
65  num_frames_decoded_ + max_num_frames);
66  while (num_frames_decoded_ < target_frames_decoded) {
67  // note: ProcessEmitting() increments num_frames_decoded_
69  cur_toks_.swap(prev_toks_);
70  ProcessEmitting(decodable);
73  }
74 }
75 
77  for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
78  iter != cur_toks_.end();
79  ++iter) {
80  if (iter->second->cost_ != std::numeric_limits<BaseFloat>::infinity() &&
81  fst_.Final(iter->first) != StdWeight::Zero())
82  return true;
83  }
84  return false;
85 }
86 
88  // as a special case, if there are no active tokens at all (e.g. some kind of
89  // pruning failure), return infinity.
90  double infinity = std::numeric_limits<double>::infinity();
91  if (cur_toks_.empty())
92  return infinity;
93  double best_cost = infinity,
94  best_cost_with_final = infinity;
95  for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
96  iter != cur_toks_.end();
97  ++iter) {
98  // Note: Plus is taking the minimum cost, since we're in the tropical
99  // semiring.
100  best_cost = std::min(best_cost, iter->second->cost_);
101  best_cost_with_final = std::min(best_cost_with_final,
102  iter->second->cost_ +
103  fst_.Final(iter->first).Value());
104  }
105  BaseFloat extra_cost = best_cost_with_final - best_cost;
106  if (extra_cost != extra_cost) { // NaN. This shouldn't happen; it indicates some
107  // kind of error, most likely.
108  KALDI_WARN << "Found NaN (likely search failure in decoding)";
109  return infinity;
110  }
111  // Note: extra_cost will be infinity if no states were final.
112  return extra_cost;
113 }
114 
115 // Outputs an FST corresponding to the single best path
116 // through the lattice.
117 bool SimpleDecoder::GetBestPath(Lattice *fst_out, bool use_final_probs) const {
118  fst_out->DeleteStates();
119  Token *best_tok = NULL;
120  bool is_final = ReachedFinal();
121  if (!is_final) {
122  for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
123  iter != cur_toks_.end();
124  ++iter)
125  if (best_tok == NULL || *best_tok < *(iter->second) )
126  best_tok = iter->second;
127  } else {
128  double infinity =std::numeric_limits<double>::infinity(),
129  best_cost = infinity;
130  for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
131  iter != cur_toks_.end();
132  ++iter) {
133  double this_cost = iter->second->cost_ + fst_.Final(iter->first).Value();
134  if (this_cost != infinity && this_cost < best_cost) {
135  best_cost = this_cost;
136  best_tok = iter->second;
137  }
138  }
139  }
140  if (best_tok == NULL) return false; // No output.
141 
142  std::vector<LatticeArc> arcs_reverse; // arcs in reverse order.
143  for (Token *tok = best_tok; tok != NULL; tok = tok->prev_)
144  arcs_reverse.push_back(tok->arc_);
145  KALDI_ASSERT(arcs_reverse.back().nextstate == fst_.Start());
146  arcs_reverse.pop_back(); // that was a "fake" token... gives no info.
147 
148  StateId cur_state = fst_out->AddState();
149  fst_out->SetStart(cur_state);
150  for (ssize_t i = static_cast<ssize_t>(arcs_reverse.size())-1; i >= 0; i--) {
151  LatticeArc arc = arcs_reverse[i];
152  arc.nextstate = fst_out->AddState();
153  fst_out->AddArc(cur_state, arc);
154  cur_state = arc.nextstate;
155  }
156  if (is_final && use_final_probs)
157  fst_out->SetFinal(cur_state,
158  LatticeWeight(fst_.Final(best_tok->arc_.nextstate).Value(),
159  0.0));
160  else
161  fst_out->SetFinal(cur_state, LatticeWeight::One());
162  fst::RemoveEpsLocal(fst_out);
163  return true;
164 }
165 
166 
168  int32 frame = num_frames_decoded_;
169  // Processes emitting arcs for one frame. Propagates from
170  // prev_toks_ to cur_toks_.
171  double cutoff = std::numeric_limits<BaseFloat>::infinity();
172  for (unordered_map<StateId, Token*>::iterator iter = prev_toks_.begin();
173  iter != prev_toks_.end();
174  ++iter) {
175  StateId state = iter->first;
176  Token *tok = iter->second;
177  KALDI_ASSERT(state == tok->arc_.nextstate);
178  for (fst::ArcIterator<fst::Fst<StdArc> > aiter(fst_, state);
179  !aiter.Done();
180  aiter.Next()) {
181  const StdArc &arc = aiter.Value();
182  if (arc.ilabel != 0) { // propagate..
183  BaseFloat acoustic_cost = -decodable->LogLikelihood(frame, arc.ilabel);
184  double total_cost = tok->cost_ + arc.weight.Value() + acoustic_cost;
185 
186  if (total_cost >= cutoff) continue;
187  if (total_cost + beam_ < cutoff)
188  cutoff = total_cost + beam_;
189  Token *new_tok = new Token(arc, acoustic_cost, tok);
190  unordered_map<StateId, Token*>::iterator find_iter
191  = cur_toks_.find(arc.nextstate);
192  if (find_iter == cur_toks_.end()) {
193  cur_toks_[arc.nextstate] = new_tok;
194  } else {
195  if ( *(find_iter->second) < *new_tok ) {
196  Token::TokenDelete(find_iter->second);
197  find_iter->second = new_tok;
198  } else {
199  Token::TokenDelete(new_tok);
200  }
201  }
202  }
203  }
204  }
206 }
207 
209  // Processes nonemitting arcs for one frame. Propagates within
210  // cur_toks_.
211  std::vector<StateId> queue;
212  double infinity = std::numeric_limits<double>::infinity();
213  double best_cost = infinity;
214  for (unordered_map<StateId, Token*>::iterator iter = cur_toks_.begin();
215  iter != cur_toks_.end();
216  ++iter) {
217  queue.push_back(iter->first);
218  best_cost = std::min(best_cost, iter->second->cost_);
219  }
220  double cutoff = best_cost + beam_;
221 
222  while (!queue.empty()) {
223  StateId state = queue.back();
224  queue.pop_back();
225  Token *tok = cur_toks_[state];
226  KALDI_ASSERT(tok != NULL && state == tok->arc_.nextstate);
227  for (fst::ArcIterator<fst::Fst<StdArc> > aiter(fst_, state);
228  !aiter.Done();
229  aiter.Next()) {
230  const StdArc &arc = aiter.Value();
231  if (arc.ilabel == 0) { // propagate nonemitting only...
232  const BaseFloat acoustic_cost = 0.0;
233  Token *new_tok = new Token(arc, acoustic_cost, tok);
234  if (new_tok->cost_ > cutoff) {
235  Token::TokenDelete(new_tok);
236  } else {
237  unordered_map<StateId, Token*>::iterator find_iter
238  = cur_toks_.find(arc.nextstate);
239  if (find_iter == cur_toks_.end()) {
240  cur_toks_[arc.nextstate] = new_tok;
241  queue.push_back(arc.nextstate);
242  } else {
243  if ( *(find_iter->second) < *new_tok ) {
244  Token::TokenDelete(find_iter->second);
245  find_iter->second = new_tok;
246  queue.push_back(arc.nextstate);
247  } else {
248  Token::TokenDelete(new_tok);
249  }
250  }
251  }
252  }
253  }
254  }
255 }
256 
257 // static
258 void SimpleDecoder::ClearToks(unordered_map<StateId, Token*> &toks) {
259  for (unordered_map<StateId, Token*>::iterator iter = toks.begin();
260  iter != toks.end(); ++iter) {
261  Token::TokenDelete(iter->second);
262  }
263  toks.clear();
264 }
265 
266 // static
267 void SimpleDecoder::PruneToks(BaseFloat beam, unordered_map<StateId, Token*> *toks) {
268  if (toks->empty()) {
269  KALDI_VLOG(2) << "No tokens to prune.\n";
270  return;
271  }
272  double best_cost = std::numeric_limits<double>::infinity();
273  for (unordered_map<StateId, Token*>::iterator iter = toks->begin();
274  iter != toks->end(); ++iter)
275  best_cost = std::min(best_cost, iter->second->cost_);
276  std::vector<StateId> retained;
277  double cutoff = best_cost + beam;
278  for (unordered_map<StateId, Token*>::iterator iter = toks->begin();
279  iter != toks->end(); ++iter) {
280  if (iter->second->cost_ < cutoff)
281  retained.push_back(iter->first);
282  else
283  Token::TokenDelete(iter->second);
284  }
285  unordered_map<StateId, Token*> tmp;
286  for (size_t i = 0; i < retained.size(); i++) {
287  tmp[retained[i]] = (*toks)[retained[i]];
288  }
289  KALDI_VLOG(2) << "Pruned to " << (retained.size()) << " toks.\n";
290  tmp.swap(*toks);
291 }
292 
293 } // end namespace kaldi.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
void ProcessEmitting(DecodableInterface *decodable)
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
unordered_map< StateId, Token * > cur_toks_
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
StdArc::StateId StateId
static const LatticeWeightTpl One()
void RemoveEpsLocal(MutableFst< Arc > *fst)
RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST, using an algorithm that is g...
unordered_map< StateId, Token * > prev_toks_
static void ClearToks(unordered_map< StateId, Token *> &toks)
bool ReachedFinal() const
kaldi::int32 int32
BaseFloat FinalRelativeCost() const
*** The next functions are from the "new interface". ***
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
const fst::Fst< fst::StdArc > & fst_
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_WARN
Definition: kaldi-error.h:150
static void PruneToks(BaseFloat beam, unordered_map< StateId, Token *> *toks)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
bool Decode(DecodableInterface *decodable)
Decode this utterance.
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
static void TokenDelete(Token *tok)
void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames=-1)
This will decode until there are no more frames ready in the decodable object, but if max_num_frames ...
virtual BaseFloat LogLikelihood(int32 frame, int32 index)=0
Returns the log likelihood, which will be negated in the decoder.
bool GetBestPath(Lattice *fst_out, bool use_final_probs=true) const
void InitDecoding()
InitDecoding initializes the decoding, and should only be used if you intend to call AdvanceDecoding(...