21 #ifndef KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_ 22 #define KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_ 27 #include <unordered_map> 28 using std::unordered_map;
49 size_t hash = 0, factor = 1;
50 for (
typename std::vector<Label>::const_iterator it = vec->begin();
51 it != vec->end(); it++) {
60 size_t operator()(
const std::vector<Label> *vec1,
const std::vector<Label> *vec2)
const {
61 return (*vec1 == *vec2);
74 std::vector<Label> v; v.push_back(l);
79 StringId
IdOfSeq(
const std::vector<Label> &v) {
82 else if (v.size() == 1)
return IdOfLabel(v[0]);
89 void SeqOfId(StringId
id, std::vector<Label> *v) {
94 assert(static_cast<size_t>(
id) <
vec_.size());
99 if (prefix_len == 0)
return id;
101 std::vector<Label> v;
103 size_t sz = v.size();
104 assert(sz >= prefix_len);
105 std::vector<Label> v_noprefix(sz - prefix_len);
106 for (
size_t i = 0;
i < sz-prefix_len;
i++) v_noprefix[
i] = v[
i+prefix_len];
114 string_end = (std::numeric_limits<StringId>::max() / 2) - 1;
115 no_symbol = (std::numeric_limits<StringId>::max() / 2);
120 for (
typename std::vector<std::vector<Label>* >::iterator iter =
vec_.begin(); iter !=
vec_.end(); ++iter)
122 std::vector<std::vector<Label>* > tmp_vec;
135 typename MapType::iterator iter =
map_.find(&v);
136 if (iter !=
map_.end()) {
139 StringId this_id = (StringId)
vec_.size();
140 std::vector<Label> *v_new =
new std::vector<Label> (v);
141 vec_.push_back(v_new);
142 map_[v_new] = this_id;
148 std::vector<std::vector<Label>* >
vec_;
165 void Output(MutableFst<GallicArc<Arc> > *ofst,
bool destroy =
true);
171 void Output(MutableFst<Arc> *ofst,
bool destroy =
true);
177 int max_states = -1,
bool allow_partial =
false):
178 ifst_(ifst.
Copy()), delta_(delta), max_states_(max_states),
179 determinized_(false), allow_partial_(allow_partial),
180 is_partial_(false), equal_(delta),
181 hash_(ifst.Properties(kExpanded, false) ?
182 down_cast<const ExpandedFst<Arc>*,
183 const Fst<Arc> >(&ifst)->NumStates()/2 + 3 : 20,
185 epsilon_closure_(ifst_, max_states, &repository_, delta) { }
188 assert(!determinized_);
192 if (start_id == kNoStateId) { determinized_ =
true;
return; }
195 elem.
state = start_id;
196 elem.
weight = Weight::One();
197 elem.
string = repository_.IdOfEmpty();
198 std::vector<Element> vec;
201 assert(cur_id == 0 &&
"Do not call Determinize twice.");
203 while (!Q_.empty()) {
204 std::pair<std::vector<Element>*,
OutputStateId> cur_pair = Q_.front();
206 ProcessSubset(cur_pair);
207 if (debug_ptr && *debug_ptr) Debug();
208 if (max_states_ > 0 && output_arcs_.size() > max_states_) {
209 if (allow_partial_ ==
false) {
210 KALDI_ERR <<
"Determinization aborted since passed " << max_states_
213 KALDI_WARN <<
"Determinization terminated since passed " << max_states_
214 <<
" states, partial results will be generated";
220 determinized_ =
true;
234 for (
typename SubsetHash::iterator iter = hash_.begin();
235 iter != hash_.end(); ++iter)
260 return (state != other.
state ||
string != other.
string ||
291 size_t hash = 0, factor = 1;
292 for (
typename std::vector<Element>::const_iterator iter = subset->begin();
293 iter != subset->end(); ++iter) {
295 hash += iter->state + 103333 * iter->string;
307 const std::vector<Element> *s2)
const {
308 size_t sz = s1->size();
310 if (sz != s2->size())
return false;
311 typename std::vector<Element>::const_iterator iter1 = s1->begin(),
312 iter1_end = s1->end(), iter2 = s2->begin();
313 for (; iter1 < iter1_end; ++iter1, ++iter2) {
314 if (iter1->state != iter2->state ||
315 iter1->string != iter2->string ||
316 !
ApproxEqual(iter1->weight, iter2->weight, delta_))
330 bool operator ()(
const std::vector<Element> *s1,
const std::vector<Element> *s2)
const {
331 size_t sz = s1->size();
333 if (sz != s2->size())
return false;
334 typename std::vector<Element>::const_iterator iter1 = s1->begin(),
335 iter1_end = s1->end(), iter2=s2->begin();
336 for (; iter1 < iter1_end; ++iter1, ++iter2) {
337 if (iter1->state != iter2->state)
return false;
350 ifst_(ifst), max_states_(max_states), repository_(repository),
358 void GetEpsilonClosure(
const std::vector<Element> &input_subset,
359 std::vector<Element> *output_subset);
365 element(e), weight_to_process(w), in_queue(i) {}
390 std::deque<typename Arc::StateId>
queue_;
409 void AddOneElement(
const Element &elem,
const Weight &unprocessed_weight);
418 void ExpandOneElement(
const Element &elem,
420 const Weight &unprocessed_weight,
421 bool save_to_queue_2 =
false);
435 void ProcessFinal(
const std::vector<Element> &closed_subset, OutputStateId state) {
437 bool is_final =
false;
438 StringId final_string = 0;
439 Weight final_weight = Weight::One();
443 typename std::vector<Element>::const_iterator iter = closed_subset.begin(),
444 end = closed_subset.end();
445 for (; iter != end; ++iter) {
447 Weight this_final_weight = ifst_->Final(elem.
state);
448 if (this_final_weight != Weight::Zero()) {
450 final_string = elem.
string;
451 final_weight =
Times(elem.
weight, this_final_weight);
454 if (final_string != elem.
string) {
455 KALDI_ERR <<
"FST was not functional -> not determinizable";
457 final_weight =
Plus(final_weight,
Times(elem.
weight, this_final_weight));
466 temp_arc.
ostring = final_string;
467 temp_arc.
weight = final_weight;
468 output_arcs_[state].push_back(temp_arc);
475 void ProcessTransition(OutputStateId state, Label ilabel, std::vector<Element> *subset);
482 inline bool operator () (
const std::pair<Label, Element> &p1,
const std::pair<Label, Element> &p2) {
483 if (p1.first < p2.first)
return true;
484 else if (p1.first > p2.first)
return false;
486 return p1.second.state < p2.second.state;
502 std::vector<std::pair<Label, Element> > all_elems;
505 typename std::vector<Element>::const_iterator iter = closed_subset.begin(),
506 end = closed_subset.end();
507 for (; iter != end; ++iter) {
509 for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.
state);
510 !aiter.Done(); aiter.Next()) {
511 const Arc &arc = aiter.Value();
512 if (arc.ilabel != 0) {
513 std::pair<Label, Element> this_pr;
514 this_pr.first = arc.ilabel;
515 Element &next_elem(this_pr.second);
516 next_elem.
state = arc.nextstate;
522 std::vector<Label> seq;
523 repository_.SeqOfId(elem.
string, &seq);
524 seq.push_back(arc.olabel);
525 next_elem.
string = repository_.IdOfSeq(seq);
527 all_elems.push_back(this_pr);
533 std::sort(all_elems.begin(), all_elems.end(), pc);
535 typedef typename std::vector<std::pair<Label, Element> >::const_iterator PairIter;
536 PairIter cur = all_elems.begin(), end = all_elems.end();
537 std::vector<Element> this_subset;
540 Label ilabel = cur->first;
542 while (cur != end && cur->first == ilabel) {
543 this_subset.push_back(cur->second);
547 ProcessTransition(state, ilabel, &this_subset);
556 typedef typename SubsetHash::iterator IterType;
557 IterType iter = hash_.find(&subset);
558 if (iter == hash_.end()) {
559 std::vector<Element> *new_subset =
new std::vector<Element>(subset);
560 OutputStateId new_state_id = (OutputStateId) output_arcs_.size();
561 bool ans = hash_.insert(std::pair<
const std::vector<Element>*,
562 OutputStateId>(new_subset,
563 new_state_id)).second;
565 output_arcs_.push_back(std::vector<TempArc>());
566 if (allow_partial_ ==
false) {
568 Q_.push_front(std::pair<std::vector<Element>*, OutputStateId>(new_subset, new_state_id));
573 Q_.push_back(std::pair<std::vector<Element>*, OutputStateId>(new_subset, new_state_id));
588 void ProcessSubset(
const std::pair<std::vector<Element>*, OutputStateId> & pair) {
589 const std::vector<Element> *subset = pair.first;
590 OutputStateId state = pair.second;
592 std::vector<Element> closed_subset;
593 epsilon_closure_.GetEpsilonClosure(*subset, &closed_subset);
596 ProcessFinal(closed_subset, state);
599 ProcessTransitions(closed_subset, state);
605 std::deque<std::pair<std::vector<Element>*, OutputStateId> >
Q_;
626 float delta,
bool *debug_ptr,
int max_states,
627 bool allow_partial) {
628 ofst->SetOutputSymbols(ifst.OutputSymbols());
629 ofst->SetInputSymbols(ifst.InputSymbols());
639 MutableFst<GallicArc<typename F::Arc> > *ofst,
float delta,
640 bool *debug_ptr,
int max_states,
641 bool allow_partial) {
642 ofst->SetOutputSymbols(ifst.InputSymbols());
643 ofst->SetInputSymbols(ifst.InputSymbols());
653 std::vector<Element> *output_subset) {
655 size_t size = input_subset.size();
658 ((ifst_->Properties(kILabelSorted,
false) & kILabelSorted) != 0);
661 for (
size_t i = 0;
i < size;
i++) {
662 ExpandOneElement(input_subset[
i], sorted, input_subset[i].weight,
true);
665 size_t s = queue_2_.size();
667 *output_subset = input_subset;
671 for (
size_t i = 0;
i < size;
i++) {
675 input_subset[i].weight,
677 ecinfo_.back().element.weight = Weight::Zero();
679 if (id_to_index_.size() < input_subset[
i].state + 1) {
680 id_to_index_.resize(2 * input_subset[i].state + 1, -1);
682 id_to_index_[input_subset[
i].state] = ecinfo_.size() - 1;
688 elem.
weight = Weight::Zero();
689 for (
size_t i = 0;
i < s;
i++) {
690 elem.
state = queue_2_[
i].state;
691 elem.
string = queue_2_[
i].string;
692 AddOneElement(elem, queue_2_[
i].weight);
698 while (!queue_.empty()) {
703 int index = id_to_index_[id];
714 if (max_states_ > 0 && counter++ > max_states_) {
715 KALDI_ERR <<
"Determinization aborted since looped more than " 716 << max_states_ <<
" times during epsilon closure";
723 ExpandOneElement(elem, sorted, unprocessed_weight);
728 sort(ecinfo_.begin(), ecinfo_.end());
730 output_subset->clear();
732 size = ecinfo_.size();
733 output_subset->reserve(size);
734 for (
size_t i = 0;
i < size;
i++) {
739 output_subset->push_back(info.
element);
749 if (elem.
state < id_to_index_.size()) {
750 index = id_to_index_[elem.
state];
753 if (index >= ecinfo_.size()) {
757 else if (ecinfo_[index].element.state != elem.
state) {
765 size_t size = id_to_index_.size();
766 if (size < elem.
state + 1) {
768 id_to_index_.resize(2 * elem.
state + 1, -1);
770 id_to_index_[elem.
state] = ecinfo_.size() - 1;
771 queue_.push_back(elem.
state);
777 std::ostringstream ss;
778 ss <<
"FST was not functional -> not determinizable.";
781 std::vector<Label> tmp_seq;
783 ss <<
"\nFirst string:";
784 for (
size_t i = 0;
i < tmp_seq.size();
i++)
785 ss <<
' ' << tmp_seq[
i];
786 ss <<
"\nSecond string:";
787 repository_->SeqOfId(elem.
string, &tmp_seq);
788 for (
size_t i = 0;
i < tmp_seq.size();
i++)
789 ss <<
' ' << tmp_seq[
i];
809 queue_.push_back(elem.
state);
819 const Weight &unprocessed_weight,
820 bool save_to_queue_2) {
825 for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.
state);
826 !aiter.Done(); aiter.Next()) {
827 const Arc &arc = aiter.Value();
828 if (sorted && arc.ilabel > 0) {
833 if (arc.ilabel != 0) {
837 next_elem.
state = arc.nextstate;
838 next_elem.
weight = Weight::Zero();
839 Weight next_unprocessed_weight
840 =
Times(unprocessed_weight, arc.weight);
843 if (arc.olabel == 0) {
846 std::vector<Label> seq;
847 repository_->SeqOfId(str, &seq);
849 seq.push_back(arc.olabel);
850 next_elem.
string = repository_->IdOfSeq(seq);
852 if (save_to_queue_2) {
853 next_elem.
weight = next_unprocessed_weight;
854 queue_2_.push_back(next_elem);
856 AddOneElement(next_elem, next_unprocessed_weight);
864 assert(determinized_);
865 if (destroy) determinized_ =
false;
866 typedef GallicWeight<Label, Weight> ThisGallicWeight;
870 StateId nStates =
static_cast<StateId
>(output_arcs_.size());
871 ofst->DeleteStates();
872 ofst->SetStart(kNoStateId);
876 for (StateId s = 0;s < nStates;s++) {
882 for (StateId this_state = 0; this_state < nStates; this_state++) {
883 std::vector<TempArc> &this_vec(output_arcs_[this_state]);
884 typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
885 end = this_vec.end();
886 for (; iter != end; ++iter) {
887 const TempArc &temp_arc(*iter);
888 GallicArc<Arc> new_arc;
889 std::vector<Label> seq;
890 repository_.SeqOfId(temp_arc.
ostring, &seq);
891 StringWeight<Label, STRING_LEFT> string_weight;
892 for (
size_t i = 0;
i < seq.size();
i++) string_weight.PushBack(seq[
i]);
893 ThisGallicWeight gallic_weight(string_weight, temp_arc.
weight);
896 ofst->SetFinal(this_state, gallic_weight);
899 new_arc.ilabel = temp_arc.
ilabel;
900 new_arc.olabel = temp_arc.
ilabel;
901 new_arc.weight = gallic_weight;
902 ofst->AddArc(this_state, new_arc);
906 if (destroy) { std::vector<TempArc> temp; temp.swap(this_vec); }
908 if (destroy) { std::vector<std::vector<TempArc> > temp; temp.swap(output_arcs_); }
913 assert(determinized_);
914 if (destroy) determinized_ =
false;
919 ofst->DeleteStates();
920 if (num_states == 0) {
921 ofst->SetStart(kNoStateId);
930 for (
OutputStateId this_state = 0; this_state < num_states; this_state++) {
931 std::vector<TempArc> &this_vec(output_arcs_[this_state]);
933 typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
934 end = this_vec.end();
935 for (; iter != end; ++iter) {
936 const TempArc &temp_arc(*iter);
937 std::vector<Label> seq;
938 repository_.SeqOfId(temp_arc.
ostring, &seq);
943 for (
size_t i = 0;
i < seq.size();
i++) {
946 arc.nextstate = next_state;
947 arc.weight = (
i == 0 ? temp_arc.
weight : Weight::One());
950 ofst->AddArc(cur_state, arc);
951 cur_state = next_state;
953 ofst->SetFinal(cur_state, (seq.size() == 0 ? temp_arc.
weight : Weight::One()));
958 for (
size_t i = 0;
i+1 < seq.size();
i++) {
962 arc.nextstate = next_state;
963 arc.weight = (
i == 0 ? temp_arc.
weight : Weight::One());
964 arc.ilabel = (
i == 0 ? temp_arc.
ilabel : 0);
966 ofst->AddArc(cur_state, arc);
967 cur_state = next_state;
972 arc.weight = (seq.size() <= 1 ? temp_arc.
weight : Weight::One());
973 arc.ilabel = (seq.size() <= 1 ? temp_arc.
ilabel : 0);
974 arc.olabel = (seq.size() > 0 ? seq.back() : 0);
975 ofst->AddArc(cur_state, arc);
979 if (destroy) { std::vector<TempArc> temp; temp.swap(this_vec); }
982 std::vector<std::vector<TempArc> > temp;
983 temp.swap(output_arcs_);
984 repository_.Destroy();
995 typedef typename std::vector<Element>::iterator IterType;
997 IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
1000 while (cur_in != end) {
1003 if (cur_in != cur_out) *cur_out = *cur_in;
1005 while (cur_in != end && cur_in->state == cur_out->state) {
1006 if (cur_in->string != cur_out->string) {
1007 KALDI_ERR <<
"FST was not functional -> not determinizable";
1009 cur_out->weight =
Plus(cur_out->weight, cur_in->weight);
1015 subset->resize(num_out);
1022 std::vector<Label> seq;
1024 IterType begin = subset->begin(), iter, end = subset->end();
1027 std::vector<Label> tmp_seq;
1028 for (iter = begin; iter != end; ++iter) {
1029 if (iter == begin) {
1030 repository_.SeqOfId(iter->string, &seq);
1032 repository_.SeqOfId(iter->string, &tmp_seq);
1033 if (tmp_seq.size() < seq.size()) seq.resize(tmp_seq.size());
1034 for (
size_t i = 0;
i < seq.size();
i++)
1035 if (tmp_seq[
i] != seq[
i]) seq.resize(i);
1037 if (seq.size() == 0)
break;
1039 common_str = repository_.IdOfSeq(seq);
1044 tot_weight = iter->weight;
1045 for (++iter; iter != end; ++iter)
1046 tot_weight =
Plus(tot_weight, iter->weight);
1050 size_t prefix_len = seq.size();
1051 for (iter = begin; iter != end; ++iter) {
1052 iter->weight =
Divide(iter->weight, tot_weight);
1053 iter->string = repository_.RemovePrefix(iter->string, prefix_len);
1060 temp_arc.
ilabel = ilabel;
1061 temp_arc.
nextstate = SubsetToStateId(*subset);
1062 temp_arc.
ostring = common_str;
1063 temp_arc.
weight = tot_weight;
1064 output_arcs_[state].push_back(temp_arc);
1074 KALDI_WARN <<
"Debug function called (probably SIGUSR1 caught)";
1078 if (output_arcs_.size() <= 2) {
1081 size_t max_state = output_arcs_.size() - 2;
1084 std::vector<OutputStateId> predecessor(max_state+1, kNoStateId);
1085 for (
size_t i = 0;
i < max_state;
i++) {
1086 for (
size_t j = 0;
j < output_arcs_[
i].size();
j++) {
1091 if (nextstate <= max_state && nextstate >
i)
1092 predecessor[nextstate] =
i;
1095 std::vector<std::pair<Label, StringId> > traceback;
1099 while (cur_state != 0 && cur_state != kNoStateId) {
1101 std::pair<Label, StringId> p;
1103 for (i = 0; i < output_arcs_[last_state].size(); i++) {
1104 if (output_arcs_[last_state][i].nextstate == cur_state) {
1105 p.first = output_arcs_[last_state][
i].ilabel;
1106 p.second = output_arcs_[last_state][
i].ostring;
1107 traceback.push_back(p);
1112 cur_state = last_state;
1114 if (cur_state == kNoStateId)
1115 KALDI_WARN <<
"Traceback did not reach start state " 1116 <<
"(possibly debug-code error)";
1118 std::stringstream ss;
1119 ss <<
"Traceback follows in format " 1120 <<
"ilabel (olabel olabel) ilabel (olabel) ... :";
1121 for (ssize_t
i = traceback.size() - 1;
i >= 0;
i--) {
1122 ss <<
' ' << traceback[
i].first <<
" ( ";
1123 std::vector<Label> seq;
1124 repository_.SeqOfId(traceback[
i].second, &seq);
1125 for (
size_t j = 0;
j < seq.size();
j++)
1126 ss << seq[
j] <<
' ';
1134 #endif // KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_ EpsilonClosureInfo(const Element &e, const Weight &w, bool i)
fst::StdArc::StateId StateId
LatticeWeightTpl< FloatType > Divide(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, DivideType typ=DIVIDE_ANY)
OutputStateId SubsetToStateId(const std::vector< Element > &subset)
bool operator<(const EpsilonClosureInfo &other) const
StringId single_symbol_start
void AddOneElement(const Element &elem, const Weight &unprocessed_weight)
bool operator!=(const LatticeWeightTpl< FloatType > &wa, const LatticeWeightTpl< FloatType > &wb)
unordered_map< const std::vector< Label > *, StringId, VectorKey, VectorEqual > MapType
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
void GetEpsilonClosure(const std::vector< Element > &input_subset, std::vector< Element > *output_subset)
LatticeWeightTpl< FloatType > Plus(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
void Determinize(bool *debug_ptr)
std::deque< typename Arc::StateId > queue_
void swap(basic_filebuf< CharT, Traits > &x, basic_filebuf< CharT, Traits > &y)
bool ApproxEqual(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, float delta=kDelta)
void ProcessTransition(OutputStateId state, Label ilabel, std::vector< Element > *subset)
void ProcessSubset(const std::pair< std::vector< Element > *, OutputStateId > &pair)
StringId IdOfSeq(const std::vector< Label > &v)
EpsilonClosure epsilon_closure_
StringId IdOfSeqInternal(const std::vector< Label > &v)
StringId RemovePrefix(StringId id, size_t prefix_len)
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
std::vector< std::vector< Label > *> vec_
StringId single_symbol_range
DeterminizerStar(const Fst< Arc > &ifst, float delta=kDelta, int max_states=-1, bool allow_partial=false)
void ExpandOneElement(const Element &elem, bool sorted, const Weight &unprocessed_weight, bool save_to_queue_2=false)
void Output(MutableFst< GallicArc< Arc > > *ofst, bool destroy=true)
StringRepository< Label, StringId > StringRepositoryType
void SeqOfId(StringId id, std::vector< Label > *v)
std::vector< Element > queue_2_
std::deque< std::pair< std::vector< Element > *, OutputStateId > > Q_
void ProcessTransitions(const std::vector< Element > &closed_subset, OutputStateId state)
void ProcessFinal(const std::vector< Element > &closed_subset, OutputStateId state)
StringId IdOfLabel(Label l)
size_t operator()(const std::vector< Label > *vec1, const std::vector< Label > *vec2) const
KALDI_DISALLOW_COPY_AND_ASSIGN(StringRepository)
std::vector< std::vector< TempArc > > output_arcs_
std::vector< EpsilonClosureInfo > ecinfo_
StringRepository< Label, StringId > * repository_
size_t operator()(const std::vector< Label > *vec) const
fst::StdArc::Weight Weight
EpsilonClosure(const Fst< Arc > *ifst, int max_states, StringRepository< Label, StringId > *repository, float delta)
#define KALDI_ASSERT(cond)
bool IsEmptyString(StringId id)
Arc::StateId OutputStateId
static const StringId string_start
StringRepository< Label, StringId > repository_
bool DeterminizeStar(F &ifst, MutableFst< typename F::Arc > *ofst, float delta, bool *debug_ptr, int max_states, bool allow_partial)
This function implements the normal version of DeterminizeStar, in which the output strings are repre...
Arc::StateId InputStateId
unordered_map< const std::vector< Element > *, OutputStateId, SubsetKey, SubsetEqual > SubsetHash
std::vector< int > id_to_index_
void Copy(const CuMatrixBase< Real > &src, const CuArray< int32 > ©_from_indices, CuMatrixBase< Real > *tgt)
Copies elements from src into tgt as given by copy_from_indices.