20 #ifndef KALDI_DECODER_BIGLM_FASTER_DECODER_H_ 21 #define KALDI_DECODER_BIGLM_FASTER_DECODER_H_ 25 #include "fst/fstlib.h" 73 fst_(fst), lm_diff_fst_(lm_diff_fst), opts_(opts), warned_noarc_(false) {
77 lm_diff_fst->
Start() != fst::kNoStateId);
84 ClearToks(toks_.Clear());
89 ClearToks(toks_.Clear());
90 PairId start_pair = ConstructPair(fst_.Start(), lm_diff_fst_->Start());
91 Arc dummy_arc(0, 0, Weight::One(), fst_.Start());
93 toks_.Insert(start_pair,
new Token(dummy_arc, NULL));
94 ProcessNonemitting(std::numeric_limits<float>::max());
96 BaseFloat weight_cutoff = ProcessEmitting(decodable, frame);
97 ProcessNonemitting(weight_cutoff);
102 for (
const Elem *e = toks_.GetList(); e != NULL; e = e->
tail) {
103 PairId state_pair = e->key;
104 StateId state = PairToState(state_pair),
105 lm_state = PairToLmState(state_pair);
107 Times(e->val->weight_,
108 Times(fst_.Final(state), lm_diff_fst_->Final(lm_state)));
109 if (this_weight != Weight::Zero())
116 bool use_final_probs =
true) {
123 fst_out->DeleteStates();
124 Token *best_tok = NULL;
125 Weight best_final = Weight::Zero();
128 bool is_final = ReachedFinal();
130 for (
const Elem *e = toks_.GetList(); e != NULL; e = e->
tail)
131 if (best_tok == NULL || *best_tok < *(e->val) )
134 Weight best_weight = Weight::Zero();
135 for (
const Elem *e = toks_.GetList(); e != NULL; e = e->
tail) {
136 Weight fst_final = fst_.Final(PairToState(e->key)),
137 lm_final = lm_diff_fst_->Final(PairToLmState(e->key)),
138 final =
Times(fst_final, lm_final);
139 Weight this_weight =
Times(e->val->weight_,
final);
140 if (this_weight != Weight::Zero() &&
141 this_weight.Value() < best_weight.Value()) {
142 best_weight = this_weight;
148 if (best_tok == NULL)
return false;
150 std::vector<LatticeArc> arcs_reverse;
152 for (
Token *tok = best_tok; tok != NULL; tok = tok->
prev_) {
153 BaseFloat tot_cost = tok->weight_.Value() -
154 (tok->prev_ ? tok->prev_->weight_.Value() : 0.0),
155 graph_cost = tok->arc_.weight.Value(),
156 ac_cost = tot_cost - graph_cost;
160 tok->arc_.nextstate);
161 arcs_reverse.push_back(l_arc);
163 KALDI_ASSERT(arcs_reverse.back().nextstate == fst_.Start());
164 arcs_reverse.pop_back();
166 StateId cur_state = fst_out->AddState();
167 fst_out->SetStart(cur_state);
168 for (ssize_t
i = static_cast<ssize_t>(arcs_reverse.size())-1;
i >= 0;
i--) {
170 arc.nextstate = fst_out->AddState();
171 fst_out->AddArc(cur_state, arc);
172 cur_state = arc.nextstate;
174 if (is_final && use_final_probs) {
175 fst_out->SetFinal(cur_state,
LatticeWeight(best_final.Value(), 0.0));
185 return static_cast<PairId
>(fst_state) + (static_cast<PairId>(lm_state) << 32);
189 return static_cast<StateId
>(
static_cast<uint32
>(state_pair));
192 return static_cast<StateId
>(
static_cast<uint32
>(state_pair >> 32));
204 inline Token(
const Arc &arc, Weight &ac_weight,
Token *prev):
205 arc_(arc), prev_(prev), ref_count_(1) {
210 weight_ =
Times(arc.weight, ac_weight);
214 arc_(arc), prev_(prev), ref_count_(1) {
219 weight_ = arc.weight;
223 return weight_.Value() > other.
weight_.Value();
229 if (prev_ != NULL) TokenDelete(prev_);
244 BaseFloat *adaptive_beam, Elem **best_elem) {
247 if (opts_.max_active == std::numeric_limits<int32>::max() &&
248 opts_.min_active == 0) {
249 for (Elem *e = list_head; e != NULL; e = e->
tail, count++) {
251 if (w < best_weight) {
253 if (best_elem) *best_elem = e;
256 if (tok_count != NULL) *tok_count =
count;
257 if (adaptive_beam != NULL) *adaptive_beam = opts_.beam;
258 return best_weight + opts_.beam;
261 for (Elem *e = list_head; e != NULL; e = e->
tail, count++) {
263 tmp_array_.push_back(w);
264 if (w < best_weight) {
266 if (best_elem) *best_elem = e;
269 if (tok_count != NULL) *tok_count =
count;
271 BaseFloat beam_cutoff = best_weight + opts_.beam,
272 min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
273 max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
275 if (tmp_array_.size() >
static_cast<size_t>(opts_.max_active)) {
276 std::nth_element(tmp_array_.begin(),
277 tmp_array_.begin() + opts_.max_active,
279 max_active_cutoff = tmp_array_[opts_.max_active];
281 if (tmp_array_.size() >
static_cast<size_t>(opts_.min_active)) {
282 if (opts_.min_active == 0) min_active_cutoff = best_weight;
284 std::nth_element(tmp_array_.begin(),
285 tmp_array_.begin() + opts_.min_active,
286 tmp_array_.size() >
static_cast<size_t>(opts_.max_active) ?
287 tmp_array_.begin() + opts_.max_active :
289 min_active_cutoff = tmp_array_[opts_.min_active];
293 if (max_active_cutoff < beam_cutoff) {
295 *adaptive_beam = max_active_cutoff - best_weight + opts_.beam_delta;
296 return max_active_cutoff;
297 }
else if (min_active_cutoff > beam_cutoff) {
299 *adaptive_beam = min_active_cutoff - best_weight + opts_.beam_delta;
300 return min_active_cutoff;
302 *adaptive_beam = opts_.beam;
309 size_t new_sz =
static_cast<size_t>(
static_cast<BaseFloat>(num_toks)
311 if (new_sz > toks_.Size()) {
312 toks_.SetSize(new_sz);
318 if (arc->olabel == 0) {
322 bool ans = lm_diff_fst_->GetArc(lm_state, arc->olabel, &lm_arc);
324 if (!warned_noarc_) {
325 warned_noarc_ =
true;
326 KALDI_WARN <<
"No arc available in LM (unlikely to be correct " 327 "if a statistical language model); will not warn again";
329 arc->weight = Weight::Zero();
333 arc->weight =
Times(arc->weight, lm_arc.weight);
334 arc->olabel = lm_arc.olabel;
335 return lm_arc.nextstate;
342 Elem *last_toks = toks_.Clear();
345 Elem *best_elem = NULL;
346 BaseFloat weight_cutoff = GetCutoff(last_toks, &tok_cnt,
347 &adaptive_beam, &best_elem);
348 PossiblyResizeHash(tok_cnt);
358 PairId state_pair = best_elem->
key;
359 StateId state = PairToState(state_pair),
360 lm_state = PairToLmState(state_pair);
362 for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
365 Arc arc = aiter.Value();
366 if (arc.ilabel != 0) {
367 PropagateLm(lm_state, &arc);
370 new_weight = arc.weight.Value() + tok->
weight_.Value() + ac_cost;
371 if (new_weight + adaptive_beam < next_weight_cutoff)
372 next_weight_cutoff = new_weight + adaptive_beam;
380 for (Elem *e = last_toks, *e_tail; e != NULL; e = e_tail) {
382 PairId state_pair = e->key;
383 StateId state = PairToState(state_pair),
384 lm_state = PairToLmState(state_pair);
386 if (tok->
weight_.Value() < weight_cutoff) {
388 for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
391 Arc arc = aiter.Value();
392 if (arc.ilabel != 0) {
393 StateId next_lm_state = PropagateLm(lm_state, &arc);
394 Weight ac_weight(-decodable->
LogLikelihood(frame, arc.ilabel));
397 if (new_weight < next_weight_cutoff) {
398 PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
399 Token *new_tok =
new Token(arc, ac_weight, tok);
400 Elem *e_found = toks_.Insert(next_pair, new_tok);
401 if (new_weight + adaptive_beam < next_weight_cutoff)
402 next_weight_cutoff = new_weight + adaptive_beam;
403 if (e_found->
val != new_tok) {
404 if (*(e_found->
val) < *new_tok) {
405 Token::TokenDelete(e_found->
val);
406 e_found->
val = new_tok;
408 Token::TokenDelete(new_tok);
416 Token::TokenDelete(e->val);
419 return next_weight_cutoff;
426 for (
const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
428 while (!queue_.empty()) {
429 const Elem *e = queue_.back();
431 PairId state_pair = e->
key;
434 if (tok->
weight_.Value() > cutoff) {
438 StateId state = PairToState(state_pair),
439 lm_state = PairToLmState(state_pair);
440 for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
443 const Arc &arc_ref = aiter.Value();
444 if (arc_ref.ilabel == 0) {
446 StateId next_lm_state = PropagateLm(lm_state, &arc);
447 PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
449 if (new_tok->
weight_.Value() > cutoff) {
450 Token::TokenDelete(new_tok);
452 Elem *e_found = toks_.Insert(next_pair, new_tok);
453 if (e_found->
val == new_tok) {
454 queue_.push_back(e_found);
456 if ( *(e_found->
val) < *new_tok ) {
457 Token::TokenDelete(e_found->
val);
458 e_found->
val = new_tok;
459 queue_.push_back(e_found);
461 Token::TokenDelete(new_tok);
474 const fst::Fst<fst::StdArc> &
fst_;
489 for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {
490 Token::TokenDelete(e->val);
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
fst::ArcTpl< LatticeWeight > LatticeArc
HashList< PairId, Token * > toks_
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
static const LatticeWeightTpl One()
void RemoveEpsLocal(MutableFst< Arc > *fst)
RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST, using an algorithm that is g...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
virtual bool IsLastFrame(int32 frame) const =0
Returns true if this is the last frame.
BiglmFasterDecoder(const fst::Fst< fst::StdArc > &fst, const BiglmFasterDecoderOptions &opts, fst::DeterministicOnDemandFst< fst::StdArc > *lm_diff_fst)
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
virtual StateId Start()=0
const fst::Fst< fst::StdArc > & fst_
StateId PropagateLm(StateId lm_state, Arc *arc)
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
BiglmFasterDecoderOptions()
void ClearToks(Elem *list)
This is as FasterDecoder, but does online composition between HCLG and the "difference language model...
void ProcessNonemitting(BaseFloat cutoff)
PairId ConstructPair(StateId fst_state, StateId lm_state)
Token(const Arc &arc, Token *prev)
void SetOptions(const BiglmFasterDecoderOptions &opts)
bool GetBestPath(fst::MutableFst< LatticeArc > *fst_out, bool use_final_probs=true)
fst::DeterministicOnDemandFst< fst::StdArc > * lm_diff_fst_
Token(const Arc &arc, Weight &ac_weight, Token *prev)
void PossiblyResizeHash(size_t num_toks)
std::vector< BaseFloat > tmp_array_
BaseFloat ProcessEmitting(DecodableInterface *decodable, int frame)
fst::StdArc::Weight Weight
bool operator<(const Int32Pair &a, const Int32Pair &b)
BiglmFasterDecoderOptions opts_
static StateId PairToLmState(PairId state_pair)
std::vector< const Elem *> queue_
#define KALDI_ASSERT(cond)
void Decode(DecodableInterface *decodable)
BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem)
Gets the weight cutoff. Also counts the active tokens.
static StateId PairToState(PairId state_pair)
virtual BaseFloat LogLikelihood(int32 frame, int32 index)=0
Returns the log likelihood, which will be negated in the decoder.
static void TokenDelete(Token *tok)
void Delete(Elem *e)
Think of this like delete().
HashList< PairId, Token * >::Elem Elem