push-lattice.cc
Go to the documentation of this file.
1 // lat/push-lattice.cc
2 
3 // Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal)
4 // 2012-2013 Johns Hopkins University (Author: Daniel Povey); Chao Weng;
5 // Bagher BabaAli
6 // 2014 Guoguo Chen
7 
8 // See ../../COPYING for clarification regarding multiple authors
9 //
10 // Licensed under the Apache License, Version 2.0 (the "License");
11 // you may not use this file except in compliance with the License.
12 // You may obtain a copy of the License at
13 //
14 // http://www.apache.org/licenses/LICENSE-2.0
15 //
16 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
18 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
19 // MERCHANTABLITY OR NON-INFRINGEMENT.
20 // See the Apache 2 License for the specific language governing permissions and
21 // limitations under the License.
22 
23 
24 #include "lat/push-lattice.h"
25 #include "hmm/transition-model.h"
26 #include "util/stl-utils.h"
27 
28 namespace fst {
29 
30 
31 template<class Weight, class IntType> class CompactLatticePusher {
32  public:
34  typedef ArcTpl<CompactWeight> CompactArc;
35  typedef typename CompactArc::StateId StateId;
36 
37  CompactLatticePusher(MutableFst<CompactArc> *clat): clat_(clat) { }
38  bool Push() {
39  if (clat_->Properties(kTopSorted, true) == 0) {
40  if (!TopSort(clat_)) {
41  KALDI_WARN << "Topological sorting of state-level lattice failed "
42  "(probably your lexicon has empty words or your LM has epsilon cycles; this "
43  " is a bad idea.)";
44  return false;
45  }
46  }
47  ComputeShifts();
48  ApplyShifts();
49  return true;
50  }
51 
52  // Gets the string of length [end - begin], starting at this
53  // state and taking arc "arc_idx" (and thereafter an arbitrary sequence).
54  // Note: here, arc_idx == -1 means take an arbitrary path.
55  static void GetString(const ExpandedFst<CompactArc> &clat,
56  StateId state,
57  size_t arc_idx,
58  typename std::vector<IntType>::iterator begin,
59  typename std::vector<IntType>::iterator end) {
60  CompactWeight final = clat.Final(state);
61  size_t len = end - begin;
62  KALDI_ASSERT(len >= 0);
63  if (len == 0) return;
64  if (arc_idx == -1 && final != CompactWeight::Zero()) {
65  const std::vector<IntType> &string = final.String();
66  KALDI_ASSERT(string.size() >= len &&
67  "Either code error, or paths in lattice have inconsistent lengths");
68  std::copy(string.begin(), string.begin() + len, begin);
69  return;
70  }
71 
72  ArcIterator<ExpandedFst<CompactArc> > aiter(clat, state);
73  if (arc_idx != -1)
74  aiter.Seek(arc_idx);
75  KALDI_ASSERT(!aiter.Done() &&
76  "Either code error, or paths in lattice are inconsistent in length.");
77 
78  const CompactArc &arc = aiter.Value();
79  size_t arc_len = arc.weight.String().size();
80  if (arc_len >= len) {
81  std::copy(arc.weight.String().begin(), arc.weight.String().begin() + len, begin);
82  } else {
83  std::copy(arc.weight.String().begin(), arc.weight.String().end(), begin);
84  // Recurse.
85  GetString(clat, arc.nextstate, -1, begin + arc_len, end);
86  }
87  }
88 
89  void CheckForConflict(const CompactWeight &final,
90  StateId state,
91  int32 *shift) {
92  if (shift == NULL) return;
93  // At input, "shift" has the maximum value that we could shift back assuming
94  // there is no conflict between the values of the strings. We need to check
95  // if there is conflict, and if so, reduce the "shift".
96  bool is_final = (final != CompactWeight::Zero());
97  size_t num_arcs = clat_->NumArcs(state);
98  if (num_arcs + (is_final ? 1 : 0) > 1 && *shift > 0) {
99  // There is potential for conflict between string values, because >1
100  // [arc or final-prob]. Find the longest shift up to and including the
101  // current shift, that gives no conflict.
102 
103  std::vector<IntType> string(*shift), compare_string(*shift);
104  size_t arc;
105  if (is_final) {
106  KALDI_ASSERT(final.String().size() >= *shift);
107  std::copy(final.String().begin(), final.String().begin() + *shift,
108  string.begin());
109  arc = 0;
110  } else {
111  // set "string" to string if we take 1st arc.
112  GetString(*clat_, state, 0, string.begin(), string.end());
113  arc = 1;
114  }
115  for (; arc < num_arcs; arc++) { // for the other arcs..
116  GetString(*clat_, state, arc,
117  compare_string.begin(), compare_string.end());
118  std::pair<typename std::vector<IntType>::iterator,
119  typename std::vector<IntType>::iterator> pr =
120  std::mismatch(string.begin(), string.end(),
121  compare_string.begin());
122  if (pr.first != string.end()) { // There was a mismatch. Reduce the shift
123  // to a value where they will match.
124  *shift = pr.first - string.begin();
125  string.resize(*shift);
126  compare_string.resize(*shift);
127  }
128  }
129  }
130  }
131 
132  void ComputeShifts() {
133  StateId num_states = clat_->NumStates();
134  shift_vec_.resize(num_states, 0);
135 
136  // The for loop will only work if StateId is signed, so assert this.
137  KALDI_COMPILE_TIME_ASSERT(static_cast<StateId>(-1) < static_cast<StateId>(0));
138  // We rely on the topological sorting, so clat_->Start() should be zero or
139  // at least any preceding states should be non-accessible. We leave the
140  // shift at zero for the start state because we can't shift to before that.
141  for (StateId state = num_states - 1; state > clat_->Start(); state--) {
142  size_t num_arcs = clat_->NumArcs(state);
143  CompactWeight final = clat_->Final(state);
144  if (num_arcs == 0) {
145  // we can shift back by the number of transition-ids on the
146  // final-prob, if any.
147  shift_vec_[state] = final.String().size();
148  } else { // We have arcs ...
149  int32 shift = std::numeric_limits<int32>::max();
150  size_t num_arcs = 0;
151  bool is_final = (final != CompactWeight::Zero());
152  if (is_final)
153  shift = std::min(shift, static_cast<int32>(final.String().size()));
154  for (ArcIterator<MutableFst<CompactArc> > aiter(*clat_, state);
155  !aiter.Done(); aiter.Next(), num_arcs++) {
156  const CompactArc &arc (aiter.Value());
157  shift = std::min(shift, shift_vec_[arc.nextstate] +
158  static_cast<int32>(arc.weight.String().size()));
159  }
160  CheckForConflict(final, state, &shift);
161  shift_vec_[state] = shift;
162  }
163  }
164  }
165 
166  void ApplyShifts() {
167  StateId num_states = clat_->NumStates();
168  for (StateId state = 0; state < num_states; state++) {
169  int32 shift = shift_vec_[state];
170  std::vector<IntType> string;
171  for (MutableArcIterator<MutableFst<CompactArc> > aiter(clat_, state);
172  !aiter.Done(); aiter.Next()) {
173  CompactArc arc(aiter.Value());
174  KALDI_ASSERT(arc.nextstate > state && "Cyclic lattice");
175 
176  string = arc.weight.String();
177  size_t orig_len = string.size(), next_shift = shift_vec_[arc.nextstate];
178  // extend "string" by next_shift.
179  string.resize(string.size() + next_shift);
180  // The next command sets the last "next_shift" elements of 'string' to
181  // the string starting from arc.nextstate (taking an arbitrary path).
182  GetString(*clat_, arc.nextstate, -1,
183  string.begin() + orig_len, string.end());
184  // Remove the first "shift" elements of this string and set the
185  // arc-weight string to this.
186  arc.weight.SetString(std::vector<IntType>(string.begin() + shift,
187  string.end()));
188  aiter.SetValue(arc);
189  }
190 
191  CompactWeight final = clat_->Final(state);
192  if (final != CompactWeight::Zero()) {
193  // Erase first "shift" elements of final-prob.
194  final.SetString(std::vector<IntType>(final.String().begin() + shift,
195  final.String().end()));
196  clat_->SetFinal(state, final);
197  }
198  }
199  }
200 
201  private:
202  MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *clat_;
203 
204  // For each state s, shift_vec_[s] >= 0 is how much we will shift the
205  // transition-ids back at this state.
206  std::vector<int32> shift_vec_;
207 };
208 
209 template<class Weight, class IntType>
211  MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *clat) {
213  return pusher.Push();
214 }
215 
216 template<class Weight, class IntType>
218  MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *clat) {
219  if (clat->Properties(kTopSorted, true) == 0) {
220  if (!TopSort(clat)) {
221  KALDI_WARN << "Topological sorting of state-level lattice failed "
222  "(probably your lexicon has empty words or your LM has epsilon cycles; this "
223  " is a bad idea.)";
224  return false;
225  }
226  }
228  typedef ArcTpl<CompactWeight> CompactArc;
229  typedef typename CompactArc::StateId StateId;
230 
231  StateId num_states = clat->NumStates();
232  if (num_states == 0) {
233  KALDI_WARN << "Pushing weights of empty compact lattice";
234  return true; // this is technically success because an empty
235  // lattice is already pushed.
236  }
237  std::vector<Weight> weight_to_end(num_states); // Note: LatticeWeight
238  // contains two floats.
239  for (StateId s = num_states - 1; s >= 0; s--) {
240  Weight this_weight_to_end = clat->Final(s).Weight();
241  for (ArcIterator<MutableFst<CompactArc> > aiter(*clat, s);
242  !aiter.Done(); aiter.Next()) {
243  const CompactArc &arc = aiter.Value();
244  KALDI_ASSERT(arc.nextstate > s && "Cyclic lattices not allowed.");
245  this_weight_to_end = Plus(this_weight_to_end,
246  Times(aiter.Value().weight.Weight(),
247  weight_to_end[arc.nextstate]));
248  }
249  if (this_weight_to_end == Weight::Zero()) {
250  KALDI_WARN << "Lattice has non-coaccessible states.";
251  }
252  weight_to_end[s] = this_weight_to_end;
253  }
254  weight_to_end[0] = Weight::One(); // We leave the "leftover weight" on
255  // the start state, which won't
256  // necessarily end up summing to one.
257  for (StateId s = 0; s < num_states; s++) {
258  Weight this_weight_to_end = weight_to_end[s];
259  if (this_weight_to_end == Weight::Zero())
260  continue;
261  for (MutableArcIterator<MutableFst<CompactArc> > aiter(clat, s);
262  !aiter.Done(); aiter.Next()) {
263  CompactArc arc = aiter.Value();
264  Weight next_weight_to_end = weight_to_end[arc.nextstate];
265  if (next_weight_to_end != Weight::Zero()) {
266  arc.weight.SetWeight(Times(arc.weight.Weight(),
267  Divide(next_weight_to_end,
268  this_weight_to_end)));
269  aiter.SetValue(arc);
270  }
271  }
272  CompactWeight final_weight = clat->Final(s);
273  if (final_weight != CompactWeight::Zero()) {
274  final_weight.SetWeight(Divide(final_weight.Weight(), this_weight_to_end));
275  clat->SetFinal(s, final_weight);
276  }
277  }
278 
279  return true;
280 }
281 
282 // Instantiate for CompactLattice.
283 template
284 bool PushCompactLatticeStrings<kaldi::LatticeWeight, kaldi::int32>(
285  MutableFst<kaldi::CompactLatticeArc> *clat);
286 
287 template
288 bool PushCompactLatticeWeights<kaldi::LatticeWeight, kaldi::int32>(
289  MutableFst<kaldi::CompactLatticeArc> *clat);
290 
291 } // namespace fst
fst::StdArc::StateId StateId
LatticeWeightTpl< FloatType > Divide(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, DivideType typ=DIVIDE_ANY)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
LatticeWeightTpl< FloatType > Plus(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
CompactLatticeWeightTpl< Weight, IntType > CompactWeight
Definition: push-lattice.cc:33
bool PushCompactLatticeStrings(MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > *clat)
This function pushes the transition-ids as far towards the start as they will go. ...
kaldi::int32 int32
static void GetString(const ExpandedFst< CompactArc > &clat, StateId state, size_t arc_idx, typename std::vector< IntType >::iterator begin, typename std::vector< IntType >::iterator end)
Definition: push-lattice.cc:55
void CheckForConflict(const CompactWeight &final, StateId state, int32 *shift)
Definition: push-lattice.cc:89
bool PushCompactLatticeWeights(MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > *clat)
This function pushes the weights in the CompactLattice so that all states except possibly the start s...
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
CompactLatticePusher(MutableFst< CompactArc > *clat)
Definition: push-lattice.cc:37
MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > * clat_
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::StdArc::Weight Weight
std::vector< int32 > shift_vec_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
ArcTpl< CompactWeight > CompactArc
Definition: push-lattice.cc:34
static const CompactLatticeWeightTpl< WeightType, IntType > Zero()
CompactArc::StateId StateId
Definition: push-lattice.cc:35
#define KALDI_COMPILE_TIME_ASSERT(b)
Definition: kaldi-utils.h:131