27 template <
typename FST,
typename Token>
35 determinizer_(trans_model, config) {
40 template <
typename FST,
typename Token>
48 determinizer_(trans_model, config) {
53 template <
typename FST,
typename Token>
55 DeleteElems(toks_.Clear());
57 if (delete_fst_)
delete fst_;
60 template <
typename FST,
typename Token>
63 DeleteElems(toks_.Clear());
64 cost_offsets_.clear();
68 decoding_finalized_ =
false;
70 StateId start_state = fst_->Start();
72 active_toks_.resize(1);
73 Token *start_tok =
new Token(0.0, 0.0, NULL, NULL, NULL);
74 active_toks_[0].toks = start_tok;
75 toks_.Insert(start_state, start_tok);
79 num_frames_in_lattice_ = 0;
80 token2label_map_.clear();
81 next_token_label_ = LatticeIncrementalDeterminizer::kTokenLabelOffset;
82 ProcessNonemitting(config_.beam);
85 template <
typename FST,
typename Token>
87 if (NumFramesDecoded() - num_frames_in_lattice_ <
88 config_.determinize_max_delay)
95 PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
97 int32 first = num_frames_in_lattice_ + config_.determinize_min_chunk_size,
98 last = NumFramesDecoded(),
99 fewest_tokens = std::numeric_limits<int32>::max(),
101 for (
int32 t = last; t >= first; t--) {
105 if (active_toks_[t].num_toks < fewest_tokens) {
107 fewest_tokens = active_toks_[t].num_toks;
113 bool use_final_probs =
false;
114 GetLattice(best_frame, use_final_probs);
120 template <
typename FST,
typename Token>
128 while (!decodable->
IsLastFrame(NumFramesDecoded() - 1)) {
129 if (NumFramesDecoded() % config_.prune_interval == 0) {
130 PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
132 UpdateLatticeDeterminization();
134 BaseFloat cost_cutoff = ProcessEmitting(decodable);
135 ProcessNonemitting(cost_cutoff);
139 bool use_final_probs =
true;
140 GetLattice(NumFramesDecoded(), use_final_probs);
141 KALDI_VLOG(2) <<
"Delay time during and after FinalizeDecoding()" 142 <<
"(secs): " << timer.
Elapsed();
146 return !active_toks_.empty() && active_toks_.back().toks != NULL;
150 template <
typename FST,
typename Token>
153 static_cast<size_t>(
static_cast<BaseFloat>(num_toks) * config_.hash_ratio);
154 if (new_sz > toks_.Size()) {
155 toks_.SetSize(new_sz);
196 template <
typename FST,
typename Token>
203 Token *&toks = active_toks_[frame_plus_one].toks;
204 Elem *e_found = toks_.Find(state);
205 if (e_found == NULL) {
210 Token *new_tok =
new Token(tot_cost, extra_cost, NULL, toks, backpointer);
214 toks_.Insert(state, new_tok);
215 if (changed) *changed =
true;
218 Token *tok = e_found->val;
219 if (tok->tot_cost > tot_cost) {
220 tok->tot_cost = tot_cost;
223 tok->SetBackpointer(backpointer);
231 if (changed) *changed =
true;
233 if (changed) *changed =
false;
242 template <
typename FST,
typename Token>
244 int32 frame_plus_one,
bool *extra_costs_changed,
bool *links_pruned,
252 *extra_costs_changed =
false;
253 *links_pruned =
false;
254 KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
255 if (active_toks_[frame_plus_one].toks == NULL) {
257 KALDI_WARN <<
"No tokens alive [doing pruning].. warning first " 258 "time only for each utterance\n";
268 for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL;
272 BaseFloat tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
274 for (link = tok->links; link != NULL;) {
278 next_tok->extra_cost +
284 if (link_extra_cost > config_.lattice_beam) {
286 if (prev_link != NULL)
287 prev_link->
next = next_link;
289 tok->links = next_link;
292 *links_pruned =
true;
294 if (link_extra_cost < 0.0) {
295 if (link_extra_cost < -0.01)
296 KALDI_WARN <<
"Negative extra_cost: " << link_extra_cost;
297 link_extra_cost = 0.0;
299 if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost;
304 if (fabs(tok_extra_cost - tok->extra_cost) > delta)
306 tok->extra_cost = tok_extra_cost;
310 if (changed) *extra_costs_changed =
true;
321 template <
typename FST,
typename Token>
324 int32 frame_plus_one = active_toks_.size() - 1;
326 if (active_toks_[frame_plus_one].toks == NULL)
327 KALDI_WARN <<
"No tokens alive at end of file";
329 typedef typename unordered_map<Token *, BaseFloat>::const_iterator IterType;
330 ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_);
331 decoding_finalized_ =
true;
336 DeleteElems(toks_.Clear());
346 for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL;
355 if (final_costs_.empty()) {
358 IterType iter = final_costs_.find(tok);
359 if (iter != final_costs_.end())
360 final_cost = iter->second;
362 final_cost = std::numeric_limits<BaseFloat>::infinity();
364 BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_;
368 for (link = tok->links; link != NULL;) {
372 next_tok->extra_cost +
375 if (link_extra_cost > config_.lattice_beam) {
377 if (prev_link != NULL)
378 prev_link->
next = next_link;
380 tok->links = next_link;
384 if (link_extra_cost < 0.0) {
385 if (link_extra_cost < -0.01)
386 KALDI_WARN <<
"Negative extra_cost: " << link_extra_cost;
387 link_extra_cost = 0.0;
389 if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost;
398 if (tok_extra_cost > config_.lattice_beam)
399 tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
402 if (!
ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) changed =
true;
403 tok->extra_cost = tok_extra_cost;
408 template <
typename FST,
typename Token>
411 ComputeFinalCosts(NULL, &relative_cost, NULL);
412 return relative_cost;
419 template <
typename FST,
typename Token>
421 int32 frame_plus_one) {
422 KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
423 Token *&toks = active_toks_[frame_plus_one].toks;
424 if (toks == NULL)
KALDI_WARN <<
"No tokens alive [doing pruning]";
425 Token *tok, *next_tok, *prev_tok = NULL;
427 for (tok = toks; tok != NULL; tok = next_tok, num_toks++) {
428 next_tok = tok->next;
429 if (tok->extra_cost == std::numeric_limits<BaseFloat>::infinity()) {
432 if (prev_tok != NULL)
433 prev_tok->next = tok->next;
442 active_toks_[frame_plus_one].num_toks = num_toks;
450 template <
typename FST,
typename Token>
452 int32 cur_frame_plus_one = NumFramesDecoded();
453 int32 num_toks_begin = num_toks_;
455 if (active_toks_[cur_frame_plus_one].num_toks == -1){
459 int this_frame_num_toks = 0;
460 for (Token *t = active_toks_[cur_frame_plus_one].toks; t != NULL; t = t->next)
461 this_frame_num_toks++;
462 active_toks_[cur_frame_plus_one].num_toks = this_frame_num_toks;
467 for (
int32 f = cur_frame_plus_one - 1; f >= 0; f--) {
472 if (active_toks_[f].must_prune_forward_links) {
473 bool extra_costs_changed =
false, links_pruned =
false;
474 PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta);
475 if (extra_costs_changed && f > 0)
476 active_toks_[f - 1].must_prune_forward_links =
true;
478 active_toks_[f].must_prune_tokens =
true;
479 active_toks_[f].must_prune_forward_links =
false;
481 if (f + 1 < cur_frame_plus_one &&
482 active_toks_[f + 1].must_prune_tokens) {
483 PruneTokensForFrame(f + 1);
484 active_toks_[f + 1].must_prune_tokens =
false;
487 KALDI_VLOG(4) <<
"pruned tokens from " << num_toks_begin
488 <<
" to " << num_toks_;
491 template <
typename FST,
typename Token>
493 unordered_map<Token *, BaseFloat> *final_costs,
BaseFloat *final_relative_cost,
495 if (decoding_finalized_) {
498 if (final_costs) *final_costs = final_costs_;
499 if (final_relative_cost) *final_relative_cost = final_relative_cost_;
500 if (final_best_cost) *final_best_cost = final_best_cost_;
503 if (final_costs != NULL) final_costs->clear();
504 const Elem *final_toks = toks_.GetList();
505 BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
506 BaseFloat best_cost = infinity, best_cost_with_final = infinity;
508 while (final_toks != NULL) {
509 StateId state = final_toks->key;
510 Token *tok = final_toks->val;
511 const Elem *next = final_toks->tail;
512 BaseFloat final_cost = fst_->Final(state).Value();
513 BaseFloat cost = tok->tot_cost, cost_with_final = cost + final_cost;
514 best_cost = std::min(cost, best_cost);
515 best_cost_with_final = std::min(cost_with_final, best_cost_with_final);
516 if (final_costs != NULL && final_cost != infinity)
517 (*final_costs)[tok] = final_cost;
520 if (final_relative_cost != NULL) {
521 if (best_cost == infinity && best_cost_with_final == infinity) {
524 *final_relative_cost = infinity;
526 *final_relative_cost = best_cost_with_final - best_cost;
529 if (final_best_cost != NULL) {
530 if (best_cost_with_final != infinity) {
531 *final_best_cost = best_cost_with_final;
533 *final_best_cost = best_cost;
538 template <
typename FST,
typename Token>
541 if (std::is_same<FST, fst::Fst<fst::StdArc> >::value) {
545 if (fst_->Type() ==
"const") {
552 }
else if (fst_->Type() ==
"vector") {
562 KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ &&
563 "You must call InitDecoding() before AdvanceDecoding");
570 int32 target_frames_decoded = num_frames_ready;
571 if (max_num_frames >= 0)
572 target_frames_decoded =
573 std::min(target_frames_decoded, NumFramesDecoded() + max_num_frames);
574 while (NumFramesDecoded() < target_frames_decoded) {
575 if (NumFramesDecoded() % config_.prune_interval == 0) {
576 PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
578 BaseFloat cost_cutoff = ProcessEmitting(decodable);
579 ProcessNonemitting(cost_cutoff);
581 UpdateLatticeDeterminization();
587 template <
typename FST,
typename Token>
589 int32 final_frame_plus_one = NumFramesDecoded();
590 int32 num_toks_begin = num_toks_;
593 PruneForwardLinksFinal();
594 for (
int32 f = final_frame_plus_one - 1; f >= 0; f--) {
597 PruneForwardLinks(f, &b1, &b2, dontcare);
598 PruneTokensForFrame(f + 1);
600 PruneTokensForFrame(0);
601 KALDI_VLOG(4) <<
"pruned tokens from " << num_toks_begin <<
" to " << num_toks_;
605 template <
typename FST,
typename Token>
607 Elem *list_head,
size_t *tok_count,
BaseFloat *adaptive_beam,
Elem **best_elem) {
608 BaseFloat best_weight = std::numeric_limits<BaseFloat>::infinity();
611 if (config_.max_active == std::numeric_limits<int32>::max() &&
612 config_.min_active == 0) {
613 for (
Elem *e = list_head; e != NULL; e = e->tail, count++) {
615 if (w < best_weight) {
617 if (best_elem) *best_elem = e;
620 if (tok_count != NULL) *tok_count =
count;
621 if (adaptive_beam != NULL) *adaptive_beam = config_.beam;
622 return best_weight + config_.beam;
625 for (
Elem *e = list_head; e != NULL; e = e->tail, count++) {
627 tmp_array_.push_back(w);
628 if (w < best_weight) {
630 if (best_elem) *best_elem = e;
633 if (tok_count != NULL) *tok_count =
count;
635 BaseFloat beam_cutoff = best_weight + config_.beam,
636 min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
637 max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
639 KALDI_VLOG(6) <<
"Number of tokens active on frame " << NumFramesDecoded()
640 <<
" is " << tmp_array_.size();
642 if (tmp_array_.size() >
static_cast<size_t>(config_.max_active)) {
643 std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.max_active,
645 max_active_cutoff = tmp_array_[config_.max_active];
647 if (max_active_cutoff < beam_cutoff) {
649 *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta;
650 return max_active_cutoff;
652 if (tmp_array_.size() >
static_cast<size_t>(config_.min_active)) {
653 if (config_.min_active == 0)
654 min_active_cutoff = best_weight;
656 std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.min_active,
657 tmp_array_.size() >
static_cast<size_t>(config_.max_active)
658 ? tmp_array_.begin() + config_.max_active
660 min_active_cutoff = tmp_array_[config_.min_active];
663 if (min_active_cutoff > beam_cutoff) {
665 *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta;
666 return min_active_cutoff;
668 *adaptive_beam = config_.beam;
674 template <
typename FST,
typename Token>
678 int32 frame = active_toks_.size() - 1;
681 active_toks_.resize(active_toks_.size() + 1);
683 Elem *final_toks = toks_.Clear();
686 Elem *best_elem = NULL;
689 BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
690 KALDI_VLOG(6) <<
"Adaptive beam on frame " << NumFramesDecoded() <<
" is " 693 PossiblyResizeHash(tok_cnt);
695 BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
705 StateId state = best_elem->key;
706 Token *tok = best_elem->val;
707 cost_offset = -tok->tot_cost;
708 for (fst::ArcIterator<FST> aiter(*fst_, state); !aiter.Done(); aiter.Next()) {
709 const Arc &arc = aiter.Value();
710 if (arc.ilabel != 0) {
711 BaseFloat new_weight = arc.weight.Value() + cost_offset -
714 if (new_weight + adaptive_beam < next_cutoff)
715 next_cutoff = new_weight + adaptive_beam;
723 cost_offsets_.resize(frame + 1, 0.0);
724 cost_offsets_[frame] = cost_offset;
729 for (
Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) {
733 if (tok->tot_cost <= cur_cutoff) {
734 for (fst::ArcIterator<FST> aiter(*fst_, state); !aiter.Done(); aiter.Next()) {
735 const Arc &arc = aiter.Value();
736 if (arc.ilabel != 0) {
739 graph_cost = arc.weight.Value(), cur_cost = tok->tot_cost,
740 tot_cost = cur_cost + ac_cost + graph_cost;
741 if (tot_cost >= next_cutoff)
743 else if (tot_cost + adaptive_beam < next_cutoff)
744 next_cutoff = tot_cost + adaptive_beam;
748 FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, NULL);
752 tok->links =
new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, graph_cost,
753 ac_cost, tok->links);
764 template <
typename FST,
typename Token>
775 template <
typename FST,
typename Token>
778 int32 frame =
static_cast<int32>(active_toks_.size()) - 2;
791 if (toks_.GetList() == NULL) {
793 KALDI_WARN <<
"Error, no surviving tokens: frame is " << frame;
798 for (
const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
800 if (fst_->NumInputEpsilons(state) != 0) queue_.push_back(state);
803 while (!queue_.empty()) {
811 if (cur_cost >= cutoff)
817 DeleteForwardLinks(tok);
819 for (fst::ArcIterator<FST> aiter(*fst_, state); !aiter.Done(); aiter.Next()) {
820 const Arc &arc = aiter.Value();
821 if (arc.ilabel == 0) {
822 BaseFloat graph_cost = arc.weight.Value(), tot_cost = cur_cost + graph_cost;
823 if (tot_cost < cutoff) {
827 FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, &changed);
830 new ForwardLinkT(new_tok, 0, arc.olabel, graph_cost, 0, tok->links);
834 if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0)
835 queue_.push_back(arc.nextstate);
842 template <
typename FST,
typename Token>
844 for (
Elem *e = list, *e_tail; e != NULL; e = e_tail) {
850 template <
typename FST,
typename Token>
852 FST, Token>::ClearActiveTokens() {
853 for (
size_t i = 0;
i < active_toks_.size();
i++) {
856 for (Token *tok = active_toks_[
i].toks; tok != NULL;) {
857 DeleteForwardLinks(tok);
858 Token *next_tok = tok->next;
864 active_toks_.clear();
869 template <
typename FST,
typename Token>
871 int32 num_frames_to_include,
872 bool use_final_probs) {
873 KALDI_ASSERT(num_frames_to_include >= num_frames_in_lattice_ &&
874 num_frames_to_include <= NumFramesDecoded());
877 if (num_frames_in_lattice_ > 0 &&
878 determinizer_.GetLattice().NumStates() == 0) {
882 num_frames_in_lattice_ = num_frames_to_include;
883 return determinizer_.GetLattice();
886 if (decoding_finalized_ && !use_final_probs) {
888 KALDI_ERR <<
"You cannot get the lattice without final-probs after " 889 "calling FinalizeDecoding().";
891 if (use_final_probs && num_frames_to_include != NumFramesDecoded()) {
894 KALDI_ERR <<
"use-final-probs may no be true if you are not " 895 "getting a lattice for all frames decoded so far.";
899 if (num_frames_to_include > num_frames_in_lattice_) {
902 PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
904 if (determinizer_.GetLattice().NumStates() == 0 ||
905 determinizer_.GetLattice().Final(0) != CompactLatticeWeight::Zero()) {
906 num_frames_in_lattice_ = 0;
907 determinizer_.Init();
912 unordered_map<Label, LatticeArc::StateId> token_label2state;
913 if (num_frames_in_lattice_ != 0) {
914 determinizer_.InitializeRawLatticeChunk(&chunk_lat,
920 unordered_map<Token*, StateId> &tok2state_map(temp_token_map_);
921 tok2state_map.clear();
923 unordered_map<Token*, Label> &next_token2label_map(token2label_map_temp_);
924 next_token2label_map.clear();
931 int32 frame = num_frames_to_include;
934 for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
939 if (decoding_finalized_) {
940 if (final_costs_.empty()) {
944 auto iter = final_costs_.find(tok);
945 if (iter == final_costs_.end())
946 final_cost = std::numeric_limits<BaseFloat>::infinity();
948 final_cost = iter->second;
960 final_cost = tok->extra_cost - tok->tot_cost;
964 StateId state = chunk_lat.AddState();
965 tok2state_map[tok] = state;
966 if (final_cost < std::numeric_limits<BaseFloat>::infinity()) {
967 next_token2label_map[tok] = AllocateNewTokenLabel();
968 StateId token_final_state = chunk_lat.AddState();
970 olabel = (next_token2label_map[tok] = AllocateNewTokenLabel());
971 chunk_lat.AddArc(state,
973 LatticeWeight::One(),
975 chunk_lat.SetFinal(token_final_state,
LatticeWeight(final_cost, 0.0));
982 for (
int32 frame = num_frames_to_include;
983 frame >= num_frames_in_lattice_; frame--) {
985 BaseFloat cost_offset = (frame < cost_offsets_.size() ?
986 cost_offsets_[frame] : 0.0);
990 if (frame == num_frames_in_lattice_ && num_frames_in_lattice_ != 0) {
991 for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
992 auto iter = token2label_map_.find(tok);
994 Label token_label = iter->second;
995 auto iter2 = token_label2state.find(token_label);
996 if (iter2 != token_label2state.end()) {
998 tok2state_map[tok] = state;
1004 StateId state = chunk_lat.AddState();
1005 tok2state_map[tok] = state;
1008 }
else if (frame != num_frames_to_include) {
1010 for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
1011 StateId state = chunk_lat.AddState();
1012 tok2state_map[tok] = state;
1015 for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
1016 auto iter = tok2state_map.find(tok);
1018 StateId cur_state = iter->second;
1020 auto next_iter = tok2state_map.find(l->next_tok);
1021 if (next_iter == tok2state_map.end()) {
1027 StateId next_state = next_iter->second;
1028 BaseFloat this_offset = (l->ilabel != 0 ? cost_offset : 0);
1030 LatticeWeight(l->graph_cost, l->acoustic_cost - this_offset),
1034 chunk_lat.AddArc(cur_state, arc);
1038 if (num_frames_in_lattice_ == 0) {
1043 Token *tok = active_toks_[0].toks;
1045 KALDI_WARN <<
"No tokens exist on start frame";
1046 return determinizer_.GetLattice();
1048 while (tok->next != NULL)
1050 Token *start_token = tok;
1051 auto iter = tok2state_map.find(start_token);
1053 StateId start_state = iter->second;
1054 chunk_lat.SetStart(start_state);
1056 token2label_map_.swap(next_token2label_map);
1059 determinizer_.AcceptRawLatticeChunk(&chunk_lat);
1062 num_frames_in_lattice_ = num_frames_to_include;
1064 if (determinizer_.GetLattice().NumStates() == 0)
1065 return determinizer_.GetLattice();
1068 unordered_map<Token*, BaseFloat> token2final_cost;
1069 unordered_map<Label, BaseFloat> token_label2final_cost;
1070 if (use_final_probs) {
1071 ComputeFinalCosts(&token2final_cost, NULL, NULL);
1072 for (
const auto &p: token2final_cost) {
1073 Token *tok = p.first;
1075 auto iter = token2label_map_.find(tok);
1076 if (iter != token2label_map_.end()) {
1078 Label token_label = iter->second;
1079 bool ret = token_label2final_cost.insert({token_label, cost}).second;
1086 determinizer_.SetFinalCosts(token_label2final_cost.empty() ? NULL :
1087 &token_label2final_cost);
1089 return determinizer_.GetLattice();
1093 template <
typename FST,
typename Token>
1096 for (Token *tok = active_toks_[frame].toks; tok; tok = tok->next) r++;
1110 const std::vector<int32> &
string = clat_arc.weight.String();
1111 size_t N =
string.size();
1115 arc.olabel = clat_arc.ilabel;
1116 arc.nextstate = clat_arc.nextstate;
1117 arc.weight = clat_arc.weight.Weight();
1118 lat->AddArc(src_state, arc);
1121 for (
size_t i = 0;
i < N;
i++) {
1123 arc.ilabel =
string[
i];
1124 arc.olabel = (
i == 0 ? clat_arc.ilabel : 0);
1125 arc.nextstate = (
i + 1 == N ? clat_arc.nextstate : lat->AddState());
1126 arc.weight = (
i == 0 ? clat_arc.weight.Weight() : LatticeWeight::One());
1127 lat->AddArc(cur_state, arc);
1128 cur_state = arc.nextstate;
1134 void LatticeIncrementalDeterminizer::Init() {
1135 non_final_redet_states_.clear();
1136 clat_.DeleteStates();
1137 final_arcs_.clear();
1138 forward_costs_.clear();
1144 forward_costs_.push_back(std::numeric_limits<BaseFloat>::infinity());
1146 arcs_in_.resize(ans + 1);
1150 void LatticeIncrementalDeterminizer::AddArcToClat(
1153 BaseFloat forward_cost = forward_costs_[state] +
1155 if (forward_cost == std::numeric_limits<BaseFloat>::infinity())
1157 int32 arc_idx = clat_.NumArcs(state);
1158 clat_.AddArc(state, arc);
1159 arcs_in_[arc.nextstate].push_back({state, arc_idx});
1160 if (forward_cost < forward_costs_[arc.nextstate])
1161 forward_costs_[arc.nextstate] = forward_cost;
1165 void LatticeIncrementalDeterminizer::IdentifyTokenFinalStates(
1167 std::unordered_map<CompactLattice::StateId, CompactLatticeArc::Label> *token_map)
const {
1172 StateId num_states = chunk_clat.NumStates();
1173 for (
StateId state = 0; state < num_states; state++) {
1174 for (fst::ArcIterator<CompactLattice> aiter(chunk_clat, state);
1175 !aiter.Done(); aiter.Next()) {
1177 if (arc.olabel >= kTokenLabelOffset && arc.olabel < kMaxTokenLabel) {
1178 StateId nextstate = arc.nextstate;
1179 auto r = token_map->insert({nextstate, arc.olabel});
1190 void LatticeIncrementalDeterminizer::GetNonFinalRedetStates() {
1192 non_final_redet_states_.clear();
1193 non_final_redet_states_.reserve(final_arcs_.size());
1195 std::vector<StateId> state_queue;
1199 StateId redet_state = arc.nextstate;
1200 if (forward_costs_[redet_state] != std::numeric_limits<BaseFloat>::infinity()) {
1202 if (non_final_redet_states_.insert(redet_state).second) {
1204 state_queue.push_back(redet_state);
1209 while (!state_queue.empty()) {
1210 StateId s = state_queue.back();
1211 state_queue.pop_back();
1212 for (fst::ArcIterator<CompactLattice> aiter(clat_, s); !aiter.Done();
1215 StateId nextstate = arc.nextstate;
1216 if (non_final_redet_states_.insert(nextstate).second)
1217 state_queue.push_back(nextstate);
1223 void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk(
1225 unordered_map<Label, LatticeArc::StateId> *token_label2state) {
1226 using namespace fst;
1228 olat->DeleteStates();
1230 olat->SetStart(start_state);
1231 token_label2state->clear();
1238 unordered_map<CompactLattice::StateId, LatticeArc::StateId> redet_state_map;
1241 redet_state_map[redet_state] = olat->AddState();
1251 for (ArcIterator<CompactLattice> aiter(clat_, redet_state);
1252 !aiter.Done(); aiter.Next()) {
1256 auto r = redet_state_map.insert({nextstate, lat_nextstate});
1262 lat_nextstate = r.first->second;
1265 clat_arc.nextstate = lat_nextstate;
1268 clat_.DeleteArcs(redet_state);
1269 clat_.SetFinal(redet_state, CompactLatticeWeight::Zero());
1275 auto iter = redet_state_map.find(src_state);
1276 if (forward_costs_[src_state] == std::numeric_limits<BaseFloat>::infinity())
1280 Label token_label = arc.ilabel;
1282 token_label < kMaxTokenLabel);
1283 auto r = token_label2state->insert({token_label,
1284 olat->NumStates()});
1291 new_arc.nextstate = dest_lat_state;
1293 new_arc.ilabel = new_arc.olabel = 0;
1294 new_arc.weight = arc.weight;
1312 BaseFloat forward_cost = forward_costs_[state_id];
1317 arc.olabel = state_id + kStateLabelOffset;
1322 auto iter = redet_state_map.find(state_id);
1324 arc.nextstate = iter->second;
1325 olat->AddArc(start_state, arc);
1329 void LatticeIncrementalDeterminizer::GetRawLatticeFinalCosts(
1331 std::unordered_map<Label, BaseFloat> *old_final_costs) {
1334 for (fst::ArcIterator<Lattice> aiter(raw_fst, s); !aiter.Done();
1337 if (value.olabel >= (
Label)kTokenLabelOffset &&
1338 value.olabel < (
Label)kMaxTokenLabel) {
1339 LatticeWeight final_weight = raw_fst.Final(value.nextstate);
1340 if (final_weight != LatticeWeight::Zero() &&
1341 final_weight.
Value2() != 0) {
1342 KALDI_ERR <<
"Label " << value.olabel <<
" from state " << s
1343 <<
" looks like a token-label but its next-state " 1344 << value.nextstate <<
1345 " has unexpected final-weight " << final_weight.
Value1() <<
',' 1346 << final_weight.
Value2();
1348 auto r = old_final_costs->insert({value.olabel,
1350 if (!r.second && r.first->second != final_weight.
Value1()) {
1354 KALDI_ERR <<
"Unexpected mismatch in final-costs for tokens, " 1355 << r.first->second <<
" vs " << final_weight.
Value1();
1363 bool LatticeIncrementalDeterminizer::ProcessArcsFromChunkStartState(
1365 std::unordered_map<CompactLattice::StateId, CompactLattice::StateId> *state_map) {
1367 StateId clat_num_states = clat_.NumStates();
1373 for (fst::ArcIterator<CompactLattice> aiter(chunk_clat, chunk_clat.Start());
1374 !aiter.Done(); aiter.Next()) {
1376 Label label = arc.ilabel;
1378 if (!(label >= kStateLabelOffset &&
1379 label - kStateLabelOffset < clat_num_states)) {
1385 StateId clat_state = label - kStateLabelOffset;
1386 StateId chunk_state = arc.nextstate;
1387 auto p = state_map->insert({chunk_state, clat_state});
1388 StateId dest_clat_state = p.first->second;
1403 if (clat_state != dest_clat_state) {
1429 forward_costs_[clat_state] = (clat_state == 0 ? 0 :
1430 std::numeric_limits<BaseFloat>::infinity());
1431 std::vector<std::pair<StateId, int32> > arcs_in;
1432 arcs_in.swap(arcs_in_[clat_state]);
1433 for (
auto p: arcs_in) {
1439 int32 arc_pos = p.second;
1441 if (arc_pos >= (
int32)clat_.NumArcs(src_state))
1443 fst::MutableArcIterator<CompactLattice> aiter(&clat_, src_state);
1444 aiter.Seek(arc_pos);
1445 if (aiter.Value().nextstate != clat_state)
1450 new_in_arc.nextstate = dest_clat_state;
1451 new_in_arc.weight =
fst::Times(new_in_arc.weight, extra_weight_in);
1452 aiter.SetValue(new_in_arc);
1454 BaseFloat new_forward_cost = forward_costs_[src_state] +
1456 if (new_forward_cost < forward_costs_[dest_clat_state])
1457 forward_costs_[dest_clat_state] = new_forward_cost;
1458 arcs_in_[dest_clat_state].push_back(p);
1464 void LatticeIncrementalDeterminizer::TransferArcsToClat(
1466 bool is_first_chunk,
1467 const std::unordered_map<CompactLattice::StateId, CompactLattice::StateId> &state_map,
1468 const std::unordered_map<CompactLattice::StateId, Label> &chunk_state_to_token,
1469 const std::unordered_map<Label, BaseFloat> &old_final_costs) {
1471 StateId chunk_num_states = chunk_clat.NumStates();
1474 for (
StateId chunk_state = (is_first_chunk ? 0 : 1);
1475 chunk_state < chunk_num_states; chunk_state++) {
1476 auto iter = state_map.find(chunk_state);
1477 if (iter == state_map.end()) {
1478 KALDI_ASSERT(chunk_state_to_token.count(chunk_state) != 0);
1483 StateId clat_state = iter->second;
1493 clat_.SetFinal(clat_state, chunk_clat.Final(chunk_state));
1496 for (fst::ArcIterator<CompactLattice> aiter(chunk_clat, chunk_state);
1497 !aiter.Done(); aiter.Next()) {
1500 auto next_iter = state_map.find(arc.nextstate);
1501 if (next_iter != state_map.end()) {
1504 arc.nextstate = next_iter->second;
1506 arc.ilabel > kMaxTokenLabel);
1507 AddArcToClat(clat_state, arc);
1513 KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() &&
1514 arc.olabel >= (
Label)kTokenLabelOffset &&
1515 arc.olabel < (
Label)kMaxTokenLabel &&
1516 chunk_state_to_token.count(arc.nextstate) != 0 &&
1517 old_final_costs.count(arc.olabel) != 0);
1522 chunk_clat.Final(arc.nextstate));
1524 auto cost_iter = old_final_costs.find(arc.olabel);
1526 BaseFloat old_final_cost = cost_iter->second;
1534 arc.weight.SetWeight(
fst::Times(arc.weight.Weight(),
1539 arc.nextstate = clat_state;
1540 final_arcs_.push_back(arc);
1547 bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk(
1560 std::unordered_map<Label, BaseFloat> old_final_costs;
1561 GetRawLatticeFinalCosts(*raw_fst, &old_final_costs);
1565 trans_model_, raw_fst, config_.lattice_beam, &chunk_clat,
1570 std::unordered_map<StateId, Label> chunk_state_to_token;
1571 IdentifyTokenFinalStates(chunk_clat,
1572 &chunk_state_to_token);
1574 StateId chunk_num_states = chunk_clat.NumStates();
1575 if (chunk_num_states == 0) {
1578 KALDI_WARN <<
"Empty lattice, something went wrong.";
1579 clat_.DeleteStates();
1583 StateId start_state = chunk_clat.Start();
1594 std::unordered_map<StateId, StateId> state_map;
1597 bool is_first_chunk = ProcessArcsFromChunkStartState(chunk_clat, &state_map);
1602 for (
StateId clat_state: non_final_redet_states_) {
1603 clat_.DeleteArcs(clat_state);
1604 clat_.SetFinal(clat_state, CompactLatticeWeight::Zero());
1608 final_arcs_.clear();
1612 for (
StateId state = (is_first_chunk ? 0 : 1);
1613 state < chunk_num_states; state++) {
1614 if (chunk_state_to_token.count(state) != 0)
1617 StateId new_clat_state = clat_.NumStates();
1618 if (state_map.insert({state, new_clat_state}).second) {
1626 if (is_first_chunk) {
1627 auto iter = state_map.find(start_state);
1631 clat_.SetStart(clat_start_state);
1632 forward_costs_[clat_start_state] = 0.0;
1635 TransferArcsToClat(chunk_clat, is_first_chunk,
1636 state_map, chunk_state_to_token, old_final_costs);
1638 GetNonFinalRedetStates();
1640 return determinized_till_beam;
1645 void LatticeIncrementalDeterminizer::SetFinalCosts(
1646 const unordered_map<Label, BaseFloat> *token_label2final_cost) {
1647 if (final_arcs_.empty()) {
1648 KALDI_WARN <<
"SetFinalCosts() called when final_arcs_.empty()... possibly " 1649 "means you are calling this after Finalize()? Not allowed: could " 1650 "indicate a code error. Or possibly decoding failed somehow.";
1658 std::unordered_set<int32> &prefinal_states(temp_);
1659 prefinal_states.clear();
1660 for (
const auto &arc: final_arcs_) {
1665 prefinal_states.insert(state);
1668 for (
int32 state: prefinal_states)
1669 clat_.SetFinal(state, CompactLatticeWeight::Zero());
1673 Label token_label = arc.ilabel;
1677 if (token_label2final_cost == NULL) {
1678 graph_final_cost = 0.0;
1680 auto iter = token_label2final_cost->find(token_label);
1681 if (iter == token_label2final_cost->end())
1684 graph_final_cost = iter->second;
1691 clat_.SetFinal(src_state,
1714 decoder::BackpointerToken>;
1716 decoder::BackpointerToken>;
1718 decoder::BackpointerToken>;
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames=-1)
This will decode until there are no more frames ready in the decodable object.
fst::ArcTpl< LatticeWeight > LatticeArc
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
LatticeIncrementalDecoderTpl(const FST &fst, const TransitionModel &trans_model, const LatticeIncrementalDecoderConfig &config)
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
This is an extention to the "normal" lattice-generating decoder.
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
LatticeWeightTpl< FloatType > Plus(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
virtual bool IsLastFrame(int32 frame) const =0
Returns true if this is the last frame.
typename Arc::StateId StateId
typename HashList< StateId, decoder::BackpointerToken *>::Elem Elem
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
typename LatticeArc::Label Label
double ConvertToCost(const LatticeWeightTpl< Float > &w)
fst::VectorFst< LatticeArc > Lattice
GrammarFst is an FST that is 'stitched together' from multiple FSTs, that can recursively incorporate...
fst::VectorFst< CompactLatticeArc > CompactLattice
The normal decoder, lattice-faster-decoder.h, sometimes has an issue when doing real-time application...
#define KALDI_ASSERT(cond)
LatticeWeightTpl< BaseFloat > LatticeWeight
typename Arc::Label Label
fst::ArcTpl< CompactLatticeWeight > CompactLatticeArc
void TopSortCompactLatticeIfNeeded(CompactLattice *clat)
Topologically sort the compact lattice if not already topologically sorted.
virtual BaseFloat LogLikelihood(int32 frame, int32 index)=0
Returns the log likelihood, which will be negated in the decoder.
double Elapsed() const
Returns time in seconds.
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
bool DeterminizeLatticePhonePrunedWrapper(const kaldi::TransitionModel &trans_model, MutableFst< kaldi::LatticeArc > *ifst, double beam, MutableFst< kaldi::CompactLatticeArc > *ofst, DeterminizeLatticePhonePrunedOptions opts)
This function is a wrapper of DeterminizeLatticePhonePruned() that works for Lattice type FSTs...
static void AddCompactLatticeArcToLattice(const CompactLatticeArc &clat_arc, LatticeArc::StateId src_state, Lattice *lat)
void SetWeight(const W &w)