deterministic-fst-inl.h
Go to the documentation of this file.
1 // fstext/deterministic-fst-inl.h
2 
3 // Copyright 2011-2012 Gilles Boulianne
4 // 2014 Telepoint Global Hosting Service, LLC. (Author: David Snyder)
5 // 2012-2015 Johns Hopkins University (author: Daniel Povey)
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #ifndef KALDI_FSTEXT_DETERMINISTIC_FST_INL_H_
23 #define KALDI_FSTEXT_DETERMINISTIC_FST_INL_H_
24 #include "base/kaldi-common.h"
25 #include "fstext/fstext-utils.h"
26 
27 
28 namespace fst {
29 // Do not include this file directly. It is included by deterministic-fst.h.
30 
31 template<class Arc>
32 typename Arc::StateId
34  Weight *w) {
35  ArcIterator<Fst<Arc> > aiter(fst_, s);
36  if (aiter.Done()) // no arcs.
37  return kNoStateId;
38  const Arc &arc = aiter.Value();
39  if (arc.ilabel == 0) {
40  *w = arc.weight;
41  return arc.nextstate;
42  } else {
43  return kNoStateId;
44  }
45 }
46 
47 template<class Arc>
49  Weight w = fst_.Final(state);
50  if (w != Weight::Zero()) return w;
51  Weight backoff_w;
52  StateId backoff_state = GetBackoffState(state, &backoff_w);
53  if (backoff_state == kNoStateId) return Weight::Zero();
54  else return Times(backoff_w, this->Final(backoff_state));
55 }
56 
57 template<class Arc>
59  const Fst<Arc> &fst): fst_(fst) {
60 #ifdef KALDI_PARANOID
61  KALDI_ASSERT(fst_.Properties(kILabelSorted|kIDeterministic, true) ==
62  (kILabelSorted|kIDeterministic) &&
63  "Input FST is not i-label sorted and deterministic.");
64 #endif
65 }
66 
67 template<class Arc>
69  StateId s, Label ilabel, Arc *oarc) {
70  KALDI_ASSERT(ilabel != 0); // We don't allow GetArc for epsilon.
71 
72  SortedMatcher<Fst<Arc> > sm(fst_, MATCH_INPUT, 1);
73  sm.SetState(s);
74  if (sm.Find(ilabel)) {
75  const Arc &arc = sm.Value();
76  *oarc = arc;
77  return true;
78  } else {
79  Weight backoff_w;
80  StateId backoff_state = GetBackoffState(s, &backoff_w);
81  if (backoff_state == kNoStateId) return false;
82  if (!this->GetArc(backoff_state, ilabel, oarc)) return false;
83  oarc->weight = Times(oarc->weight, backoff_w);
84  return true;
85  }
86 }
87 
88 template<class Arc>
90  // Starting state is an empty vector
91  std::vector<Label> start_state;
92  state_vec_.push_back(start_state);
93  start_state_ = 0;
94  state_map_[start_state] = 0;
95 }
96 
97 template<class Arc>
99  StateId s, Label ilabel, Arc *oarc) {
100 
101  // The state ids increment with each state we encounter.
102  // if the assert fails, then we are trying to access
103  // unseen states that are not immediately traversable.
104  KALDI_ASSERT(static_cast<size_t>(s) < state_vec_.size());
105  std::vector<Label> seq = state_vec_[s];
106  // Update state info.
107  seq.push_back(ilabel);
108  if (seq.size() > n_-1) {
109  // Remove oldest word in the history.
110  seq.erase(seq.begin());
111  }
112  std::pair<const std::vector<Label>, StateId> new_state(
113  seq,
114  static_cast<Label>(state_vec_.size()));
115  // Now get state id for destination state.
116  typedef typename MapType::iterator IterType;
117  std::pair<IterType, bool> result = state_map_.insert(new_state);
118  if (result.second == true) {
119  state_vec_.push_back(seq);
120  }
121  oarc->weight = Weight::One(); // Because the FST is unweightd.
122  oarc->ilabel = ilabel;
123  oarc->olabel = ilabel;
124  oarc->nextstate = result.first->second; // The next state id.
125  // All arcs can be matched.
126  return true;
127 }
128 
129 template<class Arc>
131  KALDI_ASSERT(state < static_cast<StateId>(state_vec_.size()));
132  return Weight::One();
133 }
134 
135 template<class Arc>
138  DeterministicOnDemandFst<Arc> *fst2): fst1_(fst1), fst2_(fst2) {
139  KALDI_ASSERT(fst1 != NULL && fst2 != NULL);
140  if (fst1_->Start() == -1 || fst2_->Start() == -1) {
141  start_state_ = -1;
142  next_state_ = 0; // actually we don't care about this value.
143  } else {
144  start_state_ = 0;
145  std::pair<StateId,StateId> start_pair(fst1_->Start(), fst2_->Start());
146  state_map_[start_pair] = start_state_;
147  state_vec_.push_back(start_pair);
148  next_state_ = 1;
149  }
150 }
151 
152 template<class Arc>
154  KALDI_ASSERT(s < static_cast<StateId>(state_vec_.size()));
155  const std::pair<StateId, StateId> &pr (state_vec_[s]);
156  return Times(fst1_->Final(pr.first), fst2_->Final(pr.second));
157 }
158 
159 template<class Arc>
161  Arc *oarc) {
162  typedef typename MapType::iterator IterType;
163  KALDI_ASSERT(ilabel != 0 &&
164  "This program expects epsilon-free compact lattices as input");
165  KALDI_ASSERT(s < static_cast<StateId>(state_vec_.size()));
166  const std::pair<StateId, StateId> pr (state_vec_[s]);
167 
168  Arc arc1;
169  if (!fst1_->GetArc(pr.first, ilabel, &arc1)) return false;
170  if (arc1.olabel == 0) { // There is no output label on the
171  // arc, so only the first state changes.
172  std::pair<const std::pair<StateId, StateId>, StateId> new_value(
173  std::pair<StateId, StateId>(arc1.nextstate, pr.second),
174  next_state_);
175 
176  std::pair<IterType, bool> result = state_map_.insert(new_value);
177  oarc->ilabel = ilabel;
178  oarc->olabel = 0;
179  oarc->nextstate = result.first->second;
180  oarc->weight = arc1.weight;
181  if (result.second == true) { // was inserted
182  next_state_++;
183  const std::pair<StateId, StateId> &new_pair (new_value.first);
184  state_vec_.push_back(new_pair);
185  }
186  return true;
187  }
188  // There is an output label, so we need to traverse an arc on the
189  // second fst also.
190  Arc arc2;
191  if (!fst2_->GetArc(pr.second, arc1.olabel, &arc2)) return false;
192  std::pair<const std::pair<StateId, StateId>, StateId> new_value(
193  std::pair<StateId, StateId>(arc1.nextstate, arc2.nextstate),
194  next_state_);
195  std::pair<IterType, bool> result =
196  state_map_.insert(new_value);
197  oarc->ilabel = ilabel;
198  oarc->olabel = arc2.olabel;
199  oarc->nextstate = result.first->second;
200  oarc->weight = Times(arc1.weight, arc2.weight);
201  if (result.second == true) { // was inserted
202  next_state_++;
203  const std::pair<StateId, StateId> &new_pair (new_value.first);
204  state_vec_.push_back(new_pair);
205  }
206  return true;
207 }
208 
209 template<class Arc>
211  StateId src_state, Label ilabel) {
212  const StateId p1 = 26597, p2 = 50329; // these are two
213  // values that I drew at random from a table of primes.
214  // note: num_cached_arcs_ > 0.
215 
216  // We cast to size_t before the modulus, to ensure the
217  // result is positive.
218  return static_cast<size_t>(src_state * p1 + ilabel * p2) %
219  static_cast<size_t>(num_cached_arcs_);
220 }
221 
222 template<class Arc>
225  StateId num_cached_arcs): fst_(fst),
226  num_cached_arcs_(num_cached_arcs),
227  cached_arcs_(num_cached_arcs) {
228  KALDI_ASSERT(num_cached_arcs > 0);
229  for (StateId i = 0; i < num_cached_arcs; i++)
230  cached_arcs_[i].first = kNoStateId; // Invalidate all elements of the cache.
231 }
232 
233 template<class Arc>
235  Arc *oarc) {
236  // Note: we don't cache anything in case a requested arc does not exist.
237  // In the uses that we imagine this will be put to, essentially all the
238  // requested arcs will exist. This only affects efficiency.
239  KALDI_ASSERT(s >= 0 && ilabel != 0);
240  size_t index = this->GetIndex(s, ilabel);
241  if (cached_arcs_[index].first == s &&
242  cached_arcs_[index].second.ilabel == ilabel) {
243  *oarc = cached_arcs_[index].second;
244  return true;
245  } else {
246  Arc arc;
247  if (fst_->GetArc(s, ilabel, &arc)) {
248  cached_arcs_[index].first = s;
249  cached_arcs_[index].second = arc;
250  *oarc = arc;
251  return true;
252  } else {
253  return false;
254  }
255  }
256 }
257 
258 template<class Arc>
260  void *lm, Label bos_symbol, Label eos_symbol):
261  lm_(lm), bos_symbol_(bos_symbol), eos_symbol_(eos_symbol) {
262  std::vector<Label> begin_state; // history state corresponding to beginning of sentence
263  begin_state.push_back(bos_symbol); // Depending how your LM is set up, you might
264  // want to have a history vector with more than one bos_symbol on it.
265 
266  state_vec_.push_back(begin_state);
267  start_state_ = 0;
268  state_map_[begin_state] = 0;
269 }
270 
271 template<class Arc>
273  KALDI_ASSERT(static_cast<size_t>(s) < state_vec_.size());
274  // In a real version you would probably use the following variable somehow
275  // (commenting it because it's generating warnings).
276  // const std::vector<Label> &wseq = state_vec_[s];
277  float log_prob = -0.5; // e.g. log_prob = lm->GetLogProb(wseq, eos_symbol_);
278  return Weight(-log_prob); // assuming weight is FloatWeight.
279 }
280 
281 template<class Arc>
283  StateId s, Label ilabel, Arc *oarc) {
284  KALDI_ASSERT(static_cast<size_t>(s) < state_vec_.size());
285  std::vector<Label> wseq = state_vec_[s];
286  float log_prob = -0.25; // e.g. log_prob = lm->GetLogProb(wseq, ilabel);
287  wseq.push_back(ilabel); // the code might be different if your histories are the
288  // other way around.
289 
290  while (0) { // e.g. while !lm->HistoryStateExists(wseq)
291  wseq.erase(wseq.begin(), wseq.begin() + 1); // remove most distant element of history.
292  // note: if your histories are the other way round, you might just do
293  // wseq.pop() here.
294  }
295  if (log_prob == -std::numeric_limits<float>::infinity()) { // assume this
296  // is what happens if prob of the word is zero. Some LMs will never
297  // return zero.
298  return false; // no arc.
299  }
300  std::pair<const std::vector<Label>, StateId> new_value(
301  wseq,
302  static_cast<Label>(state_vec_.size()));
303 
304  // Now get state id for destination state.
305  typedef typename MapType::iterator IterType;
306  std::pair<IterType, bool> result = state_map_.insert(new_value);
307  if (result.second == true) // was inserted
308  state_vec_.push_back(wseq);
309  oarc->ilabel = ilabel;
310  oarc->olabel = ilabel;
311  oarc->nextstate = result.first->second; // the next-state id.
312  oarc->weight = Weight(-log_prob);
313  return true;
314 }
315 
316 
317 template<class Arc>
318 void ComposeDeterministicOnDemand(const Fst<Arc> &fst1,
320  MutableFst<Arc> *fst_composed) {
321  typedef typename Arc::Weight Weight;
322  typedef typename Arc::StateId StateId;
323  typedef std::pair<StateId, StateId> StatePair;
324  typedef unordered_map<StatePair, StateId,
326  typedef typename MapType::iterator IterType;
327 
328  fst_composed->DeleteStates();
329 
330  MapType state_map;
331  std::queue<StatePair> state_queue;
332 
333  // Set start state in fst_composed.
334  StateId s1 = fst1.Start(),
335  s2 = fst2->Start(),
336  start_state = fst_composed->AddState();
337  StatePair start_pair(s1, s2);
338  state_queue.push(start_pair);
339  fst_composed->SetStart(start_state);
340  // A mapping between pairs of states in fst1 and fst2 and the corresponding
341  // state in fst_composed.
342  std::pair<const StatePair, StateId> start_map(start_pair, start_state);
343  std::pair<IterType, bool> result = state_map.insert(start_map);
344  KALDI_ASSERT(result.second == true);
345 
346  while (!state_queue.empty()) {
347  StatePair q = state_queue.front();
348  StateId q1 = q.first,
349  q2 = q.second;
350  state_queue.pop();
351  // If the product of the final weights of the two fsts is non-zero then
352  // we can set a final-prob in fst_composed
353  Weight final_weight = Times(fst1.Final(q1), fst2->Final(q2));
354  if (final_weight != Weight::Zero()) {
355  KALDI_ASSERT(state_map.find(q) != state_map.end());
356  fst_composed->SetFinal(state_map[q], final_weight);
357  }
358 
359  // for each pair of edges from fst1 and fst2 at q1 and q2.
360  for (ArcIterator<Fst<Arc> > aiter(fst1, q1); !aiter.Done(); aiter.Next()) {
361  const Arc &arc1 = aiter.Value();
362  Arc arc2;
363  StatePair next_pair;
364  StateId next_state1 = arc1.nextstate,
365  next_state2,
366  next_state;
367  // If there is an epsilon on the arc of fst1 we transition to the next
368  // state but keep fst2 at the current state.
369  if (arc1.olabel == 0) {
370  next_state2 = q2;
371  } else {
372  bool match = fst2->GetArc(q2, arc1.olabel, &arc2);
373  if (!match) // There is no matching arc -> nothing to do.
374  continue;
375  next_state2 = arc2.nextstate;
376  }
377  next_pair = StatePair(next_state1, next_state2);
378  IterType sitr = state_map.find(next_pair);
379  // If sitr == state_map.end() then the state isn't in fst_composed yet.
380  if (sitr == state_map.end()) {
381  next_state = fst_composed->AddState();
382  std::pair<const StatePair, StateId> new_state(
383  next_pair, next_state);
384  std::pair<IterType, bool> result = state_map.insert(new_state);
385  // Since we already checked if state_map contained new_state,
386  // it should always be added if we reach here.
387  KALDI_ASSERT(result.second == true);
388  state_queue.push(next_pair);
389  // If sitr != state_map.end() then the next state is already in
390  // the state_map.
391  } else {
392  next_state = sitr->second;
393  }
394  if (arc1.olabel == 0) {
395  fst_composed->AddArc(state_map[q], Arc(arc1.ilabel, 0, arc1.weight,
396  next_state));
397  } else {
398  fst_composed->AddArc(state_map[q], Arc(arc1.ilabel, arc2.olabel,
399  Times(arc1.weight, arc2.weight), next_state));
400  }
401  }
402  }
403 }
404 
405 
406 // we are doing *fst_composed = Compose(Inverse(*left), right).
407 template<class Arc>
408 void ComposeDeterministicOnDemandInverse(const Fst<Arc> &right,
410  MutableFst<Arc> *fst_composed) {
411  typedef typename Arc::Weight Weight;
412  typedef typename Arc::StateId StateId;
413  typedef std::pair<StateId, StateId> StatePair;
414  typedef unordered_map<StatePair, StateId,
416  typedef typename MapType::iterator IterType;
417 
418  fst_composed->DeleteStates();
419 
420  // the queue and map contain pairs (state-in-left, state-in-right)
421  MapType state_map;
422  std::queue<StatePair> state_queue;
423 
424  // Set start state in fst_composed.
425  StateId s_left = left->Start(),
426  s_right = right.Start();
427  if (s_left == kNoStateId || s_right == kNoStateId)
428  return; // Empty result.
429  StatePair start_pair(s_left, s_right);
430  StateId start_state = fst_composed->AddState();
431  state_queue.push(start_pair);
432  fst_composed->SetStart(start_state);
433  // A mapping between pairs of states in *left and right, and the corresponding
434  // state in fst_composed.
435  std::pair<const StatePair, StateId> start_map(start_pair, start_state);
436  std::pair<IterType, bool> result = state_map.insert(start_map);
437  KALDI_ASSERT(result.second == true);
438 
439  while (!state_queue.empty()) {
440  StatePair q = state_queue.front();
441  StateId q_left = q.first,
442  q_right = q.second;
443  state_queue.pop();
444  // If the product of the final weights of the two fsts is non-zero then
445  // we can set a final-prob in fst_composed
446  Weight final_weight = Times(left->Final(q_left), right.Final(q_right));
447  if (final_weight != Weight::Zero()) {
448  KALDI_ASSERT(state_map.find(q) != state_map.end());
449  fst_composed->SetFinal(state_map[q], final_weight);
450  }
451 
452  for (ArcIterator<Fst<Arc> > aiter(right, q_right); !aiter.Done(); aiter.Next()) {
453  const Arc &arc_right = aiter.Value();
454  Arc arc_left;
455  StatePair next_pair;
456  StateId next_state_right = arc_right.nextstate,
457  next_state_left,
458  next_state;
459  // If there is an epsilon on the input side of the rigth arc, we
460  // transition to the next state of the output but keep 'left' at the
461  // current state.
462  if (arc_right.ilabel == 0) {
463  next_state_left = q_left;
464  } else {
465  bool match = left->GetArc(q_left, arc_right.ilabel, &arc_left);
466  if (!match) // There is no matching arc -> nothing to do.
467  continue;
468  // the next 'swap' is because we are composing with the inverse of
469  // *left. Just removing the swap statement wouldn't let us compose
470  // with non-inverted *left though, because the GetArc function call
471  // above interprets the second argument as an ilabel not an olabel.
472  std::swap(arc_left.ilabel, arc_left.olabel);
473  next_state_left = arc_left.nextstate;
474  }
475  next_pair = StatePair(next_state_left, next_state_right);
476  IterType sitr = state_map.find(next_pair);
477  // If sitr == state_map.end() then the state isn't in fst_composed yet.
478  if (sitr == state_map.end()) {
479  next_state = fst_composed->AddState();
480  std::pair<const StatePair, StateId> new_state(
481  next_pair, next_state);
482  std::pair<IterType, bool> result = state_map.insert(new_state);
483  // Since we already checked if state_map contained new_state,
484  // it should always be added if we reach here.
485  KALDI_ASSERT(result.second == true);
486  state_queue.push(next_pair);
487  // If sitr != state_map.end() then the next state is already in
488  // the state_map.
489  } else {
490  next_state = sitr->second;
491  }
492  if (arc_right.ilabel == 0) {
493  // we didn't get an actual arc from the left FST.
494  fst_composed->AddArc(state_map[q], Arc(0, arc_right.olabel,
495  arc_right.weight,
496  next_state));
497  } else {
498  fst_composed->AddArc(state_map[q],
499  Arc(arc_left.ilabel, arc_right.olabel,
500  Times(arc_left.weight, arc_right.weight),
501  next_state));
502  }
503  }
504  }
505 }
506 
507 
508 
509 } // end namespace fst
510 
511 
512 #endif
std::vector< std::vector< Label > > state_vec_
fst::StdArc::StateId StateId
ComposeDeterministicOnDemandFst(DeterministicOnDemandFst< Arc > *fst1, DeterministicOnDemandFst< Arc > *fst2)
Note: constructor does not "take ownership" of the input fst&#39;s.
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)=0
Note: ilabel must not be epsilon.
BackoffDeterministicOnDemandFst(const Fst< Arc > &fst)
virtual Weight Final(StateId s)=0
unordered_map< std::vector< Label >, StateId, kaldi::VectorHasher< Label > > MapType
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)
Note: ilabel must not be epsilon.
StateId GetBackoffState(StateId s, Weight *w)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void ComposeDeterministicOnDemand(const Fst< Arc > &fst1, DeterministicOnDemandFst< Arc > *fst2, MutableFst< Arc > *fst_composed)
void swap(basic_filebuf< CharT, Traits > &x, basic_filebuf< CharT, Traits > &y)
virtual StateId Start()=0
std::vector< std::vector< Label > > state_vec_
LmExampleDeterministicOnDemandFst(void *lm, Label bos_symbol, Label eos_symbol)
DeterministicOnDemandFst< Arc > * fst1_
void ComposeDeterministicOnDemandInverse(const Fst< Arc > &right, DeterministicOnDemandFst< Arc > *left, MutableFst< Arc > *fst_composed)
This function does &#39;*fst_composed = Compose(Inverse(*fst2), fst1)&#39; Note that the arguments are revers...
bool GetArc(StateId s, Label ilabel, Arc *oarc)
Note: ilabel must not be epsilon.
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
class DeterministicOnDemandFst is an "FST-like" base-class.
size_t GetIndex(StateId src_state, Label ilabel)
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)
Note: ilabel must not be epsilon.
struct rnnlm::@11::@12 n
std::vector< std::pair< StateId, Arc > > cached_arcs_
bool GetArc(StateId s, Label ilabel, Arc *oarc)
Note: ilabel must not be epsilon.
DeterministicOnDemandFst< Arc > * fst_
fst::StdArc::Weight Weight
DeterministicOnDemandFst< Arc > * fst2_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
CacheDeterministicOnDemandFst(DeterministicOnDemandFst< Arc > *fst, StateId num_cached_arcs=100000)
We don&#39;t take ownership of this pointer. The argument is "really" const.
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)
Note: ilabel must not be epsilon.
virtual Weight Final(StateId s)
We don&#39;t bother caching the final-probs, just the arcs.
std::vector< std::pair< StateId, StateId > > state_vec_
A hashing function-object for pairs of ints.
Definition: stl-utils.h:235