prune-special-inl.h
Go to the documentation of this file.
1 // fstext/prune-special-inl.h
2 
3 // Copyright 2014 Johns Hopkins University (Author: Daniel Povey)
4 // Guoguo Chen
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_FSTEXT_PRUNE_SPECIAL_INL_H_
22 #define KALDI_FSTEXT_PRUNE_SPECIAL_INL_H_
23 // Do not include this file directly. It is included by prune-special.h
24 
25 #include "fstext/prune-special.h"
26 #include "base/kaldi-error.h"
27 
28 namespace fst {
29 
30 
32 template<class Arc> class PruneSpecialClass {
33  public:
34  typedef typename Arc::StateId InputStateId;
35  typedef typename Arc::StateId OutputStateId;
36  typedef typename Arc::Weight Weight;
37  typedef typename Arc::Label Label;
38 
39  PruneSpecialClass(const Fst<Arc> &ifst,
40  VectorFst<Arc> *ofst,
41  Weight beam,
42  size_t max_states):
43  ifst_(ifst), ofst_(ofst), beam_(beam), max_states_(max_states),
44  best_weight_(Weight::Zero()) {
45  KALDI_ASSERT(beam != Weight::One());
46  KALDI_ASSERT(queue_.size() == 0);
47  ofst_->DeleteStates(); // make sure it's empty.
48  if (ifst_.Start() == kNoStateId)
49  return;
50  ofst_->SetStart(ProcessState(ifst_.Start(), Weight::One()));
51 
52  while (!queue_.empty()) {
53  Task task = queue_.top();
54  queue_.pop();
55  if (Done(task)) break;
56  else ProcessTask(task);
57  }
58  Connect(ofst);
59  if (beam_ != Weight::One())
60  Prune(ofst, beam_);
61  }
62 
63  struct Task {
64  InputStateId istate;
65  OutputStateId ostate; // could be looked up; this is for speed.
66  size_t position; // arc position, or -1 if final-prob.
67  Weight weight;
68 
69  Task(InputStateId istate, OutputStateId ostate, size_t position,
70  Weight weight): istate(istate), ostate(ostate), position(position),
71  weight(weight) { }
72  bool operator < (const Task &other) const {
73  return Compare(weight, other.weight) < 0;
74  }
75  };
76 
77  bool Done(const Task &task) {
78  if (beam_ != Weight::One() && best_weight_ != Weight::Zero() &&
79  Compare(task.weight, Times(best_weight_, beam_)) < 0)
80  return true;
81  if (max_states_ > 0 &&
82  static_cast<size_t>(ofst_->NumStates()) > max_states_)
83  return true;
84  return false;
85  }
86 
87 
88  // This function assumes "state" has not been seen before, so we need to
89  // create a new output-state for it and add tasks. It returns the
90  // output-state id. "weight" is the best cost from the start-state to this
91  // state.
92  inline OutputStateId ProcessState(InputStateId istate, const Weight &weight) {
93  OutputStateId ostate = ofst_->AddState();
95  for (ArcIterator<Fst<Arc> > aiter(ifst_, istate); !aiter.Done();
96  aiter.Next()) {
97  const Arc &arc = aiter.Value();
98  Task new_task(istate, ostate, aiter.Position(),
99  Times(weight, arc.weight));
100  KALDI_ASSERT(Compare(arc.weight, Weight::One()) != 1);
101  queue_.push(new_task);
102  }
103  Weight final = ifst_.Final(istate);
104  if (final != Weight::Zero()) {
105  Task final_task(istate, ostate, static_cast<size_t>(-1),
106  Times(weight, final));
107  KALDI_ASSERT(Compare(final, Weight::One()) != 1);
108  queue_.push(final_task);
109  }
110  return ostate;
111  }
112 
113  // Returns the output-state id corresponding to "istate". This assumes we are
114  // processing a task corresponding to an arc to "istate", and the cost from
115  // the start-state to this state is "weight". Since we process tasks in
116  // order, if this is the first time we see this istate, then this is the best
117  // cost from the start-state to this state, and it can be used in setting the
118  // priority costs in ProcessState().
119  inline OutputStateId GetOutputStateId(InputStateId istate,
120  const Weight &weight) {
121  typedef typename unordered_map<InputStateId, OutputStateId>::iterator IterType;
122  IterType iter = state_map_.find(istate);
123  if (iter == state_map_.end())
124  return ProcessState(istate, weight);
125  else
126  return iter->second;
127  }
128 
129  void ProcessTask(const Task &task) {
130  if (task.position == static_cast<size_t>(-1)) {
131  ofst_->SetFinal(task.ostate, ifst_.Final(task.istate));
132  if (best_weight_ == Weight::Zero())
133  best_weight_ = task.weight; // best-path cost through FST, used for
134  // beam-pruning.
135  } else {
136  ArcIterator<Fst<Arc> > aiter(ifst_, task.istate);
137  aiter.Seek(task.position); // if we spend most of our time here, we may
138  // need to store the arc in the Task.
139  const Arc &arc = aiter.Value();
140  InputStateId next_istate = arc.nextstate;
141  OutputStateId next_ostate = GetOutputStateId(next_istate, task.weight);
142  Arc oarc(arc.ilabel, arc.olabel, arc.weight, next_ostate);
143  ofst_->AddArc(task.ostate, oarc);
144  }
145  }
146 
147  private:
148  const Fst<Arc> &ifst_;
149  VectorFst<Arc> *ofst_;
150  Weight beam_;
151  size_t max_states_;
152 
153  unordered_map<InputStateId, OutputStateId> state_map_;
154  std::priority_queue<Task> queue_;
155  Weight best_weight_; // if not Zero(), then we have now processed a successful path
156  // through ifst_, and this is the weight.
157 
158 };
159 
160 template<class Arc>
161 void PruneSpecial(const Fst<Arc> &ifst,
162  VectorFst<Arc> *ofst,
163  typename Arc::Weight beam,
164  size_t max_states) {
165  PruneSpecialClass<Arc> c(ifst, ofst, beam, max_states);
166 }
167 
168 
169 
170 }
171 
172 
173 #endif
fst::StdArc::StateId StateId
void ProcessTask(const Task &task)
void PruneSpecial(const Fst< Arc > &ifst, VectorFst< Arc > *ofst, typename Arc::Weight beam, size_t max_states)
The function PruneSpecial is like the standard OpenFst function "prune", except it does not expand th...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
PruneSpecialClass(const Fst< Arc > &ifst, VectorFst< Arc > *ofst, Weight beam, size_t max_states)
std::priority_queue< Task > queue_
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
bool operator<(const Task &other) const
unordered_map< InputStateId, OutputStateId > state_map_
VectorFst< Arc > * ofst_
fst::StdArc::Label Label
Task(InputStateId istate, OutputStateId ostate, size_t position, Weight weight)
fst::StdArc::Weight Weight
int Compare(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
OutputStateId ProcessState(InputStateId istate, const Weight &weight)
This class is used to implement the function PruneSpecial.
OutputStateId GetOutputStateId(InputStateId istate, const Weight &weight)
bool Done(const Task &task)
const Fst< Arc > & ifst_