minimize-lattice.cc
Go to the documentation of this file.
1 // lat/minimize-lattice.cc
2 
3 // Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal)
4 // 2012-2013 Johns Hopkins University (Author: Daniel Povey); Chao Weng;
5 // Bagher BabaAli
6 // 2014 Guoguo Chen
7 
8 // See ../../COPYING for clarification regarding multiple authors
9 //
10 // Licensed under the Apache License, Version 2.0 (the "License");
11 // you may not use this file except in compliance with the License.
12 // You may obtain a copy of the License at
13 //
14 // http://www.apache.org/licenses/LICENSE-2.0
15 //
16 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
18 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
19 // MERCHANTABLITY OR NON-INFRINGEMENT.
20 // See the Apache 2 License for the specific language governing permissions and
21 // limitations under the License.
22 
23 
24 #include "lat/minimize-lattice.h"
25 #include "hmm/transition-model.h"
26 #include "util/stl-utils.h"
27 
28 namespace fst {
29 
30 /*
31  Process the states in reverse topological order.
32  For each state, compute a hash-value that will be the same for states
33  that can be combined. Then for each pair of states with the
34  same hash value, check that the "to-states" map to the
35  same equivalence class and that the weights are sufficiently similar.
36 */
37 
38 template<class Weight, class IntType> class CompactLatticeMinimizer {
39  public:
41  typedef ArcTpl<CompactWeight> CompactArc;
42  typedef typename CompactArc::StateId StateId;
43  typedef typename CompactArc::Label Label;
44  typedef size_t HashType;
45 
46  CompactLatticeMinimizer(MutableFst<CompactArc> *clat,
47  float delta = fst::kDelta):
48  clat_(clat), delta_(delta) { }
49 
50  bool Minimize() {
51  if (clat_->Properties(kTopSorted, true) == 0) {
52  if (!TopSort(clat_)) {
53  KALDI_WARN << "Topological sorting of state-level lattice failed "
54  "(probably your lexicon has empty words or your LM has epsilon cycles; this "
55  " is a bad idea.)";
56  return false;
57  }
58  }
61  ModifyModel();
62  return true;
63  }
64 
65  static HashType ConvertStringToHashValue(const std::vector<IntType> &vec) {
66  const HashType prime = 53281;
68  HashType ans = static_cast<HashType>(h(vec));
69  if (ans == 0) ans = prime;
70  // We don't allow a zero answer, as this can cause too many values to be the
71  // same.
72  return ans;
73  }
74 
75  static void InitHashValue(const CompactWeight &final_weight, HashType *h) {
76  const HashType prime1 = 33317, prime2 = 607; // it's pretty random.
77  if (final_weight == CompactWeight::Zero()) *h = prime1;
78  else *h = prime2 * ConvertStringToHashValue(final_weight.String());
79  }
80 
81  // It's important that this function and UpdateHashValueForFinalProb be
82  // insensitive to the order in which it's called, as the order of the arcs
83  // won't necessarily be the same for different equivalent states.
84  static void UpdateHashValueForTransition(const CompactWeight &weight,
85  Label label,
86  HashType &next_state_hash,
87  HashType *h) {
88  const HashType prime1 = 1447, prime2 = 51907;
89  if (label == 0) label = prime2; // Zeros will cause problems.
90  *h += prime1 * label *
91  (1 + ConvertStringToHashValue(weight.String()) * next_state_hash);
92  // Above, the "1 +" is to ensure that if somehow we get zeros due to
93  // weird word sequences, they don't propagate.
94  }
95 
97  // Note: clat_ is topologically sorted, and StateId is
98  // signed. Each state's hash value is only a function of toplogically-later
99  // states' hash values.
100  state_hashes_.resize(clat_->NumStates());
101  for (StateId s = clat_->NumStates() - 1; s >= 0; s--) {
102  HashType this_hash;
103  InitHashValue(clat_->Final(s), &this_hash);
104  for (ArcIterator<MutableFst<CompactArc> > aiter(*clat_, s);
105  !aiter.Done(); aiter.Next()) {
106  const CompactArc &arc = aiter.Value();
107  HashType next_hash;
108  if (arc.nextstate > s) {
109  next_hash = state_hashes_[arc.nextstate];
110  } else {
111  KALDI_ASSERT(s == arc.nextstate &&
112  "Lattice not topologically sorted [code error]");
113  next_hash = 1;
114  KALDI_WARN << "Minimizing lattice with self-loops "
115  "(lattices should not have self-loops)";
116  }
117  UpdateHashValueForTransition(arc.weight, arc.ilabel,
118  next_hash, &this_hash);
119  }
120  state_hashes_[s] = this_hash;
121  }
122  }
123 
124 
125 
127  // This struct has an operator () which you can interpret as a less-than (<)
128  // operator for arcs. We sort on ilabel; since the lattice is supposed to
129  // be deterministic, this should completely determine the ordering (there
130  // should not be more than one arc with the same ilabel, out of the same
131  // state). For identical ilabels we next sort on the nextstate, simply to
132  // better handle non-deterministic input (we do our best on this, without
133  // guaranteeing full minimization). We could sort on the strings next, but
134  // this would be an unnecessary hassle as we only really need good
135  // performance on deterministic input.
136  bool operator () (const CompactArc &a, const CompactArc &b) const {
137  if (a.ilabel < b.ilabel) return true;
138  else if (a.ilabel > b.ilabel) return false;
139  else if (a.nextstate < b.nextstate) return true;
140  else return false;
141  }
142  };
143 
144 
145  // This function works out whether s and t are equivalent, assuming
146  // we have already partitioned all topologically-later states into
147  // equivalence classes (i.e. set up state_map_).
148  bool Equivalent(StateId s, StateId t) const {
149  if (!ApproxEqual(clat_->Final(s), clat_->Final(t), delta_))
150  return false;
151  if (clat_->NumArcs(s) != clat_->NumArcs(t))
152  return false;
153  std::vector<CompactArc> s_arcs;
154  std::vector<CompactArc> t_arcs;
155  for (int32 iter = 0; iter <= 1; iter++) {
156  StateId state = (iter == 0 ? s : t);
157  std::vector<CompactArc> &arcs = (iter == 0 ? s_arcs : t_arcs);
158  arcs.reserve(clat_->NumArcs(s));
159  for (ArcIterator<MutableFst<CompactArc> > aiter(*clat_, state);
160  !aiter.Done(); aiter.Next()) {
161  CompactArc arc = aiter.Value();
162  if (arc.nextstate == state) {
163  // This is a special case for states that have self-loops. If two
164  // states have an identical self-loop arc, they may be equivalent.
165  arc.nextstate = kNoStateId;
166  } else {
167  KALDI_ASSERT(arc.nextstate > state);
168  //while (state_map_[arc.nextstate] != arc.nextstate)
169  arc.nextstate = state_map_[arc.nextstate];
170  arcs.push_back(arc);
171  }
172  }
174  std::sort(arcs.begin(), arcs.end(), s);
175  }
176  KALDI_ASSERT(s_arcs.size() == t_arcs.size());
177  for (size_t i = 0; i < s_arcs.size(); i++) {
178  if (s_arcs[i].nextstate != t_arcs[i].nextstate) return false;
179  KALDI_ASSERT(s_arcs[i].ilabel == s_arcs[i].olabel); // CompactLattices are
180  // supposed to be
181  // acceptors.
182  if (s_arcs[i].ilabel != t_arcs[i].ilabel) return false;
183  // We've already mapped to equivalence classes.
184  if (s_arcs[i].nextstate != t_arcs[i].nextstate) return false;
185  if (!ApproxEqual(s_arcs[i].weight, t_arcs[i].weight)) return false;
186  }
187  return true;
188  }
189 
191  // We have to compute the state mapping in reverse topological order also,
192  // since the equivalence test relies on later states being already sorted
193  // out into equivalence classes (by state_map_).
194  StateId num_states = clat_->NumStates();
195  unordered_map<HashType, std::vector<StateId> > hash_groups_;
196 
197  for (StateId s = 0; s < num_states; s++)
198  hash_groups_[state_hashes_[s]].push_back(s);
199 
200  state_map_.resize(num_states);
201  for (StateId s = 0; s < num_states; s++)
202  state_map_[s] = s; // Default mapping.
203 
204 
205  { // This block is just diagnostic.
206  typedef typename unordered_map<HashType,
207  std::vector<StateId> >::const_iterator HashIter;
208  size_t max_size = 0;
209  for (HashIter iter = hash_groups_.begin(); iter != hash_groups_.end();
210  ++iter)
211  max_size = std::max(max_size, iter->second.size());
212  if (max_size > 1000) {
213  KALDI_WARN << "Largest equivalence group (using hash) is " << max_size
214  << ", minimization might be slow.";
215  }
216  }
217 
218  for (StateId s = num_states - 1; s >= 0; s--) {
219  HashType hash = state_hashes_[s];
220  const std::vector<StateId> &equivalence_class = hash_groups_[hash];
221  KALDI_ASSERT(!equivalence_class.empty());
222  for (size_t i = 0; i < equivalence_class.size(); i++) {
223  StateId t = equivalence_class[i];
224  // Below, there is no point doing the test if state_map_[t] != t, because
225  // in that case we will, before after this, be comparing with another state
226  // that is equivalent to t.
227  if (t > s && state_map_[t] == t && Equivalent(s, t)) {
228  state_map_[s] = t;
229  break;
230  }
231  }
232  }
233  }
234 
235  void ModifyModel() {
236  // Modifies the model according to state_map_;
237 
238  StateId num_removed = 0;
239  StateId num_states = clat_->NumStates();
240  for (StateId s = 0; s < num_states; s++)
241  if (state_map_[s] != s)
242  num_removed++;
243  KALDI_VLOG(3) << "Removing " << num_removed << " of "
244  << num_states << " states.";
245  if (num_removed == 0) return; // Nothing to do.
246 
247  clat_->SetStart(state_map_[clat_->Start()]);
248 
249  for (StateId s = 0; s < num_states; s++) {
250  if (state_map_[s] != s)
251  continue; // There is no point modifying states we're removing.
252  for (MutableArcIterator<MutableFst<CompactArc> > aiter(clat_, s);
253  !aiter.Done(); aiter.Next()) {
254  CompactArc arc = aiter.Value();
255  StateId mapped_nextstate = state_map_[arc.nextstate];
256  if (mapped_nextstate != arc.nextstate) {
257  arc.nextstate = mapped_nextstate;
258  aiter.SetValue(arc);
259  }
260  }
261  }
262  fst::Connect(clat_);
263  }
264  private:
265  MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *clat_;
266  float delta_;
267  std::vector<HashType> state_hashes_;
268  std::vector<StateId> state_map_; // maps each state to itself or to some
269  // equivalent state. Within each equivalence
270  // class, we pick one arbitrarily.
271 };
272 
273 template<class Weight, class IntType>
275  MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *clat,
276  float delta) {
277  CompactLatticeMinimizer<Weight, IntType> minimizer(clat, delta);
278  return minimizer.Minimize();
279 }
280 
281 // Instantiate for CompactLattice type.
282 template
283 bool MinimizeCompactLattice<kaldi::LatticeWeight, kaldi::int32>(
284  MutableFst<kaldi::CompactLatticeArc> *clat, float delta);
285 
286 
287 } // namespace fst
fst::StdArc::StateId StateId
static void InitHashValue(const CompactWeight &final_weight, HashType *h)
static void UpdateHashValueForTransition(const CompactWeight &weight, Label label, HashType &next_state_hash, HashType *h)
A hashing function-object for vectors.
Definition: stl-utils.h:216
bool operator()(const CompactArc &a, const CompactArc &b) const
bool Equivalent(StateId s, StateId t) const
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
kaldi::int32 int32
bool ApproxEqual(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, float delta=kDelta)
std::vector< StateId > state_map_
CompactLatticeWeightTpl< Weight, IntType > CompactWeight
bool MinimizeCompactLattice(MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > *clat, float delta)
This function minimizes the compact lattice.
static HashType ConvertStringToHashValue(const std::vector< IntType > &vec)
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::StdArc::Label Label
ArcTpl< CompactWeight > CompactArc
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< HashType > state_hashes_
static const CompactLatticeWeightTpl< WeightType, IntType > Zero()
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > * clat_
CompactLatticeMinimizer(MutableFst< CompactArc > *clat, float delta=fst::kDelta)
const std::vector< IntType > & String() const