21 #ifndef KALDI_DECODER_LATTICE_BIGLM_FASTER_DECODER_H_ 22 #define KALDI_DECODER_LATTICE_BIGLM_FASTER_DECODER_H_ 27 #include "fst/fstlib.h" 58 const fst::Fst<fst::StdArc> &
fst,
59 const LatticeBiglmFasterDecoderConfig &config,
65 lm_diff_fst->
Start() != fst::kNoStateId);
87 Token *start_tok =
new Token(0.0, 0.0, NULL, NULL);
89 toks_.Insert(start_pair, start_tok);
121 bool use_final_probs =
true)
const {
122 fst::VectorFst<LatticeArc>
fst;
127 ShortestPath(fst, ofst);
134 bool use_final_probs =
true)
const {
141 ofst->DeleteStates();
146 unordered_map<Token*, StateId> tok_map(
num_toks_/2 + 3);
148 for (
int32 f = 0; f <= num_frames; f++) {
150 KALDI_WARN <<
"GetRawLattice: no tokens active on frame " << f
151 <<
": not producing lattice.\n";
155 tok_map[tok] = ofst->AddState();
160 if (f == 0 && ofst->NumStates() > 0)
161 ofst->SetStart(ofst->NumStates()-1);
164 << tok_map.bucket_count() <<
" load:" << tok_map.load_factor()
165 <<
" max:" << tok_map.max_load_factor();
167 StateId cur_state = 0;
169 for (
int32 f = 0; f <= num_frames; f++) {
175 unordered_map<Token*, StateId>::const_iterator iter =
176 tok_map.find(l->next_tok);
177 StateId nextstate = iter->second;
179 Arc arc(l->ilabel, l->olabel,
180 Weight(l->graph_cost, l->acoustic_cost),
182 ofst->AddArc(cur_state, arc);
184 if (f == num_frames) {
186 std::map<Token*, BaseFloat>::const_iterator iter =
197 return (cur_state != 0);
205 bool use_final_probs =
true)
const {
209 if (!TopSort(&raw_fst))
210 KALDI_WARN <<
"Topological sorting of state-level lattice failed " 211 "(probably your lexicon has empty words or your LM has epsilon cycles; this " 214 fst::ILabelCompare<LatticeArc> ilabel_comp;
215 ArcSort(&raw_fst, ilabel_comp);
222 raw_fst.DeleteStates();
230 return static_cast<PairId
>(fst_state) + (static_cast<PairId>(lm_state) << 32);
234 return static_cast<StateId
>(
static_cast<uint32
>(state_pair));
237 return static_cast<StateId
>(
static_cast<uint32
>(state_pair >> 32));
254 next_tok(next_tok), ilabel(ilabel), olabel(olabel),
255 graph_cost(graph_cost), acoustic_cost(acoustic_cost),
277 Token *next): tot_cost(tot_cost), extra_cost(extra_cost),
278 links(links), next(next) { }
296 TokenList(): toks(NULL), must_prune_forward_links(true),
297 must_prune_tokens(true) { }
303 size_t new_sz =
static_cast<size_t>(
static_cast<BaseFloat>(num_toks)
305 if (new_sz >
toks_.Size()) {
306 toks_.SetSize(new_sz);
316 BaseFloat tot_cost,
bool emitting,
bool *changed) {
321 Elem *e_found =
toks_.Insert(state_pair, NULL);
322 if (e_found->
val == NULL) {
327 Token *new_tok =
new Token (tot_cost, extra_cost, NULL, toks);
331 e_found->
val = new_tok;
332 if (changed) *changed =
true;
345 if (changed) *changed =
true;
347 if (changed) *changed =
false;
365 *extra_costs_changed =
false;
366 *links_pruned =
false;
370 KALDI_WARN <<
"No tokens alive [doing pruning].. warning first " 371 "time only for each utterance\n";
384 BaseFloat tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
386 for (link = tok->links; link != NULL; ) {
397 if (prev_link != NULL) prev_link->
next = next_link;
398 else tok->links = next_link;
401 *links_pruned =
true;
403 if (link_extra_cost < 0.0) {
404 if (link_extra_cost < -0.01)
405 KALDI_WARN <<
"Negative extra_cost: " << link_extra_cost;
406 link_extra_cost = 0.0;
408 if (link_extra_cost < tok_extra_cost)
409 tok_extra_cost = link_extra_cost;
414 if (fabs(tok_extra_cost - tok->extra_cost) > delta)
416 tok->extra_cost = tok_extra_cost;
420 if (changed) *extra_costs_changed =
true;
434 KALDI_WARN <<
"No tokens alive at end of file\n";
439 const BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
441 best_cost_nofinal = infinity;
442 unordered_map<Token*, BaseFloat> tok_to_final_cost;
443 Elem *cur_toks =
toks_.Clear();
444 for (Elem *e = cur_toks, *e_tail; e != NULL; e = e_tail) {
445 PairId state_pair = e->key;
451 tok_to_final_cost[tok] = final_cost;
452 best_cost_final = std::min(best_cost_final, tok->
tot_cost + final_cost);
453 best_cost_nofinal = std::min(best_cost_nofinal, tok->
tot_cost);
475 BaseFloat final_cost = tok_to_final_cost[tok];
476 tok_extra_cost = (tok->tot_cost + final_cost) - best_cost_final;
478 tok_extra_cost = tok->tot_cost - best_cost_nofinal;
480 for (link = tok->links; link != NULL; ) {
488 if (prev_link != NULL) prev_link->
next = next_link;
489 else tok->links = next_link;
493 if (link_extra_cost < 0.0) {
494 if (link_extra_cost < -0.01)
495 KALDI_WARN <<
"Negative extra_cost: " << link_extra_cost;
496 link_extra_cost = 0.0;
498 if (link_extra_cost < tok_extra_cost)
499 tok_extra_cost = link_extra_cost;
509 tok_extra_cost = infinity;
512 if (!
ApproxEqual(tok->extra_cost, tok_extra_cost, delta))
514 tok->extra_cost = tok_extra_cost;
521 if (tok->extra_cost != infinity) {
524 BaseFloat final_cost = tok_to_final_cost[tok];
525 if (final_cost != infinity)
542 KALDI_WARN <<
"No tokens alive [doing pruning]\n";
544 for (tok = toks; tok != NULL; tok =
next_tok) {
545 next_tok = tok->
next;
546 if (tok->
extra_cost == std::numeric_limits<BaseFloat>::infinity()) {
549 if (prev_tok != NULL) prev_tok->
next = tok->
next;
550 else toks = tok->
next;
568 for (
int32 frame = cur_frame-1; frame >= 0; frame--) {
574 bool extra_costs_changed =
false, links_pruned =
false;
576 if (extra_costs_changed && frame > 0)
582 if (frame+1 < cur_frame &&
588 KALDI_VLOG(3) <<
"PruneActiveTokens: pruned tokens from " << num_toks_begin
604 for (
int32 frame = cur_frame-1; frame >= 0; frame--) {
611 KALDI_VLOG(3) <<
"PruneActiveTokensFinal: pruned tokens from " << num_toks_begin
617 BaseFloat *adaptive_beam, Elem **best_elem) {
618 BaseFloat best_weight = std::numeric_limits<BaseFloat>::infinity();
622 for (Elem *e = list_head; e != NULL; e = e->
tail, count++) {
624 if (w < best_weight) {
626 if (best_elem) *best_elem = e;
629 if (tok_count != NULL) *tok_count =
count;
630 if (adaptive_beam != NULL) *adaptive_beam =
config_.
beam;
634 for (Elem *e = list_head; e != NULL; e = e->
tail, count++) {
637 if (w < best_weight) {
639 if (best_elem) *best_elem = e;
642 if (tok_count != NULL) *tok_count =
count;
665 if (arc->olabel == 0) {
673 KALDI_WARN <<
"No arc available in LM (unlikely to be correct " 674 "if a statistical language model); will not warn again";
676 arc->weight = Weight::Zero();
680 arc->weight =
Times(arc->weight, lm_arc.weight);
681 arc->olabel = lm_arc.olabel;
682 return lm_arc.nextstate;
689 Elem *last_toks =
toks_.Clear();
690 Elem *best_elem = NULL;
696 BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
702 PairId state_pair = best_elem->
key;
706 for (fst::ArcIterator<fst::Fst<Arc> > aiter(
fst_, state);
709 Arc arc = aiter.Value();
710 if (arc.ilabel != 0) {
713 arc.weight =
Times(arc.weight,
716 if (new_weight + adaptive_beam < next_cutoff)
717 next_cutoff = new_weight + adaptive_beam;
725 for (Elem *e = last_toks, *e_tail; e != NULL; e = e_tail) {
727 PairId state_pair = e->key;
732 for (fst::ArcIterator<fst::Fst<Arc> > aiter(
fst_, state);
735 const Arc &arc_ref = aiter.Value();
736 if (arc_ref.ilabel != 0) {
738 StateId next_lm_state =
PropagateLm(lm_state, &arc);
743 if (tot_cost >= next_cutoff)
continue;
746 PairId next_pair =
ConstructPair(arc.nextstate, next_lm_state);
747 Elem *e_next =
FindOrAddToken(next_pair, frame, tot_cost,
true, NULL);
752 graph_cost, ac_cost, tok->
links);
771 BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
772 for (
const Elem *e =
toks_.GetList(); e != NULL; e = e->tail) {
775 best_cost = std::min(best_cost, static_cast<BaseFloat>(e->val->tot_cost));
779 KALDI_ERR <<
"Error in ProcessEmitting: no surviving tokens: frame is " 787 const Elem *e =
queue_.back();
790 PairId state_pair = e->
key;
794 if (cur_cost >= cutoff)
804 for (fst::ArcIterator<fst::Fst<Arc> > aiter(
fst_, state);
807 const Arc &arc_ref = aiter.Value();
808 if (arc_ref.ilabel == 0) {
810 StateId next_lm_state =
PropagateLm(lm_state, &arc);
813 if (tot_cost < cutoff) {
815 PairId next_pair =
ConstructPair(arc.nextstate, next_lm_state);
820 graph_cost, 0, tok->
links);
824 if (changed)
queue_.push_back(e_new);
842 const fst::Fst<fst::StdArc> &
fst_;
860 for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {
868 for (
size_t i = 0;
i < active_toks_.size();
i++) {
871 for (
Token *tok = active_toks_[
i].toks; tok != NULL; ) {
872 tok->DeleteForwardLinks();
879 active_toks_.clear();
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
bool GetBestPath(fst::MutableFst< LatticeArc > *ofst, bool use_final_probs=true) const
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)=0
Note: ilabel must not be epsilon.
fst::ArcTpl< LatticeWeight > LatticeArc
fst::DeterministicOnDemandFst< fst::StdArc > * lm_diff_fst_
void PruneForwardLinks(int32 frame, bool *extra_costs_changed, bool *links_pruned, BaseFloat delta)
virtual Weight Final(StateId s)=0
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
ForwardLink(Token *next_tok, Label ilabel, Label olabel, BaseFloat graph_cost, BaseFloat acoustic_cost, ForwardLink *next)
static const LatticeWeightTpl One()
void PossiblyResizeHash(size_t num_toks)
bool DeterminizeLatticePruned(const ExpandedFst< ArcTpl< Weight > > &ifst, double beam, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > *ofst, DeterminizeLatticePrunedOptions opts)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
StateId PropagateLm(StateId lm_state, Arc *arc)
PairId ConstructPair(StateId fst_state, StateId lm_state)
static StateId PairToLmState(PairId state_pair)
void DeleteForwardLinks()
virtual bool IsLastFrame(int32 frame) const =0
Returns true if this is the last frame.
const fst::Fst< fst::StdArc > & fst_
HashList< PairId, Token * >::Elem Elem
std::vector< TokenList > active_toks_
HashList< PairId, Token * > toks_
bool Decode(DecodableInterface *decodable)
std::map< Token *, BaseFloat > final_costs_
virtual StateId Start()=0
void PruneForwardLinksFinal(int32 frame)
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
Token(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLink *links, Token *next)
LatticeBiglmFasterDecoderConfig GetOptions()
void ProcessNonemitting(int32 frame)
Elem * FindOrAddToken(PairId state_pair, int32 frame, BaseFloat tot_cost, bool emitting, bool *changed)
~LatticeBiglmFasterDecoder()
fst::VectorFst< LatticeArc > Lattice
bool ReachedFinal() const
says whether a final-state was active on the last frame.
std::vector< const Elem *> queue_
This is as LatticeFasterDecoder, but does online composition between HCLG and the "difference languag...
bool GetRawLattice(fst::MutableFst< LatticeArc > *ofst, bool use_final_probs=true) const
void SetOptions(const LatticeBiglmFasterDecoderConfig &config)
Elem * Clear()
Clears the hash and gives the head of the current list to the user; ownership is transferred to the u...
LatticeBiglmFasterDecoderConfig config_
void PruneActiveTokensFinal(int32 cur_frame)
fst::StdArc::Weight Weight
fst::DeterminizeLatticePhonePrunedOptions det_opts
bool must_prune_forward_links
BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem)
Gets the weight cutoff. Also counts the active tokens.
void DeleteElems(Elem *list)
void PruneTokensForFrame(int32 frame)
#define KALDI_ASSERT(cond)
LatticeFasterDecoderConfig LatticeBiglmFasterDecoderConfig
static StateId PairToState(PairId state_pair)
virtual BaseFloat LogLikelihood(int32 frame, int32 index)=0
Returns the log likelihood, which will be negated in the decoder.
bool GetLattice(fst::MutableFst< CompactLatticeArc > *ofst, bool use_final_probs=true) const
std::vector< BaseFloat > tmp_array_
void Delete(Elem *e)
Think of this like delete().
void PruneActiveTokens(int32 cur_frame, BaseFloat delta)
void ProcessEmitting(DecodableInterface *decodable, int32 frame)
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
LatticeBiglmFasterDecoder(const fst::Fst< fst::StdArc > &fst, const LatticeBiglmFasterDecoderConfig &config, fst::DeterministicOnDemandFst< fst::StdArc > *lm_diff_fst)