20 #ifndef KALDI_FSTEXT_PRE_DETERMINIZE_INL_H_ 21 #define KALDI_FSTEXT_PRE_DETERMINIZE_INL_H_ 226 namespace pre_determinize_helpers {
235 assert(symTable != NULL);
236 const char *prefix_ptr = prefix.c_str();
237 size_t prefix_len = strlen(prefix_ptr);
238 for (SymbolTableIterator siter(*symTable); !siter.Done(); siter.Next()) {
239 const char *sym = siter.Symbol().c_str();
240 if (!strncmp(prefix_ptr, sym, prefix_len)) {
241 if (isdigit(sym[prefix_len])) {
244 for (pos = prefix_len;sym[pos] !=
'\0'; pos++)
245 if (!isdigit(sym[pos]))
break;
246 if (sym[pos] ==
'\0') {
247 if (bad_sym != NULL) *bad_sym = (std::string) sym;
261 typename std::set<T>::const_iterator siter = s.begin();
262 typename std::vector<T>::iterator viter = v->begin();
263 for (; siter != s.end(); ++siter, ++viter) {
264 assert(viter != v->end());
271 std::vector<T>*
InsertMember(
const std::vector<T> m, std::vector<std::vector<T>*> *S) {
272 assert(m.size() > 0);
274 assert(idx>=(T)0 && idx < (T)S->size());
275 if ( (*S)[idx] != NULL) {
276 assert( *((*S)[idx]) == m );
281 std::vector<T> *ret = (*S)[idx] =
new std::vector<T>(m);
291 template<
class Arc>
void Closure(MutableFst<Arc> *
fst, std::set<typename Arc::StateId> *S,
292 const std::vector<bool> &pVec) {
294 std::vector<StateId> Q;
296 while (Q.size() != 0) {
297 StateId s = Q.back();
299 for (ArcIterator<MutableFst<Arc> > aiter(*fst, s); ! aiter.Done(); aiter.Next()) {
300 const Arc &arc = aiter.Value();
301 if (arc.ilabel != 0)
break;
303 if (!pVec[arc.nextstate]) {
304 std::pair< typename std::set<StateId>::iterator,
bool > p = S->insert(arc.nextstate);
306 Q.push_back(arc.nextstate);
316 template<
class Arc,
class Int>
319 std::vector<Int> *symsOut) {
322 typedef size_t ArcId;
325 assert(first_new_sym > 0);
327 if (fst->Start() == kNoStateId)
return;
328 assert(symsOut != NULL && symsOut->size() == 0);
331 KALDI_VLOG(2) <<
"PreDeterminize: Checking FST properties";
332 uint64 props = fst->Properties(kAccessible|kCoAccessible,
true);
333 if (props != (kAccessible|kCoAccessible)) {
334 KALDI_ERR <<
"PreDeterminize: FST is not trim";
339 KALDI_VLOG(2) <<
"PreDeterminize: creating single final state";
344 KALDI_VLOG(2) <<
"PreDeterminize: sorting arcs on input";
345 ILabelCompare<Arc> icomp;
349 StateId n_states = 0, max_state = 0;
351 for (StateIterator<MutableFst<Arc> > iter(*fst); ! iter.Done(); iter.Next()) {
352 StateId state = iter.Value();
355 if (state > max_state) max_state = state;
357 KALDI_VLOG(2) <<
"PreDeterminize: n_states = "<<(n_states)<<
", max_state ="<<(max_state);
360 std::vector<bool> p_vec(max_state+1,
false);
363 std::vector<bool> seen_vec(max_state+1,
false);
365 seen_vec[fst->Start()] =
true;
366 for (StateIterator<MutableFst<Arc> > siter(*fst); ! siter.Done(); siter.Next()) {
367 for (ArcIterator<MutableFst<Arc> > aiter(*fst, siter.Value()); ! aiter.Done(); aiter.Next()) {
368 const Arc &arc = aiter.Value();
369 assert(arc.nextstate>=0&&arc.nextstate<max_state+1);
370 if (seen_vec[arc.nextstate])
371 p_vec[arc.nextstate] =
true;
373 seen_vec[arc.nextstate] =
true;
378 std::map<std::pair<StateId, ArcId>,
size_t> m_map;
385 std::vector<std::vector<StateId>* > S(max_state+1, (std::vector<StateId>*)(
void*)0);
386 std::vector<std::pair<std::vector<StateId>*,
size_t> > Q;
390 std::vector<StateId> all_seed_states;
391 if (!p_vec[fst->Start()])
392 all_seed_states.push_back(fst->Start());
393 for (StateId s = 0;s<=max_state; s++)
394 if (p_vec[s]) all_seed_states.push_back(s);
396 for (
size_t idx = 0;idx < all_seed_states.size(); idx++) {
397 StateId s = all_seed_states[idx];
398 std::set<StateId> closure_s;
402 std::vector<StateId> closure_s_vec;
407 Q.push_back(std::pair<std::vector<StateId>*,
size_t>(ptr, 0));
411 std::vector<bool> d_vec(max_state+1,
false);
414 size_t num_extra_det_states = 0;
417 while (Q.size() != 0) {
420 std::pair<std::vector<StateId>*,
size_t> cur_pair(Q.back());
422 const std::vector<StateId> &A(*cur_pair.first);
423 size_t n =cur_pair.second;
426 for (
size_t idx = 0;idx < A.size(); idx++) {
427 assert(d_vec[A[idx]] ==
false &&
"This state has been seen before. Algorithm error.");
428 d_vec[A[idx]] =
true;
433 std::map<Label, std::set<std::pair<std::pair<StateId, ArcId>, StateId> > > arc_hash;
442 for (
size_t idx = 0;idx < A.size(); idx++) {
444 assert(s>=0 && s<=max_state);
446 for (ArcIterator<MutableFst<Arc> > aiter(*fst, s); ! aiter.Done(); aiter.Next(), ++arc_id) {
447 const Arc &arc = aiter.Value();
449 std::pair<std::pair<StateId, ArcId>, StateId>
450 this_pair(std::pair<StateId, ArcId>(s, arc_id), arc.nextstate);
451 bool inserted = (arc_hash[arc.ilabel].insert(this_pair)).second;
458 if (arc_hash.count(0) == 1) {
459 std::set<std::pair<std::pair<StateId, ArcId>, StateId> > &eps_set = arc_hash[0];
460 typedef typename std::set<std::pair<std::pair<StateId, ArcId>, StateId> >::iterator set_iter_t;
461 for (set_iter_t siter = eps_set.begin(); siter != eps_set.end(); ++siter) {
462 const std::pair<std::pair<StateId, ArcId>, StateId> &this_pr = *siter;
463 if (p_vec[this_pr.second]) {
464 assert(m_map.count(this_pr.first) == 0);
465 m_map[this_pr.first] =
n;
473 typedef typename std::map<Label, std::set<std::pair<std::pair<StateId, ArcId>, StateId> > >::iterator map_iter_t;
474 typedef typename std::set<std::pair<std::pair<StateId, ArcId>, StateId> >::iterator set_iter_t2;
475 for (map_iter_t miter = arc_hash.begin(); miter != arc_hash.end(); ++miter) {
476 Label t = miter->first;
477 std::set<std::pair<std::pair<StateId, ArcId>, StateId> > &S_t = miter->second;
479 std::set<StateId> V_t;
487 for (set_iter_t2 siter = S_t.begin(); siter != S_t.end(); ++siter) {
488 const std::pair<std::pair<StateId, ArcId>, StateId> &this_pr = *siter;
489 if (p_vec[this_pr.second]) {
490 if (S_t.size() > 1) {
491 assert(m_map.count(this_pr.first) == 0);
492 m_map[this_pr.first] = k;
494 num_extra_det_states++;
497 V_t.insert(this_pr.second);
500 if (V_t.size() != 0) {
502 std::vector<StateId> closure_V_t_vec;
506 Q.push_back(std::pair<std::vector<StateId>*,
size_t>(ptr, k));
516 for (StateIterator<MutableFst<Arc> > siter(*fst); ! siter.Done(); siter.Next()) {
517 StateId val = siter.Value();
518 assert(d_vec[val] ==
true);
525 for (
typename std::map<std::pair<StateId, ArcId>,
size_t>::iterator m_iter = m_map.begin();
526 m_iter != m_map.end();
528 n = std::max(n, (int64) m_iter->second);
532 for (
size_t i = 0;
static_cast<int64
>(
i)<n;
i++) symsOut->push_back(first_new_sym +
i);
536 std::map<std::pair<StateId, size_t>, StateId> h_map;
541 size_t n_states_added = 0;
543 for (
typename std::map<std::pair<StateId, ArcId>,
size_t>::iterator m_iter = m_map.begin();
544 m_iter != m_map.end();
546 StateId state = m_iter->first.first;
547 ArcId arcpos = m_iter->first.second;
548 size_t m_a = m_iter->second;
550 MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
552 Arc arc = aiter.Value();
556 arc.ilabel = (*symsOut)[m_a];
558 std::pair<StateId, size_t> pr(arc.nextstate, m_a);
559 if (!h_map.count(pr)) {
561 StateId newstate = fst->AddState();
563 Arc new_arc( (*symsOut)[m_a], (Label)0, Weight::One(), arc.nextstate);
564 fst->AddArc(newstate, new_arc);
565 h_map[pr] = newstate;
567 arc.nextstate = h_map[pr];
572 KALDI_VLOG(2) <<
"Added " <<(n_states_added)<<
" new states and added/changed "<<(m_map.size())<<
" arcs";
576 for (
size_t i = 0;
i < S.size();
i++)
582 std::string prefix, std::vector<Label> *symsOut) {
586 assert(symsOut && symsOut->size() == 0);
587 for (
int i = 0;
i < nSym;
i++) {
588 std::stringstream ss; ss << prefix <<
i;
589 std::string str = ss.str();
590 if (input_sym_table->Find(str) != -1) {
593 symsOut->push_back( (
Label) input_sym_table->AddSymbol(str));
599 template<
class Arc>
void AddSelfLoops(MutableFst<Arc> *
fst, std::vector<typename Arc::Label> &isyms,
600 std::vector<typename Arc::Label> &osyms) {
602 assert(isyms.size() == osyms.size());
606 size_t n = isyms.size();
612 Label isyms_min = *std::min_element(isyms.begin(), isyms.end()),
613 isyms_max = *std::max_element(isyms.begin(), isyms.end()),
614 osyms_min = *std::min_element(osyms.begin(), osyms.end()),
615 osyms_max = *std::max_element(osyms.begin(), osyms.end());
616 std::set<Label> isyms_set, osyms_set;
617 for (
size_t i = 0;
i < isyms.size();
i++) {
618 assert(isyms[
i] > 0 && osyms[
i] > 0);
619 isyms_set.insert(isyms[
i]);
620 osyms_set.insert(osyms[i]);
622 assert(isyms_set.size() == n && osyms_set.size() ==
n);
625 for (StateIterator<MutableFst<Arc> > siter(*fst); ! siter.Done(); siter.Next()) {
626 StateId state = siter.Value();
627 bool this_state_needs_self_loops = (fst->Final(state) != Weight::Zero());
628 for (ArcIterator<MutableFst<Arc> > aiter(*fst, state); ! aiter.Done(); aiter.Next()) {
629 const Arc &arc = aiter.Value();
632 assert(!(arc.ilabel>=isyms_min && arc.ilabel<=isyms_max && isyms_set.count(arc.ilabel) != 0));
633 assert(!(arc.olabel>=osyms_min && arc.olabel<=osyms_max && osyms_set.count(arc.olabel) != 0));
635 this_state_needs_self_loops =
true;
637 if (this_state_needs_self_loops) {
638 for (
size_t i = 0;
i <
n;
i++) {
640 arc.ilabel = isyms[
i];
641 arc.olabel = osyms[
i];
642 arc.weight = Weight::One();
643 arc.nextstate = state;
644 fst->AddArc(state, arc);
658 int64 num_deleted = 0;
660 if (isyms.size() == 0)
return 0;
661 Label isyms_min = *std::min_element(isyms.begin(), isyms.end()),
662 isyms_max = *std::max_element(isyms.begin(), isyms.end());
663 bool isyms_consecutive = (isyms_max+1-isyms_min ==
static_cast<Label
>(isyms.size()));
664 std::set<Label> isyms_set;
665 if (!isyms_consecutive)
666 for (
size_t i = 0;
i < isyms.size();
i++)
667 isyms_set.insert(isyms[
i]);
669 for (StateIterator<MutableFst<Arc> > siter(*fst); ! siter.Done(); siter.Next()) {
670 StateId state = siter.Value();
671 for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state); ! aiter.Done(); aiter.Next()) {
672 const Arc &arc = aiter.Value();
673 if (arc.ilabel >= isyms_min && arc.ilabel <= isyms_max) {
674 if (isyms_consecutive || isyms_set.count(arc.ilabel) != 0) {
678 aiter.SetValue(mod_arc);
691 StateId num_states = fst->NumStates();
692 StateId num_final = 0;
693 std::vector<StateId> final_states;
694 for (StateId s = 0; s < num_states; s++) {
695 if (fst->Final(s) != Weight::Zero()) {
697 final_states.push_back(s);
700 if (final_states.size() == 1) {
701 if (fst->Final(final_states[0]) == Weight::One()) {
702 ArcIterator<MutableFst<Arc> > iter(*fst, final_states[0]);
706 return final_states[0];
711 StateId final_state = fst->AddState();
712 fst->SetFinal(final_state, Weight::One());
713 for (
size_t idx = 0; idx < final_states.size(); idx++) {
714 StateId s = final_states[idx];
715 Weight weight = fst->Final(s);
716 fst->SetFinal(s, Weight::Zero());
720 arc.nextstate = final_state;
730 #endif // KALDI_FSTEXT_PRE_DETERMINIZE_INL_H_ fst::StdArc::StateId StateId
void CopySetToVector(const std::set< T > s, std::vector< T > *v)
void PreDeterminize(MutableFst< Arc > *fst, typename Arc::Label first_new_sym, std::vector< Int > *symsOut)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
void CreateNewSymbols(SymbolTable *input_sym_table, int nSym, std::string prefix, std::vector< Label > *symsOut)
bool HasBannedPrefixPlusDigits(SymbolTable *symTable, std::string prefix, std::string *bad_sym)
void AddSelfLoops(MutableFst< Arc > *fst, std::vector< typename Arc::Label > &isyms, std::vector< typename Arc::Label > &osyms)
AddSelfLoops is a function you will probably want to use alongside PreDeterminize, to add self-loops to any FSTs that you compose on the left hand side of the one modified by PreDeterminize.
Arc::StateId CreateSuperFinal(MutableFst< Arc > *fst)
void Closure(MutableFst< Arc > *fst, std::set< typename Arc::StateId > *S, const std::vector< bool > &pVec)
fst::StdArc::Weight Weight
std::vector< T > * InsertMember(const std::vector< T > m, std::vector< std::vector< T > *> *S)
#define KALDI_ASSERT(cond)
int64 DeleteISymbols(MutableFst< Arc > *fst, std::vector< typename Arc::Label > isyms)