online-faster-decoder.cc
Go to the documentation of this file.
1 // online/online-faster-decoder.cc
2 
3 // Copyright 2012 Cisco Systems (author: Matthias Paulik)
4 
5 // Modifications to the original contribution by Cisco Systems made by:
6 // Vassil Panayotov
7 
8 // See ../../COPYING for clarification regarding multiple authors
9 //
10 // Licensed under the Apache License, Version 2.0 (the "License");
11 // you may not use this file except in compliance with the License.
12 // You may obtain a copy of the License at
13 //
14 // http://www.apache.org/licenses/LICENSE-2.0
15 //
16 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
18 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
19 // MERCHANTABLITY OR NON-INFRINGEMENT.
20 // See the Apache 2 License for the specific language governing permissions and
21 // limitations under the License.
22 
23 #include "base/timer.h"
24 #include "online-faster-decoder.h"
25 #include "fstext/fstext-utils.h"
26 #include "hmm/hmm-utils.h"
27 
28 namespace kaldi {
29 
32  StateId start_state = fst_.Start();
33  KALDI_ASSERT(start_state != fst::kNoStateId);
34  Arc dummy_arc(0, 0, Weight::One(), start_state);
35  Token *dummy_token = new Token(dummy_arc, NULL);
36  toks_.Insert(start_state, dummy_token);
37  prev_immortal_tok_ = immortal_tok_ = dummy_token;
38  utt_frames_ = 0;
39  if (full)
40  frame_ = 0;
41 }
42 
43 
44 void
46  const Token *end,
47  fst::MutableFst<LatticeArc> *out_fst) const {
48  out_fst->DeleteStates();
49  if (start == NULL) return;
50  bool is_final = false;
51  double this_cost = start->cost_ + fst_.Final(start->arc_.nextstate).Value();
52  if (this_cost != std::numeric_limits<double>::infinity())
53  is_final = true;
54  std::vector<LatticeArc> arcs_reverse; // arcs in reverse order.
55  for (const Token *tok = start; tok != end; tok = tok->prev_) {
56  BaseFloat tot_cost = tok->cost_ -
57  (tok->prev_ ? tok->prev_->cost_ : 0.0),
58  graph_cost = tok->arc_.weight.Value(),
59  ac_cost = tot_cost - graph_cost;
60  LatticeArc l_arc(tok->arc_.ilabel,
61  tok->arc_.olabel,
62  LatticeWeight(graph_cost, ac_cost),
63  tok->arc_.nextstate);
64  arcs_reverse.push_back(l_arc);
65  }
66  if(arcs_reverse.back().nextstate == fst_.Start()) {
67  arcs_reverse.pop_back(); // that was a "fake" token... gives no info.
68  }
69  StateId cur_state = out_fst->AddState();
70  out_fst->SetStart(cur_state);
71  for (ssize_t i = static_cast<ssize_t>(arcs_reverse.size())-1; i >= 0; i--) {
72  LatticeArc arc = arcs_reverse[i];
73  arc.nextstate = out_fst->AddState();
74  out_fst->AddArc(cur_state, arc);
75  cur_state = arc.nextstate;
76  }
77  if (is_final) {
78  Weight final_weight = fst_.Final(start->arc_.nextstate);
79  out_fst->SetFinal(cur_state, LatticeWeight(final_weight.Value(), 0.0));
80  } else {
81  out_fst->SetFinal(cur_state, LatticeWeight::One());
82  }
83  RemoveEpsLocal(out_fst);
84 }
85 
86 
88  unordered_set<Token*> emitting;
89  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
90  Token* tok = e->val;
91  while (tok != NULL && tok->arc_.ilabel == 0) //deal with non-emitting ones ...
92  tok = tok->prev_;
93  if (tok != NULL)
94  emitting.insert(tok);
95  }
96  Token* the_one = NULL;
97  while (1) {
98  if (emitting.size() == 1) {
99  the_one = *(emitting.begin());
100  break;
101  }
102  if (emitting.size() == 0)
103  break;
104  unordered_set<Token*> prev_emitting;
105  unordered_set<Token*>::iterator it;
106  for (it = emitting.begin(); it != emitting.end(); ++it) {
107  Token* tok = *it;
108  Token* prev_token = tok->prev_;
109  while ((prev_token != NULL) && (prev_token->arc_.ilabel == 0))
110  prev_token = prev_token->prev_; //deal with non-emitting ones
111  if (prev_token == NULL)
112  continue;
113  prev_emitting.insert(prev_token);
114  } // for
115  emitting = prev_emitting;
116  } // while
117  if (the_one != NULL) {
119  immortal_tok_ = the_one;
120  return;
121  }
122 }
123 
124 
125 bool
126 OnlineFasterDecoder::PartialTraceback(fst::MutableFst<LatticeArc> *out_fst) {
129  return false; //no partial traceback at that point of time
131  return true;
132 }
133 
134 
135 void
136 OnlineFasterDecoder::FinishTraceBack(fst::MutableFst<LatticeArc> *out_fst) {
137  Token *best_tok = NULL;
138  bool is_final = ReachedFinal();
139  if (!is_final) {
140  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
141  if (best_tok == NULL || *best_tok < *(e->val) )
142  best_tok = e->val;
143  } else {
144  double best_cost = std::numeric_limits<double>::infinity();
145  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
146  double this_cost = e->val->cost_ + fst_.Final(e->key).Value();
147  if (this_cost != std::numeric_limits<double>::infinity() &&
148  this_cost < best_cost) {
149  best_cost = this_cost;
150  best_tok = e->val;
151  }
152  }
153  }
154  MakeLattice(best_tok, immortal_tok_, out_fst);
155 }
156 
157 
158 void
160  fst::MutableFst<LatticeArc> *out_fst) {
161  Token *best_tok = NULL;
162  for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
163  if (best_tok == NULL || *best_tok < *(e->val) )
164  best_tok = e->val;
165  if (best_tok == NULL) {
166  out_fst->DeleteStates();
167  return;
168  }
169 
170  bool is_final = false;
171  double this_cost = best_tok->cost_ +
172  fst_.Final(best_tok->arc_.nextstate).Value();
173 
174  if (this_cost != std::numeric_limits<double>::infinity())
175  is_final = true;
176  std::vector<LatticeArc> arcs_reverse; // arcs in reverse order.
177  for (Token *tok = best_tok; (tok != NULL) && (nframes > 0); tok = tok->prev_) {
178  if (tok->arc_.ilabel != 0) // count only the non-epsilon arcs
179  --nframes;
180  BaseFloat tot_cost = tok->cost_ -
181  (tok->prev_ ? tok->prev_->cost_ : 0.0);
182  BaseFloat graph_cost = tok->arc_.weight.Value();
183  BaseFloat ac_cost = tot_cost - graph_cost;
184  LatticeArc larc(tok->arc_.ilabel,
185  tok->arc_.olabel,
186  LatticeWeight(graph_cost, ac_cost),
187  tok->arc_.nextstate);
188  arcs_reverse.push_back(larc);
189  }
190  if(arcs_reverse.back().nextstate == fst_.Start())
191  arcs_reverse.pop_back(); // that was a "fake" token... gives no info.
192  StateId cur_state = out_fst->AddState();
193  out_fst->SetStart(cur_state);
194  for (ssize_t i = static_cast<ssize_t>(arcs_reverse.size())-1; i >= 0; i--) {
195  LatticeArc arc = arcs_reverse[i];
196  arc.nextstate = out_fst->AddState();
197  out_fst->AddArc(cur_state, arc);
198  cur_state = arc.nextstate;
199  }
200  if (is_final) {
201  Weight final_weight = fst_.Final(best_tok->arc_.nextstate);
202  out_fst->SetFinal(cur_state, LatticeWeight(final_weight.Value(), 0.0));
203  } else {
204  out_fst->SetFinal(cur_state, LatticeWeight::One());
205  }
206  RemoveEpsLocal(out_fst);
207 }
208 
209 
211  fst::VectorFst<LatticeArc> trace;
213  TracebackNFrames(sil_frm, &trace);
214  std::vector<int32> isymbols;
215  fst::GetLinearSymbolSequence(trace, &isymbols,
216  static_cast<std::vector<int32>* >(0),
217  static_cast<LatticeArc::Weight*>(0));
218  std::vector<std::vector<int32> > split;
219  SplitToPhones(trans_model_, isymbols, &split);
220  for (size_t i = 0; i < split.size(); i++) {
221  int32 tid = split[i][0];
223  if (silence_set_.count(phone) == 0)
224  return false;
225  }
226  return true;
227 }
228 
229 
232  if (state_ == kEndFeats || state_ == kEndUtt) // new utterance
234  ProcessNonemitting(std::numeric_limits<float>::max());
235  int32 batch_frame = 0;
236  Timer timer;
237  double64 tstart = timer.Elapsed(), tstart_batch = tstart;
238  BaseFloat factor = -1;
239  for (; !decodable->IsLastFrame(frame_ - 1) && batch_frame < opts_.batch_size;
240  ++frame_, ++utt_frames_, ++batch_frame) {
241  if (batch_frame != 0 && (batch_frame % opts_.update_interval) == 0) {
242  // adjust the beam if needed
243  BaseFloat tend = timer.Elapsed();
244  BaseFloat elapsed = (tend - tstart) * 1000;
245  // warning: hardcoded 10ms frames assumption!
246  factor = elapsed / (opts_.rt_max * opts_.update_interval * 10);
247  BaseFloat min_factor = (opts_.rt_min / opts_.rt_max);
248  if (factor > 1 || factor < min_factor) {
249  BaseFloat update_factor = (factor > 1)?
250  -std::min(opts_.beam_update * factor, opts_.max_beam_update):
251  std::min(opts_.beam_update / factor, opts_.max_beam_update);
252  effective_beam_ += effective_beam_ * update_factor;
254  }
255  tstart = tend;
256  }
257  if (batch_frame != 0 && (frame_ % 200) == 0)
258  // one log message at every 2 seconds assuming 10ms frames
259  KALDI_VLOG(3) << "Beam: " << effective_beam_
260  << "; Speed: "
261  << ((timer.Elapsed() - tstart_batch) * 1000) / (batch_frame*10)
262  << " xRT";
263  BaseFloat weight_cutoff = ProcessEmitting(decodable);
264  ProcessNonemitting(weight_cutoff);
265  }
266  if (batch_frame == opts_.batch_size && !decodable->IsLastFrame(frame_ - 1)) {
267  if (EndOfUtterance())
268  state_ = kEndUtt;
269  else
270  state_ = kEndBatch;
271  } else {
272  state_ = kEndFeats;
273  }
274  return state_;
275 }
276 
277 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
Elem * Insert(I key, T val)
Insert inserts a new element into the hashtable/stored list.
bool PartialTraceback(fst::MutableFst< LatticeArc > *out_fst)
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
void ClearToks(Elem *list)
static const LatticeWeightTpl One()
void RemoveEpsLocal(MutableFst< Arc > *fst)
RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST, using an algorithm that is g...
virtual bool IsLastFrame(int32 frame) const =0
Returns true if this is the last frame.
kaldi::int32 int32
DecodeState Decode(DecodableInterface *decodable)
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
bool SplitToPhones(const TransitionModel &trans_model, const std::vector< int32 > &alignment, std::vector< std::vector< int32 > > *split_alignment)
SplitToPhones splits up the TransitionIds in "alignment" into their individual phones (one vector per...
Definition: hmm-utils.cc:723
const fst::Fst< fst::StdArc > & fst_
void FinishTraceBack(fst::MutableFst< LatticeArc > *fst_out)
void MakeLattice(const Token *start, const Token *end, fst::MutableFst< LatticeArc > *out_fst) const
void ProcessNonemitting(double cutoff)
const Elem * GetList() const
Gives the head of the current list to the user.
Definition: hash-list-inl.h:61
const ConstIntegerSet< int32 > silence_set_
Elem * Clear()
Clears the hash and gives the head of the current list to the user; ownership is transferred to the u...
Definition: hash-list-inl.h:46
HashList< StateId, Token * > toks_
double double64
Definition: kaldi-types.h:54
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void TracebackNFrames(int32 nframes, fst::MutableFst< LatticeArc > *out_fst)
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
Arc::StateId StateId
const TransitionModel & trans_model_
double ProcessEmitting(DecodableInterface *decodable)
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
int32 TransitionIdToPhone(int32 trans_id) const
bool ReachedFinal() const
Returns true if a final state was active on the last frame.
const OnlineFasterDecoderOpts opts_