lattice-faster-online-decoder.cc
Go to the documentation of this file.
1 // decoder/lattice-faster-online-decoder.cc
2 
3 // Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
4 // 2013-2014 Johns Hopkins University (Author: Daniel Povey)
5 // 2014 Guoguo Chen
6 // 2014 IMSL, PKU-HKUST (author: Wei Shi)
7 // 2018 Zhehuai Chen
8 
9 // See ../../COPYING for clarification regarding multiple authors
10 //
11 // Licensed under the Apache License, Version 2.0 (the "License");
12 // you may not use this file except in compliance with the License.
13 // You may obtain a copy of the License at
14 //
15 // http://www.apache.org/licenses/LICENSE-2.0
16 //
17 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
18 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
19 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
20 // MERCHANTABLITY OR NON-INFRINGEMENT.
21 // See the Apache 2 License for the specific language governing permissions and
22 // limitations under the License.
23 
24 // see note at the top of lattice-faster-decoder.cc, about how to maintain this
25 // file in sync with lattice-faster-decoder.cc
26 
28 #include "lat/lattice-functions.h"
29 
30 namespace kaldi {
31 
32 template <typename FST>
34  bool use_final_probs) const {
35  Lattice lat1;
36  {
37  Lattice raw_lat;
38  this->GetRawLattice(&raw_lat, use_final_probs);
39  ShortestPath(raw_lat, &lat1);
40  }
41  Lattice lat2;
42  GetBestPath(&lat2, use_final_probs);
43  BaseFloat delta = 0.1;
44  int32 num_paths = 1;
45  if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) {
46  KALDI_WARN << "Best-path test failed";
47  return false;
48  } else {
49  return true;
50  }
51 }
52 
53 
54 // Outputs an FST corresponding to the single best path through the lattice.
55 template <typename FST>
57  bool use_final_probs) const {
58  olat->DeleteStates();
59  BaseFloat final_graph_cost;
60  BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
61  if (iter.Done())
62  return false; // would have printed warning.
63  StateId state = olat->AddState();
64  olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
65  while (!iter.Done()) {
66  LatticeArc arc;
67  iter = TraceBackBestPath(iter, &arc);
68  arc.nextstate = state;
69  StateId new_state = olat->AddState();
70  olat->AddArc(new_state, arc);
71  state = new_state;
72  }
73  olat->SetStart(state);
74  return true;
75 }
76 
77 template <typename FST>
79  bool use_final_probs,
80  BaseFloat *final_cost_out) const {
81  if (this->decoding_finalized_ && !use_final_probs)
82  KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
83  << "BestPathEnd() with use_final_probs == false";
84  KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
85  "You cannot call BestPathEnd if no frames were decoded.");
86 
87  unordered_map<Token*, BaseFloat> final_costs_local;
88 
89  const unordered_map<Token*, BaseFloat> &final_costs =
90  (this->decoding_finalized_ ? this->final_costs_ :final_costs_local);
91  if (!this->decoding_finalized_ && use_final_probs)
92  this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
93 
94  // Singly linked list of tokens on last frame (access list through "next"
95  // pointer).
96  BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
97  BaseFloat best_final_cost = 0;
98  Token *best_tok = NULL;
99  for (Token *tok = this->active_toks_.back().toks;
100  tok != NULL; tok = tok->next) {
101  BaseFloat cost = tok->tot_cost, final_cost = 0.0;
102  if (use_final_probs && !final_costs.empty()) {
103  // if we are instructed to use final-probs, and any final tokens were
104  // active on final frame, include the final-prob in the cost of the token.
105  typename unordered_map<Token*, BaseFloat>::const_iterator
106  iter = final_costs.find(tok);
107  if (iter != final_costs.end()) {
108  final_cost = iter->second;
109  cost += final_cost;
110  } else {
111  cost = std::numeric_limits<BaseFloat>::infinity();
112  }
113  }
114  if (cost < best_cost) {
115  best_cost = cost;
116  best_tok = tok;
117  best_final_cost = final_cost;
118  }
119  }
120  if (best_tok == NULL) { // this should not happen, and is likely a code error or
121  // caused by infinities in likelihoods, but I'm not making
122  // it a fatal error for now.
123  KALDI_WARN << "No final token found.";
124  }
125  if (final_cost_out)
126  *final_cost_out = best_final_cost;
127  return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
128 }
129 
130 
131 template <typename FST>
133  BestPathIterator iter, LatticeArc *oarc) const {
134  KALDI_ASSERT(!iter.Done() && oarc != NULL);
135  Token *tok = static_cast<Token*>(iter.tok);
136  int32 cur_t = iter.frame, ret_t = cur_t;
137  if (tok->backpointer != NULL) {
138  ForwardLinkT *link;
139  for (link = tok->backpointer->links;
140  link != NULL; link = link->next) {
141  if (link->next_tok == tok) { // this is the link to "tok"
142  oarc->ilabel = link->ilabel;
143  oarc->olabel = link->olabel;
144  BaseFloat graph_cost = link->graph_cost,
145  acoustic_cost = link->acoustic_cost;
146  if (link->ilabel != 0) {
147  KALDI_ASSERT(static_cast<size_t>(cur_t) < this->cost_offsets_.size());
148  acoustic_cost -= this->cost_offsets_[cur_t];
149  ret_t--;
150  }
151  oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
152  break;
153  }
154  }
155  if (link == NULL) { // Did not find correct link.
156  KALDI_ERR << "Error tracing best-path back (likely "
157  << "bug in token-pruning algorithm)";
158  }
159  } else {
160  oarc->ilabel = 0;
161  oarc->olabel = 0;
162  oarc->weight = LatticeWeight::One(); // zero costs.
163  }
164  return BestPathIterator(tok->backpointer, ret_t);
165 }
166 
167 template <typename FST>
169  Lattice *ofst,
170  bool use_final_probs,
171  BaseFloat beam) const {
172  typedef LatticeArc Arc;
173  typedef Arc::StateId StateId;
174  typedef Arc::Weight Weight;
175  typedef Arc::Label Label;
176 
177  // Note: you can't use the old interface (Decode()) if you want to
178  // get the lattice with use_final_probs = false. You'd have to do
179  // InitDecoding() and then AdvanceDecoding().
180  if (this->decoding_finalized_ && !use_final_probs)
181  KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
182  << "GetRawLattice() with use_final_probs == false";
183 
184  unordered_map<Token*, BaseFloat> final_costs_local;
185 
186  const unordered_map<Token*, BaseFloat> &final_costs =
187  (this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
188  if (!this->decoding_finalized_ && use_final_probs)
189  this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
190 
191  ofst->DeleteStates();
192  // num-frames plus one (since frames are one-based, and we have
193  // an extra frame for the start-state).
194  int32 num_frames = this->active_toks_.size() - 1;
195  KALDI_ASSERT(num_frames > 0);
196  for (int32 f = 0; f <= num_frames; f++) {
197  if (this->active_toks_[f].toks == NULL) {
198  KALDI_WARN << "No tokens active on frame " << f
199  << ": not producing lattice.\n";
200  return false;
201  }
202  }
203  unordered_map<Token*, StateId> tok_map;
204  std::queue<std::pair<Token*, int32> > tok_queue;
205  // First initialize the queue and states. Put the initial state on the queue;
206  // this is the last token in the list active_toks_[0].toks.
207  for (Token *tok = this->active_toks_[0].toks;
208  tok != NULL; tok = tok->next) {
209  if (tok->next == NULL) {
210  tok_map[tok] = ofst->AddState();
211  ofst->SetStart(tok_map[tok]);
212  std::pair<Token*, int32> tok_pair(tok, 0); // #frame = 0
213  tok_queue.push(tok_pair);
214  }
215  }
216 
217  // Next create states for "good" tokens
218  while (!tok_queue.empty()) {
219  std::pair<Token*, int32> cur_tok_pair = tok_queue.front();
220  tok_queue.pop();
221  Token *cur_tok = cur_tok_pair.first;
222  int32 cur_frame = cur_tok_pair.second;
223  KALDI_ASSERT(cur_frame >= 0 &&
224  cur_frame <= this->cost_offsets_.size());
225 
226  typename unordered_map<Token*, StateId>::const_iterator iter =
227  tok_map.find(cur_tok);
228  KALDI_ASSERT(iter != tok_map.end());
229  StateId cur_state = iter->second;
230 
231  for (ForwardLinkT *l = cur_tok->links;
232  l != NULL;
233  l = l->next) {
234  Token *next_tok = l->next_tok;
235  if (next_tok->extra_cost < beam) {
236  // so both the current and the next token are good; create the arc
237  int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1;
238  StateId nextstate;
239  if (tok_map.find(next_tok) == tok_map.end()) {
240  nextstate = tok_map[next_tok] = ofst->AddState();
241  tok_queue.push(std::pair<Token*, int32>(next_tok, next_frame));
242  } else {
243  nextstate = tok_map[next_tok];
244  }
245  BaseFloat cost_offset = (l->ilabel != 0 ?
246  this->cost_offsets_[cur_frame] : 0);
247  Arc arc(l->ilabel, l->olabel,
248  Weight(l->graph_cost, l->acoustic_cost - cost_offset),
249  nextstate);
250  ofst->AddArc(cur_state, arc);
251  }
252  }
253  if (cur_frame == num_frames) {
254  if (use_final_probs && !final_costs.empty()) {
255  typename unordered_map<Token*, BaseFloat>::const_iterator iter =
256  final_costs.find(cur_tok);
257  if (iter != final_costs.end())
258  ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
259  } else {
260  ofst->SetFinal(cur_state, LatticeWeight::One());
261  }
262  }
263  }
264  return (ofst->NumStates() != 0);
265 }
266 
267 
268 
269 // Instantiate the template for the FST types that we'll need.
274 
275 
276 } // end namespace kaldi.
fst::StdArc::StateId StateId
fst::StdArc::Label Label
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
static const LatticeWeightTpl One()
Lattice::StateId StateId
kaldi::int32 int32
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also supports an efficient way to get...
BestPathIterator TraceBackBestPath(BestPathIterator iter, LatticeArc *arc) const
This function can be used in conjunction with BestPathEnd() to trace back the best path one link at a...
fst::StdArc::Label Label
bool TestGetBestPath(bool use_final_probs=true) const
This function does a self-test of GetBestPath().
fst::StdArc::Weight Weight
Arc::Weight Weight
Definition: kws-search.cc:31
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
BestPathIterator BestPathEnd(bool use_final_probs, BaseFloat *final_cost=NULL) const
This function returns an iterator that can be used to trace back the best path.
bool GetBestPath(Lattice *ofst, bool use_final_probs=true) const
Outputs an FST corresponding to the single best path through the lattice.
bool GetRawLatticePruned(Lattice *ofst, bool use_final_probs, BaseFloat beam) const
Behaves the same as GetRawLattice but only processes tokens whose extra_cost is smaller than the best...