remove-eps-local-inl.h
Go to the documentation of this file.
1 // fstext/remove-eps-local-inl.h
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 // 2014 Johns Hopkins University (author: Daniel Povey
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
22 #define KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
23 
24 
25 namespace fst {
26 
27 
28 template<class Weight>
30  inline Weight operator () (const Weight &a, const Weight &b) {
31  return Plus(a, b);
32  }
33 };
34 
36  inline TropicalWeight operator () (const TropicalWeight &a,
37  const TropicalWeight &b) {
38  LogWeight a_log(a.Value()), b_log(b.Value());
39  return TropicalWeight(Plus(a_log, b_log).Value());
40  }
41 };
42 
43 
44 
45 template<class Arc, class ReweightPlus = ReweightPlusDefault<typename Arc::Weight> >
47  typedef typename Arc::StateId StateId;
48  typedef typename Arc::Label Label;
49  typedef typename Arc::Weight Weight;
50 
51  public:
52  RemoveEpsLocalClass(MutableFst<Arc> *fst):
53  fst_(fst) {
54  if (fst_->Start() == kNoStateId) return; // empty.
55  non_coacc_state_ = fst_->AddState();
56  InitNumArcs();
57  StateId num_states = fst_->NumStates();
58  for (StateId s = 0; s < num_states; s++)
59  for (size_t pos = 0; pos < fst_->NumArcs(s); pos++)
60  RemoveEps(s, pos);
61  assert(CheckNumArcs());
62  Connect(fst); // remove inaccessible states.
63  }
64  private:
65  MutableFst<Arc> *fst_;
66  StateId non_coacc_state_; // use this to delete arcs: make it nextstate
67  std::vector<StateId> num_arcs_in_; // The number of arcs into the state, plus one
68  // if it's the start state.
69  std::vector<StateId> num_arcs_out_; // The number of arcs out of the state, plus
70  // one if it's a final state.
71  ReweightPlus reweight_plus_;
72 
73  bool CanCombineArcs(const Arc &a, const Arc &b, Arc *c) {
74  if (a.ilabel != 0 && b.ilabel != 0) return false;
75  if (a.olabel != 0 && b.olabel != 0) return false;
76  c->weight = Times(a.weight, b.weight);
77  c->ilabel = (a.ilabel != 0 ? a.ilabel : b.ilabel);
78  c->olabel = (a.olabel != 0 ? a.olabel : b.olabel);
79  c->nextstate = b.nextstate;
80  return true;
81  }
82 
83  static bool CanCombineFinal(const Arc &a, Weight final_prob, Weight *final_prob_out) {
84  if (a.ilabel != 0 || a.olabel != 0) return false;
85  else {
86  *final_prob_out = Times(a.weight, final_prob);
87  return true;
88  }
89  }
90 
91  void InitNumArcs() { // init num transitions in/out of each state.
92  StateId num_states = fst_->NumStates();
93  num_arcs_in_.resize(num_states);
94  num_arcs_out_.resize(num_states);
95  num_arcs_in_[fst_->Start()]++; // count start as trans in.
96  for (StateId s = 0; s < num_states; s++) {
97  if (fst_->Final(s) != Weight::Zero())
98  num_arcs_out_[s]++; // count final as transition.
99  for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
100  num_arcs_in_[aiter.Value().nextstate]++;
101  num_arcs_out_[s]++;
102  }
103  }
104  }
105 
106  bool CheckNumArcs() { // check num arcs in/out of each state, at end. Debug.
107  num_arcs_in_[fst_->Start()]--; // count start as trans in.
108  StateId num_states = fst_->NumStates();
109  for (StateId s = 0; s < num_states; s++) {
110  if (s == non_coacc_state_) continue;
111  if (fst_->Final(s) != Weight::Zero())
112  num_arcs_out_[s]--; // count final as transition.
113  for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
114  if (aiter.Value().nextstate == non_coacc_state_) continue;
115  num_arcs_in_[aiter.Value().nextstate]--;
116  num_arcs_out_[s]--;
117  }
118  }
119  for (StateId s = 0; s < num_states; s++) {
120  assert(num_arcs_in_[s] == 0);
121  assert(num_arcs_out_[s] == 0);
122  }
123  return true; // always does this. so we can assert it w/o warnings.
124  }
125 
126  inline void GetArc(StateId s, size_t pos, Arc *arc) const {
127  ArcIterator<MutableFst<Arc> > aiter(*fst_, s);
128  aiter.Seek(pos);
129  *arc = aiter.Value();
130  }
131 
132  inline void SetArc(StateId s, size_t pos, const Arc &arc) {
133  MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
134  aiter.Seek(pos);
135  aiter.SetValue(arc);
136  }
137 
138 
139  void Reweight(StateId s, size_t pos, Weight reweight) {
140  // Reweight is called from RemoveEpsPattern1; it is a step we
141  // do to preserve stochasticity. This function multiplies the
142  // arc at (s, pos) by reweight and divides all the arcs [+final-prob]
143  // out of the next state by the same. This is only valid if
144  // the next state has only one arc in and is not the start state.
145  assert(reweight != Weight::Zero());
146  MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
147  aiter.Seek(pos);
148  Arc arc = aiter.Value();
149  assert(num_arcs_in_[arc.nextstate] == 1);
150  arc.weight = Times(arc.weight, reweight);
151  aiter.SetValue(arc);
152 
153  for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, arc.nextstate);
154  !aiter_next.Done();
155  aiter_next.Next()) {
156  Arc nextarc = aiter_next.Value();
157  if (nextarc.nextstate != non_coacc_state_) {
158  nextarc.weight = Divide(nextarc.weight, reweight, DIVIDE_LEFT);
159  aiter_next.SetValue(nextarc);
160  }
161  }
162  Weight final = fst_->Final(arc.nextstate);
163  if (final != Weight::Zero()) {
164  fst_->SetFinal(arc.nextstate, Divide(final, reweight, DIVIDE_LEFT));
165  }
166  }
167 
168  // RemoveEpsPattern1 applies where this arc, which is not a
169  // self-loop, enters a state which has only one input transition
170  // [and is not the start state], and has multiple output
171  // transitions [counting being the final-state as a final-transition].
172 
173  void RemoveEpsPattern1(StateId s, size_t pos, Arc arc) {
174  const StateId nextstate = arc.nextstate;
175  Weight total_removed = Weight::Zero(),
176  total_kept = Weight::Zero(); // totals out of nextstate.
177  std::vector<Arc> arcs_to_add; // to add to state s.
178  for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
179  !aiter_next.Done();
180  aiter_next.Next()) {
181  Arc nextarc = aiter_next.Value();
182  if (nextarc.nextstate == non_coacc_state_) continue; // deleted.
183  Arc combined;
184  if (CanCombineArcs(arc, nextarc, &combined)) {
185  total_removed = reweight_plus_(total_removed, nextarc.weight);
186  num_arcs_out_[nextstate]--;
187  num_arcs_in_[nextarc.nextstate]--;
188  nextarc.nextstate = non_coacc_state_;
189  aiter_next.SetValue(nextarc);
190  arcs_to_add.push_back(combined);
191  } else {
192  total_kept = reweight_plus_(total_kept, nextarc.weight);
193  }
194  }
195 
196  { // now final-state.
197  Weight next_final = fst_->Final(nextstate);
198  if (next_final != Weight::Zero()) {
199  Weight new_final;
200  if (CanCombineFinal(arc, next_final, &new_final)) {
201  total_removed = reweight_plus_(total_removed, next_final);
202  if (fst_->Final(s) == Weight::Zero())
203  num_arcs_out_[s]++; // final is counted as arc.
204  fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
205  num_arcs_out_[nextstate]--;
206  fst_->SetFinal(nextstate, Weight::Zero());
207  } else {
208  total_kept = reweight_plus_(total_kept, next_final);
209  }
210  }
211  }
212 
213  if (total_removed != Weight::Zero()) { // did something...
214  if (total_kept == Weight::Zero()) { // removed everything: remove arc.
215  num_arcs_out_[s]--;
216  num_arcs_in_[arc.nextstate]--;
217  arc.nextstate = non_coacc_state_;
218  SetArc(s, pos, arc);
219  } else {
220  // Have to reweight.
221  Weight total = reweight_plus_(total_removed, total_kept);
222  Weight reweight = Divide(total_kept, total, DIVIDE_LEFT); // <=1
223  Reweight(s, pos, reweight);
224  }
225  }
226  // Now add the arcs we were going to add.
227  for (size_t i = 0; i < arcs_to_add.size(); i++) {
228  num_arcs_out_[s]++;
229  num_arcs_in_[arcs_to_add[i].nextstate]++;
230  fst_->AddArc(s, arcs_to_add[i]);
231  }
232  }
233 
234  void RemoveEpsPattern2(StateId s, size_t pos, Arc arc) {
235 
236  // Pattern 2 is where "nextstate" has only one arc out, counting
237  // being-the-final-state as an arc, but possibly multiple arcs in.
238  // Also, nextstate != s.
239 
240  const StateId nextstate = arc.nextstate;
241  bool can_delete_next = (num_arcs_in_[nextstate] == 1); // if
242  // we combine, can delete the corresponding out-arc/final-prob
243  // of nextstate.
244  bool delete_arc = false; // set to true if this arc to be deleted.
245 
246  Weight next_final = fst_->Final(arc.nextstate);
247  if (next_final != Weight::Zero()) { // nextstate has no actual arcs out, only final-prob.
248  Weight new_final;
249  if (CanCombineFinal(arc, next_final, &new_final)) {
250  if (fst_->Final(s) == Weight::Zero())
251  num_arcs_out_[s]++; // final is counted as arc.
252  fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
253  delete_arc = true; // will delete "arc".
254  if (can_delete_next) {
255  num_arcs_out_[nextstate]--;
256  fst_->SetFinal(nextstate, Weight::Zero());
257  }
258  }
259  } else { // has an arc but no final prob.
260  MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
261  assert(!aiter_next.Done());
262  while (aiter_next.Value().nextstate == non_coacc_state_) {
263  aiter_next.Next();
264  assert(!aiter_next.Done());
265  }
266  // now aiter_next points to a real arc out of nextstate.
267  Arc nextarc = aiter_next.Value();
268  Arc combined;
269  if (CanCombineArcs(arc, nextarc, &combined)) {
270  delete_arc = true;
271  if (can_delete_next) { // do it before we invalidate iterators
272  num_arcs_out_[nextstate]--;
273  num_arcs_in_[nextarc.nextstate]--;
274  nextarc.nextstate = non_coacc_state_;
275  aiter_next.SetValue(nextarc);
276  }
277  num_arcs_out_[s]++;
278  num_arcs_in_[combined.nextstate]++;
279  fst_->AddArc(s, combined);
280  }
281  }
282  if (delete_arc) {
283  num_arcs_out_[s]--;
284  num_arcs_in_[nextstate]--;
285  arc.nextstate = non_coacc_state_;
286  SetArc(s, pos, arc);
287  }
288  }
289 
290  void RemoveEps(StateId s, size_t pos) {
291  // Tries to do local epsilon-removal for arc sequences starting with this arc
292  Arc arc;
293  GetArc(s, pos, &arc);
294  StateId nextstate = arc.nextstate;
295  if (nextstate == non_coacc_state_) return; // deleted arc.
296  if (nextstate == s) return; // don't handle self-loops: too complex.
297 
298  if (num_arcs_in_[nextstate] == 1 && num_arcs_out_[nextstate] > 1) {
299  RemoveEpsPattern1(s, pos, arc);
300  } else if (num_arcs_out_[nextstate] == 1) {
301  RemoveEpsPattern2(s, pos, arc);
302  }
303  }
304 
305 };
306 
307 
308 template<class Arc>
309 void RemoveEpsLocal(MutableFst<Arc> *fst) {
310  RemoveEpsLocalClass<Arc> c(fst); // work gets done in initializer.
311 }
312 
313 
314 void RemoveEpsLocalSpecial(MutableFst<StdArc> *fst) {
315  // work gets done in initializer.
317 }
318 
319 } // end namespace fst.
320 
321 #endif
fst::StdArc::StateId StateId
LatticeWeightTpl< FloatType > Divide(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, DivideType typ=DIVIDE_ANY)
std::vector< StateId > num_arcs_out_
bool CanCombineArcs(const Arc &a, const Arc &b, Arc *c)
void RemoveEpsLocal(MutableFst< Arc > *fst)
RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST, using an algorithm that is g...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
LatticeWeightTpl< FloatType > Plus(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
std::vector< StateId > num_arcs_in_
void RemoveEpsLocalSpecial(MutableFst< StdArc > *fst)
As RemoveEpsLocal but takes care to preserve stochasticity when cast to LogArc.
void RemoveEpsPattern2(StateId s, size_t pos, Arc arc)
void RemoveEps(StateId s, size_t pos)
void RemoveEpsPattern1(StateId s, size_t pos, Arc arc)
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
void SetArc(StateId s, size_t pos, const Arc &arc)
Weight operator()(const Weight &a, const Weight &b)
fst::StdArc::Label Label
fst::StdArc::Weight Weight
RemoveEpsLocalClass(MutableFst< Arc > *fst)
void GetArc(StateId s, size_t pos, Arc *arc) const
void Reweight(StateId s, size_t pos, Weight reweight)
static bool CanCombineFinal(const Arc &a, Weight final_prob, Weight *final_prob_out)