21 #ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_ 22 #define KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_ 28 template<
class Weight>
37 const TropicalWeight &b) {
38 LogWeight a_log(a.Value()), b_log(b.Value());
39 return TropicalWeight(
Plus(a_log, b_log).Value());
45 template<
class Arc,
class ReweightPlus = ReweightPlusDefault<
typename Arc::Weight> >
54 if (fst_->Start() == kNoStateId)
return;
55 non_coacc_state_ = fst_->AddState();
57 StateId num_states = fst_->NumStates();
58 for (StateId s = 0; s < num_states; s++)
59 for (
size_t pos = 0; pos < fst_->NumArcs(s); pos++)
61 assert(CheckNumArcs());
74 if (a.ilabel != 0 && b.ilabel != 0)
return false;
75 if (a.olabel != 0 && b.olabel != 0)
return false;
76 c->weight =
Times(a.weight, b.weight);
77 c->ilabel = (a.ilabel != 0 ? a.ilabel : b.ilabel);
78 c->olabel = (a.olabel != 0 ? a.olabel : b.olabel);
79 c->nextstate = b.nextstate;
84 if (a.ilabel != 0 || a.olabel != 0)
return false;
86 *final_prob_out =
Times(a.weight, final_prob);
92 StateId num_states = fst_->NumStates();
93 num_arcs_in_.resize(num_states);
94 num_arcs_out_.resize(num_states);
95 num_arcs_in_[fst_->Start()]++;
96 for (StateId s = 0; s < num_states; s++) {
97 if (fst_->Final(s) != Weight::Zero())
99 for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
100 num_arcs_in_[aiter.Value().nextstate]++;
107 num_arcs_in_[fst_->Start()]--;
108 StateId num_states = fst_->NumStates();
109 for (StateId s = 0; s < num_states; s++) {
110 if (s == non_coacc_state_)
continue;
111 if (fst_->Final(s) != Weight::Zero())
113 for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
114 if (aiter.Value().nextstate == non_coacc_state_)
continue;
115 num_arcs_in_[aiter.Value().nextstate]--;
119 for (StateId s = 0; s < num_states; s++) {
120 assert(num_arcs_in_[s] == 0);
121 assert(num_arcs_out_[s] == 0);
126 inline void GetArc(StateId s,
size_t pos,
Arc *arc)
const {
127 ArcIterator<MutableFst<Arc> > aiter(*fst_, s);
129 *arc = aiter.Value();
132 inline void SetArc(StateId s,
size_t pos,
const Arc &arc) {
133 MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
139 void Reweight(StateId s,
size_t pos, Weight reweight) {
145 assert(reweight != Weight::Zero());
146 MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
148 Arc arc = aiter.Value();
149 assert(num_arcs_in_[arc.nextstate] == 1);
150 arc.weight =
Times(arc.weight, reweight);
153 for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, arc.nextstate);
156 Arc nextarc = aiter_next.Value();
157 if (nextarc.nextstate != non_coacc_state_) {
158 nextarc.weight =
Divide(nextarc.weight, reweight, DIVIDE_LEFT);
159 aiter_next.SetValue(nextarc);
162 Weight
final = fst_->Final(arc.nextstate);
163 if (
final != Weight::Zero()) {
164 fst_->SetFinal(arc.nextstate,
Divide(
final, reweight, DIVIDE_LEFT));
174 const StateId nextstate = arc.nextstate;
175 Weight total_removed = Weight::Zero(),
176 total_kept = Weight::Zero();
177 std::vector<Arc> arcs_to_add;
178 for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
181 Arc nextarc = aiter_next.Value();
182 if (nextarc.nextstate == non_coacc_state_)
continue;
184 if (CanCombineArcs(arc, nextarc, &combined)) {
185 total_removed = reweight_plus_(total_removed, nextarc.weight);
186 num_arcs_out_[nextstate]--;
187 num_arcs_in_[nextarc.nextstate]--;
188 nextarc.nextstate = non_coacc_state_;
189 aiter_next.SetValue(nextarc);
190 arcs_to_add.push_back(combined);
192 total_kept = reweight_plus_(total_kept, nextarc.weight);
197 Weight next_final = fst_->Final(nextstate);
198 if (next_final != Weight::Zero()) {
200 if (CanCombineFinal(arc, next_final, &new_final)) {
201 total_removed = reweight_plus_(total_removed, next_final);
202 if (fst_->Final(s) == Weight::Zero())
204 fst_->SetFinal(s,
Plus(fst_->Final(s), new_final));
205 num_arcs_out_[nextstate]--;
206 fst_->SetFinal(nextstate, Weight::Zero());
208 total_kept = reweight_plus_(total_kept, next_final);
213 if (total_removed != Weight::Zero()) {
214 if (total_kept == Weight::Zero()) {
216 num_arcs_in_[arc.nextstate]--;
217 arc.nextstate = non_coacc_state_;
221 Weight total = reweight_plus_(total_removed, total_kept);
222 Weight reweight =
Divide(total_kept, total, DIVIDE_LEFT);
223 Reweight(s, pos, reweight);
227 for (
size_t i = 0;
i < arcs_to_add.size();
i++) {
229 num_arcs_in_[arcs_to_add[
i].nextstate]++;
230 fst_->AddArc(s, arcs_to_add[
i]);
240 const StateId nextstate = arc.nextstate;
241 bool can_delete_next = (num_arcs_in_[nextstate] == 1);
244 bool delete_arc =
false;
246 Weight next_final = fst_->Final(arc.nextstate);
247 if (next_final != Weight::Zero()) {
249 if (CanCombineFinal(arc, next_final, &new_final)) {
250 if (fst_->Final(s) == Weight::Zero())
252 fst_->SetFinal(s,
Plus(fst_->Final(s), new_final));
254 if (can_delete_next) {
255 num_arcs_out_[nextstate]--;
256 fst_->SetFinal(nextstate, Weight::Zero());
260 MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
261 assert(!aiter_next.Done());
262 while (aiter_next.Value().nextstate == non_coacc_state_) {
264 assert(!aiter_next.Done());
267 Arc nextarc = aiter_next.Value();
269 if (CanCombineArcs(arc, nextarc, &combined)) {
271 if (can_delete_next) {
272 num_arcs_out_[nextstate]--;
273 num_arcs_in_[nextarc.nextstate]--;
274 nextarc.nextstate = non_coacc_state_;
275 aiter_next.SetValue(nextarc);
278 num_arcs_in_[combined.nextstate]++;
279 fst_->AddArc(s, combined);
284 num_arcs_in_[nextstate]--;
285 arc.nextstate = non_coacc_state_;
293 GetArc(s, pos, &arc);
294 StateId nextstate = arc.nextstate;
295 if (nextstate == non_coacc_state_)
return;
296 if (nextstate == s)
return;
298 if (num_arcs_in_[nextstate] == 1 && num_arcs_out_[nextstate] > 1) {
299 RemoveEpsPattern1(s, pos, arc);
300 }
else if (num_arcs_out_[nextstate] == 1) {
301 RemoveEpsPattern2(s, pos, arc);
fst::StdArc::StateId StateId
LatticeWeightTpl< FloatType > Divide(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, DivideType typ=DIVIDE_ANY)
std::vector< StateId > num_arcs_out_
bool CanCombineArcs(const Arc &a, const Arc &b, Arc *c)
void RemoveEpsLocal(MutableFst< Arc > *fst)
RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST, using an algorithm that is g...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
LatticeWeightTpl< FloatType > Plus(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
std::vector< StateId > num_arcs_in_
void RemoveEpsLocalSpecial(MutableFst< StdArc > *fst)
As RemoveEpsLocal but takes care to preserve stochasticity when cast to LogArc.
void RemoveEpsPattern2(StateId s, size_t pos, Arc arc)
void RemoveEps(StateId s, size_t pos)
void RemoveEpsPattern1(StateId s, size_t pos, Arc arc)
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
void SetArc(StateId s, size_t pos, const Arc &arc)
Weight operator()(const Weight &a, const Weight &b)
fst::StdArc::Weight Weight
RemoveEpsLocalClass(MutableFst< Arc > *fst)
ReweightPlus reweight_plus_
void GetArc(StateId s, size_t pos, Arc *arc) const
void Reweight(StateId s, size_t pos, Weight reweight)
static bool CanCombineFinal(const Arc &a, Weight final_prob, Weight *final_prob_out)