22 #ifndef KALDI_FSTEXT_DETERMINISTIC_FST_INL_H_ 23 #define KALDI_FSTEXT_DETERMINISTIC_FST_INL_H_ 35 ArcIterator<Fst<Arc> > aiter(fst_, s);
38 const Arc &arc = aiter.Value();
39 if (arc.ilabel == 0) {
49 Weight w = fst_.Final(state);
50 if (w != Weight::Zero())
return w;
52 StateId backoff_state = GetBackoffState(state, &backoff_w);
53 if (backoff_state == kNoStateId)
return Weight::Zero();
54 else return Times(backoff_w, this->Final(backoff_state));
59 const Fst<Arc> &
fst): fst_(fst) {
62 (kILabelSorted|kIDeterministic) &&
63 "Input FST is not i-label sorted and deterministic.");
72 SortedMatcher<Fst<Arc> > sm(
fst_, MATCH_INPUT, 1);
74 if (sm.Find(ilabel)) {
75 const Arc &arc = sm.Value();
81 if (backoff_state == kNoStateId)
return false;
82 if (!this->
GetArc(backoff_state, ilabel, oarc))
return false;
83 oarc->weight =
Times(oarc->weight, backoff_w);
91 std::vector<Label> start_state;
107 seq.push_back(ilabel);
108 if (seq.size() >
n_-1) {
110 seq.erase(seq.begin());
112 std::pair<const std::vector<Label>,
StateId> new_state(
116 typedef typename MapType::iterator IterType;
117 std::pair<IterType, bool> result =
state_map_.insert(new_state);
118 if (result.second ==
true) {
121 oarc->weight = Weight::One();
122 oarc->ilabel = ilabel;
123 oarc->olabel = ilabel;
124 oarc->nextstate = result.first->second;
132 return Weight::One();
140 if (
fst1_->Start() == -1 ||
fst2_->Start() == -1) {
145 std::pair<StateId,StateId> start_pair(
fst1_->Start(),
fst2_->Start());
155 const std::pair<StateId, StateId> &pr (
state_vec_[s]);
162 typedef typename MapType::iterator IterType;
164 "This program expects epsilon-free compact lattices as input");
166 const std::pair<StateId, StateId> pr (
state_vec_[s]);
169 if (!
fst1_->GetArc(pr.first, ilabel, &arc1))
return false;
170 if (arc1.olabel == 0) {
172 std::pair<const std::pair<StateId, StateId>,
StateId> new_value(
173 std::pair<StateId, StateId>(arc1.nextstate, pr.second),
176 std::pair<IterType, bool> result =
state_map_.insert(new_value);
177 oarc->ilabel = ilabel;
179 oarc->nextstate = result.first->second;
180 oarc->weight = arc1.weight;
181 if (result.second ==
true) {
183 const std::pair<StateId, StateId> &new_pair (new_value.first);
191 if (!
fst2_->GetArc(pr.second, arc1.olabel, &arc2))
return false;
192 std::pair<const std::pair<StateId, StateId>,
StateId> new_value(
193 std::pair<StateId, StateId>(arc1.nextstate, arc2.nextstate),
195 std::pair<IterType, bool> result =
197 oarc->ilabel = ilabel;
198 oarc->olabel = arc2.olabel;
199 oarc->nextstate = result.first->second;
200 oarc->weight =
Times(arc1.weight, arc2.weight);
201 if (result.second ==
true) {
203 const std::pair<StateId, StateId> &new_pair (new_value.first);
212 const StateId p1 = 26597, p2 = 50329;
218 return static_cast<size_t>(src_state * p1 + ilabel * p2) %
219 static_cast<size_t>(num_cached_arcs_);
225 StateId num_cached_arcs): fst_(fst),
226 num_cached_arcs_(num_cached_arcs),
227 cached_arcs_(num_cached_arcs) {
240 size_t index = this->
GetIndex(s, ilabel);
247 if (
fst_->GetArc(s, ilabel, &arc)) {
260 void *lm,
Label bos_symbol,
Label eos_symbol):
261 lm_(lm), bos_symbol_(bos_symbol), eos_symbol_(eos_symbol) {
262 std::vector<Label> begin_state;
263 begin_state.push_back(bos_symbol);
277 float log_prob = -0.5;
286 float log_prob = -0.25;
287 wseq.push_back(ilabel);
291 wseq.erase(wseq.begin(), wseq.begin() + 1);
295 if (log_prob == -std::numeric_limits<float>::infinity()) {
300 std::pair<const std::vector<Label>,
StateId> new_value(
305 typedef typename MapType::iterator IterType;
306 std::pair<IterType, bool> result =
state_map_.insert(new_value);
307 if (result.second ==
true)
309 oarc->ilabel = ilabel;
310 oarc->olabel = ilabel;
311 oarc->nextstate = result.first->second;
312 oarc->weight =
Weight(-log_prob);
320 MutableFst<Arc> *fst_composed) {
323 typedef std::pair<StateId, StateId> StatePair;
324 typedef unordered_map<StatePair,
StateId,
326 typedef typename MapType::iterator IterType;
328 fst_composed->DeleteStates();
331 std::queue<StatePair> state_queue;
334 StateId s1 = fst1.Start(),
336 start_state = fst_composed->AddState();
337 StatePair start_pair(s1, s2);
338 state_queue.push(start_pair);
339 fst_composed->SetStart(start_state);
342 std::pair<const StatePair, StateId> start_map(start_pair, start_state);
343 std::pair<IterType, bool> result = state_map.insert(start_map);
346 while (!state_queue.empty()) {
347 StatePair q = state_queue.front();
348 StateId q1 = q.first,
353 Weight final_weight =
Times(fst1.Final(q1), fst2->
Final(q2));
354 if (final_weight != Weight::Zero()) {
356 fst_composed->SetFinal(state_map[q], final_weight);
360 for (ArcIterator<Fst<Arc> > aiter(fst1, q1); !aiter.Done(); aiter.Next()) {
361 const Arc &arc1 = aiter.Value();
364 StateId next_state1 = arc1.nextstate,
369 if (arc1.olabel == 0) {
372 bool match = fst2->
GetArc(q2, arc1.olabel, &arc2);
375 next_state2 = arc2.nextstate;
377 next_pair = StatePair(next_state1, next_state2);
378 IterType sitr = state_map.find(next_pair);
380 if (sitr == state_map.end()) {
381 next_state = fst_composed->AddState();
382 std::pair<const StatePair, StateId> new_state(
383 next_pair, next_state);
384 std::pair<IterType, bool> result = state_map.insert(new_state);
388 state_queue.push(next_pair);
392 next_state = sitr->second;
394 if (arc1.olabel == 0) {
395 fst_composed->AddArc(state_map[q],
Arc(arc1.ilabel, 0, arc1.weight,
398 fst_composed->AddArc(state_map[q],
Arc(arc1.ilabel, arc2.olabel,
399 Times(arc1.weight, arc2.weight), next_state));
410 MutableFst<Arc> *fst_composed) {
413 typedef std::pair<StateId, StateId> StatePair;
414 typedef unordered_map<StatePair,
StateId,
416 typedef typename MapType::iterator IterType;
418 fst_composed->DeleteStates();
422 std::queue<StatePair> state_queue;
425 StateId s_left = left->
Start(),
426 s_right = right.Start();
427 if (s_left == kNoStateId || s_right == kNoStateId)
429 StatePair start_pair(s_left, s_right);
430 StateId start_state = fst_composed->AddState();
431 state_queue.push(start_pair);
432 fst_composed->SetStart(start_state);
435 std::pair<const StatePair, StateId> start_map(start_pair, start_state);
436 std::pair<IterType, bool> result = state_map.insert(start_map);
439 while (!state_queue.empty()) {
440 StatePair q = state_queue.front();
441 StateId q_left = q.first,
446 Weight final_weight =
Times(left->
Final(q_left), right.Final(q_right));
447 if (final_weight != Weight::Zero()) {
449 fst_composed->SetFinal(state_map[q], final_weight);
452 for (ArcIterator<Fst<Arc> > aiter(right, q_right); !aiter.Done(); aiter.Next()) {
453 const Arc &arc_right = aiter.Value();
456 StateId next_state_right = arc_right.nextstate,
462 if (arc_right.ilabel == 0) {
463 next_state_left = q_left;
465 bool match = left->
GetArc(q_left, arc_right.ilabel, &arc_left);
472 std::swap(arc_left.ilabel, arc_left.olabel);
473 next_state_left = arc_left.nextstate;
475 next_pair = StatePair(next_state_left, next_state_right);
476 IterType sitr = state_map.find(next_pair);
478 if (sitr == state_map.end()) {
479 next_state = fst_composed->AddState();
480 std::pair<const StatePair, StateId> new_state(
481 next_pair, next_state);
482 std::pair<IterType, bool> result = state_map.insert(new_state);
486 state_queue.push(next_pair);
490 next_state = sitr->second;
492 if (arc_right.ilabel == 0) {
494 fst_composed->AddArc(state_map[q],
Arc(0, arc_right.olabel,
498 fst_composed->AddArc(state_map[q],
499 Arc(arc_left.ilabel, arc_right.olabel,
500 Times(arc_left.weight, arc_right.weight),
std::vector< std::vector< Label > > state_vec_
fst::StdArc::StateId StateId
ComposeDeterministicOnDemandFst(DeterministicOnDemandFst< Arc > *fst1, DeterministicOnDemandFst< Arc > *fst2)
Note: constructor does not "take ownership" of the input fst's.
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)=0
Note: ilabel must not be epsilon.
BackoffDeterministicOnDemandFst(const Fst< Arc > &fst)
virtual Weight Final(StateId s)=0
unordered_map< std::vector< Label >, StateId, kaldi::VectorHasher< Label > > MapType
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)
Note: ilabel must not be epsilon.
StateId GetBackoffState(StateId s, Weight *w)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
void ComposeDeterministicOnDemand(const Fst< Arc > &fst1, DeterministicOnDemandFst< Arc > *fst2, MutableFst< Arc > *fst_composed)
void swap(basic_filebuf< CharT, Traits > &x, basic_filebuf< CharT, Traits > &y)
virtual StateId Start()=0
std::vector< std::vector< Label > > state_vec_
LmExampleDeterministicOnDemandFst(void *lm, Label bos_symbol, Label eos_symbol)
DeterministicOnDemandFst< Arc > * fst1_
void ComposeDeterministicOnDemandInverse(const Fst< Arc > &right, DeterministicOnDemandFst< Arc > *left, MutableFst< Arc > *fst_composed)
This function does '*fst_composed = Compose(Inverse(*fst2), fst1)' Note that the arguments are revers...
bool GetArc(StateId s, Label ilabel, Arc *oarc)
Note: ilabel must not be epsilon.
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
class DeterministicOnDemandFst is an "FST-like" base-class.
virtual Weight Final(StateId s)
size_t GetIndex(StateId src_state, Label ilabel)
UnweightedNgramFst(int n)
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)
Note: ilabel must not be epsilon.
std::vector< std::pair< StateId, Arc > > cached_arcs_
bool GetArc(StateId s, Label ilabel, Arc *oarc)
Note: ilabel must not be epsilon.
DeterministicOnDemandFst< Arc > * fst_
fst::StdArc::Weight Weight
DeterministicOnDemandFst< Arc > * fst2_
#define KALDI_ASSERT(cond)
CacheDeterministicOnDemandFst(DeterministicOnDemandFst< Arc > *fst, StateId num_cached_arcs=100000)
We don't take ownership of this pointer. The argument is "really" const.
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)
Note: ilabel must not be epsilon.
virtual Weight Final(StateId s)
We don't bother caching the final-probs, just the arcs.
std::vector< std::pair< StateId, StateId > > state_vec_
A hashing function-object for pairs of ints.