29 template <
typename FST,
typename Token>
33 fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) {
39 template <
typename FST,
typename Token>
42 fst_(fst), delete_fst_(true), config_(config), num_toks_(0) {
48 template <
typename FST,
typename Token>
50 DeleteElems(toks_.Clear());
52 if (delete_fst_)
delete fst_;
55 template <
typename FST,
typename Token>
58 DeleteElems(toks_.Clear());
59 cost_offsets_.clear();
63 decoding_finalized_ =
false;
65 StateId start_state = fst_->Start();
67 active_toks_.resize(1);
68 Token *start_tok =
new Token(0.0, 0.0, NULL, NULL, NULL);
69 active_toks_[0].toks = start_tok;
70 toks_.Insert(start_state, start_tok);
72 ProcessNonemitting(config_.beam);
78 template <
typename FST,
typename Token>
84 AdvanceDecoding(decodable);
89 return !active_toks_.empty() && active_toks_.back().toks != NULL;
94 template <
typename FST,
typename Token>
96 bool use_final_probs)
const {
98 GetRawLattice(&raw_lat, use_final_probs);
99 ShortestPath(raw_lat, olat);
100 return (olat->NumStates() != 0);
105 template <
typename FST,
typename Token>
108 bool use_final_probs)
const {
117 if (decoding_finalized_ && !use_final_probs)
118 KALDI_ERR <<
"You cannot call FinalizeDecoding() and then call " 119 <<
"GetRawLattice() with use_final_probs == false";
121 unordered_map<Token*, BaseFloat> final_costs_local;
123 const unordered_map<Token*, BaseFloat> &final_costs =
124 (decoding_finalized_ ? final_costs_ : final_costs_local);
125 if (!decoding_finalized_ && use_final_probs)
126 ComputeFinalCosts(&final_costs_local, NULL, NULL);
128 ofst->DeleteStates();
131 int32 num_frames = active_toks_.size() - 1;
133 const int32 bucket_count = num_toks_/2 + 3;
134 unordered_map<Token*, StateId> tok_map(bucket_count);
136 std::vector<Token*> token_list;
137 for (
int32 f = 0; f <= num_frames; f++) {
138 if (active_toks_[f].toks == NULL) {
139 KALDI_WARN <<
"GetRawLattice: no tokens active on frame " << f
140 <<
": not producing lattice.\n";
143 TopSortTokens(active_toks_[f].toks, &token_list);
144 for (
size_t i = 0;
i < token_list.size();
i++)
145 if (token_list[
i] != NULL)
146 tok_map[token_list[
i]] = ofst->AddState();
152 KALDI_VLOG(4) <<
"init:" << num_toks_/2 + 3 <<
" buckets:" 153 << tok_map.bucket_count() <<
" load:" << tok_map.load_factor()
154 <<
" max:" << tok_map.max_load_factor();
156 for (
int32 f = 0; f <= num_frames; f++) {
157 for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) {
158 StateId cur_state = tok_map[tok];
162 typename unordered_map<Token*, StateId>::const_iterator
163 iter = tok_map.find(l->next_tok);
164 StateId nextstate = iter->second;
167 if (l->ilabel != 0) {
169 cost_offset = cost_offsets_[f];
171 Arc arc(l->ilabel, l->olabel,
172 Weight(l->graph_cost, l->acoustic_cost - cost_offset),
174 ofst->AddArc(cur_state, arc);
176 if (f == num_frames) {
177 if (use_final_probs && !final_costs.empty()) {
178 typename unordered_map<Token*, BaseFloat>::const_iterator
179 iter = final_costs.find(tok);
180 if (iter != final_costs.end())
183 ofst->SetFinal(cur_state, LatticeWeight::One());
188 return (ofst->NumStates() > 0);
195 template <
typename FST,
typename Token>
197 bool use_final_probs)
const {
199 GetRawLattice(&raw_fst, use_final_probs);
202 fst::ILabelCompare<LatticeArc> ilabel_comp;
203 ArcSort(&raw_fst, ilabel_comp);
207 lat_opts.
max_mem = config_.det_opts.max_mem;
210 raw_fst.DeleteStates();
215 return (ofst->NumStates() != 0);
218 template <
typename FST,
typename Token>
220 size_t new_sz =
static_cast<size_t>(
static_cast<BaseFloat>(num_toks)
221 * config_.hash_ratio);
222 if (new_sz > toks_.Size()) {
223 toks_.SetSize(new_sz);
251 template <
typename FST,
typename Token>
255 Token *backpointer,
bool *changed) {
259 Token *&toks = active_toks_[frame_plus_one].toks;
260 Elem *e_found = toks_.Insert(state, NULL);
261 if (e_found->val == NULL) {
266 Token *new_tok =
new Token (tot_cost, extra_cost, NULL, toks, backpointer);
270 e_found->val = new_tok;
271 if (changed) *changed =
true;
274 Token *tok = e_found->val;
275 if (tok->tot_cost > tot_cost) {
276 tok->tot_cost = tot_cost;
279 tok->SetBackpointer(backpointer);
287 if (changed) *changed =
true;
289 if (changed) *changed =
false;
298 template <
typename FST,
typename Token>
300 int32 frame_plus_one,
bool *extra_costs_changed,
308 *extra_costs_changed =
false;
309 *links_pruned =
false;
310 KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
311 if (active_toks_[frame_plus_one].toks == NULL) {
313 KALDI_WARN <<
"No tokens alive [doing pruning].. warning first " 314 "time only for each utterance\n";
324 for (Token *tok = active_toks_[frame_plus_one].toks;
325 tok != NULL; tok = tok->next) {
328 BaseFloat tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
330 for (link = tok->links; link != NULL; ) {
333 BaseFloat link_extra_cost = next_tok->extra_cost +
335 - next_tok->tot_cost);
339 if (link_extra_cost > config_.lattice_beam) {
341 if (prev_link != NULL) prev_link->
next = next_link;
342 else tok->links = next_link;
345 *links_pruned =
true;
347 if (link_extra_cost < 0.0) {
348 if (link_extra_cost < -0.01)
349 KALDI_WARN <<
"Negative extra_cost: " << link_extra_cost;
350 link_extra_cost = 0.0;
352 if (link_extra_cost < tok_extra_cost)
353 tok_extra_cost = link_extra_cost;
358 if (fabs(tok_extra_cost - tok->extra_cost) > delta)
360 tok->extra_cost = tok_extra_cost;
364 if (changed) *extra_costs_changed =
true;
375 template <
typename FST,
typename Token>
378 int32 frame_plus_one = active_toks_.size() - 1;
380 if (active_toks_[frame_plus_one].toks == NULL)
381 KALDI_WARN <<
"No tokens alive at end of file";
383 typedef typename unordered_map<Token*, BaseFloat>::const_iterator IterType;
384 ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_);
385 decoding_finalized_ =
true;
390 DeleteElems(toks_.Clear());
400 for (Token *tok = active_toks_[frame_plus_one].toks;
401 tok != NULL; tok = tok->next) {
408 if (final_costs_.empty()) {
411 IterType iter = final_costs_.find(tok);
412 if (iter != final_costs_.end())
413 final_cost = iter->second;
415 final_cost = std::numeric_limits<BaseFloat>::infinity();
417 BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_;
421 for (link = tok->links; link != NULL; ) {
424 BaseFloat link_extra_cost = next_tok->extra_cost +
426 - next_tok->tot_cost);
427 if (link_extra_cost > config_.lattice_beam) {
429 if (prev_link != NULL) prev_link->
next = next_link;
430 else tok->links = next_link;
434 if (link_extra_cost < 0.0) {
435 if (link_extra_cost < -0.01)
436 KALDI_WARN <<
"Negative extra_cost: " << link_extra_cost;
437 link_extra_cost = 0.0;
439 if (link_extra_cost < tok_extra_cost)
440 tok_extra_cost = link_extra_cost;
449 if (tok_extra_cost > config_.lattice_beam)
450 tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
453 if (!
ApproxEqual(tok->extra_cost, tok_extra_cost, delta))
455 tok->extra_cost = tok_extra_cost;
460 template <
typename FST,
typename Token>
462 if (!decoding_finalized_) {
464 ComputeFinalCosts(NULL, &relative_cost, NULL);
465 return relative_cost;
469 return final_relative_cost_;
478 template <
typename FST,
typename Token>
480 KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
481 Token *&toks = active_toks_[frame_plus_one].toks;
483 KALDI_WARN <<
"No tokens alive [doing pruning]";
484 Token *tok, *next_tok, *prev_tok = NULL;
485 for (tok = toks; tok != NULL; tok = next_tok) {
486 next_tok = tok->next;
487 if (tok->extra_cost == std::numeric_limits<BaseFloat>::infinity()) {
490 if (prev_tok != NULL) prev_tok->next = tok->next;
491 else toks = tok->next;
505 template <
typename FST,
typename Token>
507 int32 cur_frame_plus_one = NumFramesDecoded();
508 int32 num_toks_begin = num_toks_;
511 for (
int32 f = cur_frame_plus_one - 1; f >= 0; f--) {
516 if (active_toks_[f].must_prune_forward_links) {
517 bool extra_costs_changed =
false, links_pruned =
false;
518 PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta);
519 if (extra_costs_changed && f > 0)
520 active_toks_[f-1].must_prune_forward_links =
true;
522 active_toks_[f].must_prune_tokens =
true;
523 active_toks_[f].must_prune_forward_links =
false;
525 if (f+1 < cur_frame_plus_one &&
526 active_toks_[f+1].must_prune_tokens) {
527 PruneTokensForFrame(f+1);
528 active_toks_[f+1].must_prune_tokens =
false;
531 KALDI_VLOG(4) <<
"PruneActiveTokens: pruned tokens from " << num_toks_begin
532 <<
" to " << num_toks_;
535 template <
typename FST,
typename Token>
537 unordered_map<Token*, BaseFloat> *final_costs,
541 if (final_costs != NULL)
542 final_costs->clear();
543 const Elem *final_toks = toks_.GetList();
544 BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
546 best_cost_with_final = infinity;
548 while (final_toks != NULL) {
549 StateId state = final_toks->key;
550 Token *tok = final_toks->val;
551 const Elem *next = final_toks->tail;
552 BaseFloat final_cost = fst_->Final(state).Value();
554 cost_with_final = cost + final_cost;
555 best_cost = std::min(cost, best_cost);
556 best_cost_with_final = std::min(cost_with_final, best_cost_with_final);
557 if (final_costs != NULL && final_cost != infinity)
558 (*final_costs)[tok] = final_cost;
561 if (final_relative_cost != NULL) {
562 if (best_cost == infinity && best_cost_with_final == infinity) {
565 *final_relative_cost = infinity;
567 *final_relative_cost = best_cost_with_final - best_cost;
570 if (final_best_cost != NULL) {
571 if (best_cost_with_final != infinity) {
572 *final_best_cost = best_cost_with_final;
574 *final_best_cost = best_cost;
579 template <
typename FST,
typename Token>
581 int32 max_num_frames) {
582 if (std::is_same<FST, fst::Fst<fst::StdArc> >::value) {
586 if (fst_->Type() ==
"const") {
591 }
else if (fst_->Type() ==
"vector") {
600 KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ &&
601 "You must call InitDecoding() before AdvanceDecoding");
608 int32 target_frames_decoded = num_frames_ready;
609 if (max_num_frames >= 0)
610 target_frames_decoded = std::min(target_frames_decoded,
611 NumFramesDecoded() + max_num_frames);
612 while (NumFramesDecoded() < target_frames_decoded) {
613 if (NumFramesDecoded() % config_.prune_interval == 0) {
614 PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
616 BaseFloat cost_cutoff = ProcessEmitting(decodable);
617 ProcessNonemitting(cost_cutoff);
624 template <
typename FST,
typename Token>
626 int32 final_frame_plus_one = NumFramesDecoded();
627 int32 num_toks_begin = num_toks_;
630 PruneForwardLinksFinal();
631 for (
int32 f = final_frame_plus_one - 1; f >= 0; f--) {
634 PruneForwardLinks(f, &b1, &b2, dontcare);
635 PruneTokensForFrame(f + 1);
637 PruneTokensForFrame(0);
638 KALDI_VLOG(4) <<
"pruned tokens from " << num_toks_begin
639 <<
" to " << num_toks_;
643 template <
typename FST,
typename Token>
646 BaseFloat best_weight = std::numeric_limits<BaseFloat>::infinity();
649 if (config_.max_active == std::numeric_limits<int32>::max() &&
650 config_.min_active == 0) {
651 for (
Elem *e = list_head; e != NULL; e = e->tail, count++) {
653 if (w < best_weight) {
655 if (best_elem) *best_elem = e;
658 if (tok_count != NULL) *tok_count =
count;
659 if (adaptive_beam != NULL) *adaptive_beam = config_.beam;
660 return best_weight + config_.beam;
663 for (
Elem *e = list_head; e != NULL; e = e->tail, count++) {
665 tmp_array_.push_back(w);
666 if (w < best_weight) {
668 if (best_elem) *best_elem = e;
671 if (tok_count != NULL) *tok_count =
count;
673 BaseFloat beam_cutoff = best_weight + config_.beam,
674 min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
675 max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
677 KALDI_VLOG(6) <<
"Number of tokens active on frame " << NumFramesDecoded()
678 <<
" is " << tmp_array_.size();
680 if (tmp_array_.size() >
static_cast<size_t>(config_.max_active)) {
681 std::nth_element(tmp_array_.begin(),
682 tmp_array_.begin() + config_.max_active,
684 max_active_cutoff = tmp_array_[config_.max_active];
686 if (max_active_cutoff < beam_cutoff) {
688 *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta;
689 return max_active_cutoff;
691 if (tmp_array_.size() >
static_cast<size_t>(config_.min_active)) {
692 if (config_.min_active == 0) min_active_cutoff = best_weight;
694 std::nth_element(tmp_array_.begin(),
695 tmp_array_.begin() + config_.min_active,
696 tmp_array_.size() >
static_cast<size_t>(config_.max_active) ?
697 tmp_array_.begin() + config_.max_active :
699 min_active_cutoff = tmp_array_[config_.min_active];
702 if (min_active_cutoff > beam_cutoff) {
704 *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta;
705 return min_active_cutoff;
707 *adaptive_beam = config_.beam;
713 template <
typename FST,
typename Token>
717 int32 frame = active_toks_.size() - 1;
720 active_toks_.resize(active_toks_.size() + 1);
722 Elem *final_toks = toks_.Clear();
725 Elem *best_elem = NULL;
728 BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
729 KALDI_VLOG(6) <<
"Adaptive beam on frame " << NumFramesDecoded() <<
" is " 732 PossiblyResizeHash(tok_cnt);
734 BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
745 StateId state = best_elem->key;
746 Token *tok = best_elem->val;
747 cost_offset = - tok->tot_cost;
748 for (fst::ArcIterator<FST> aiter(*fst_, state);
751 const Arc &arc = aiter.Value();
752 if (arc.ilabel != 0) {
753 BaseFloat new_weight = arc.weight.Value() + cost_offset -
755 if (new_weight + adaptive_beam < next_cutoff)
756 next_cutoff = new_weight + adaptive_beam;
764 cost_offsets_.resize(frame + 1, 0.0);
765 cost_offsets_[frame] = cost_offset;
770 for (
Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) {
774 if (tok->tot_cost <= cur_cutoff) {
775 for (fst::ArcIterator<FST> aiter(*fst_, state);
778 const Arc &arc = aiter.Value();
779 if (arc.ilabel != 0) {
782 graph_cost = arc.weight.Value(),
783 cur_cost = tok->tot_cost,
784 tot_cost = cur_cost + ac_cost + graph_cost;
785 if (tot_cost >= next_cutoff)
continue;
786 else if (tot_cost + adaptive_beam < next_cutoff)
787 next_cutoff = tot_cost + adaptive_beam;
790 Elem *e_next = FindOrAddToken(arc.nextstate,
791 frame + 1, tot_cost, tok, NULL);
795 tok->links =
new ForwardLinkT(e_next->val, arc.ilabel, arc.olabel,
796 graph_cost, ac_cost, tok->links);
807 template <
typename FST,
typename Token>
819 template <
typename FST,
typename Token>
822 int32 frame =
static_cast<int32>(active_toks_.size()) - 2;
835 if (toks_.GetList() == NULL) {
837 KALDI_WARN <<
"Error, no surviving tokens: frame is " << frame;
842 for (
const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
844 if (fst_->NumInputEpsilons(state) != 0)
848 while (!queue_.empty()) {
849 const Elem *e = queue_.back();
855 if (cur_cost >= cutoff)
861 DeleteForwardLinks(tok);
863 for (fst::ArcIterator<FST> aiter(*fst_, state);
866 const Arc &arc = aiter.Value();
867 if (arc.ilabel == 0) {
868 BaseFloat graph_cost = arc.weight.Value(),
869 tot_cost = cur_cost + graph_cost;
870 if (tot_cost < cutoff) {
873 Elem *e_new = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
876 tok->links =
new ForwardLinkT(e_new->val, 0, arc.olabel,
877 graph_cost, 0, tok->links);
881 if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0)
882 queue_.push_back(e_new);
890 template <
typename FST,
typename Token>
892 for (
Elem *e = list, *e_tail; e != NULL; e = e_tail) {
898 template <
typename FST,
typename Token>
900 for (
size_t i = 0;
i < active_toks_.size();
i++) {
903 for (Token *tok = active_toks_[
i].toks; tok != NULL; ) {
904 DeleteForwardLinks(tok);
905 Token *next_tok = tok->next;
911 active_toks_.clear();
916 template <
typename FST,
typename Token>
918 Token *tok_list, std::vector<Token*> *topsorted_list) {
919 unordered_map<Token*, int32> token2pos;
920 typedef typename unordered_map<Token*, int32>::iterator IterType;
922 for (Token *tok = tok_list; tok != NULL; tok = tok->next)
929 for (Token *tok = tok_list; tok != NULL; tok = tok->next)
930 token2pos[tok] = num_toks - ++cur_pos;
932 unordered_set<Token*> reprocess;
934 for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) {
935 Token *tok = iter->first;
936 int32 pos = iter->second;
938 if (link->ilabel == 0) {
942 IterType following_iter = token2pos.find(link->next_tok);
943 if (following_iter != token2pos.end()) {
945 int32 next_pos = following_iter->second;
946 if (next_pos < pos) {
947 following_iter->second = cur_pos++;
948 reprocess.insert(link->next_tok);
955 reprocess.erase(tok);
958 size_t max_loop = 1000000, loop_count;
960 !reprocess.empty() && loop_count < max_loop; ++loop_count) {
961 std::vector<Token*> reprocess_vec;
962 for (
typename unordered_set<Token*>::iterator iter = reprocess.begin();
963 iter != reprocess.end(); ++iter)
964 reprocess_vec.push_back(*iter);
966 for (
typename std::vector<Token*>::iterator iter = reprocess_vec.begin();
967 iter != reprocess_vec.end(); ++iter) {
969 int32 pos = token2pos[tok];
972 if (link->ilabel == 0) {
973 IterType following_iter = token2pos.find(link->next_tok);
974 if (following_iter != token2pos.end()) {
975 int32 next_pos = following_iter->second;
976 if (next_pos < pos) {
977 following_iter->second = cur_pos++;
978 reprocess.insert(link->next_tok);
985 KALDI_ASSERT(loop_count < max_loop &&
"Epsilon loops exist in your decoding " 986 "graph (this is not allowed!)");
988 topsorted_list->clear();
989 topsorted_list->resize(cur_pos, NULL);
990 for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter)
991 (*topsorted_list)[iter->second] = iter->first;
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
fst::ArcTpl< LatticeWeight > LatticeArc
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
typename fst::StdFst ::Arc Arc
bool DeterminizeLatticePruned(const ExpandedFst< ArcTpl< Weight > > &ifst, double beam, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > *ofst, DeterminizeLatticePrunedOptions opts)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
LatticeFasterDecoderTpl(const FST &fst, const LatticeFasterDecoderConfig &config)
fst::VectorFst< LatticeArc > Lattice
fst::VectorFst< CompactLatticeArc > CompactLattice
fst::StdArc::Weight Weight
This is the "normal" lattice-generating decoder.
void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames=-1)
This will decode until there are no more frames ready in the decodable object.
#define KALDI_ASSERT(cond)
typename HashList< StateId, Token * >::Elem Elem
virtual BaseFloat LogLikelihood(int32 frame, int32 index)=0
Returns the log likelihood, which will be negated in the decoder.
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
typename Arc::StateId StateId