determinize-lattice-inl.h
Go to the documentation of this file.
1 // fstext/determinize-lattice-inl.h
2 
3 // Copyright 2009-2012 Microsoft Corporation
4 // 2012-2013 Johns Hopkins University (Author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
22 #define KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
23 // Do not include this file directly. It is included by determinize-lattice.h
24 
25 #include <vector>
26 #include <climits>
27 
28 namespace fst {
29 
30 // This class maps back and forth from/to integer id's to sequences of strings.
31 // used in determinization algorithm. It is constructed in such a way that
32 // finding the string-id of the successor of (string, next-label) has constant time.
33 
34 // Note: class IntType, typically int32, is the type of the element in the
35 // string (typically a template argument of the CompactLatticeWeightTpl).
36 
37 template<class IntType> class LatticeStringRepository {
38  public:
39  struct Entry {
40  const Entry *parent; // NULL for empty string.
41  IntType i;
42  inline bool operator == (const Entry &other) const {
43  return (parent == other.parent && i == other.i);
44  }
45  Entry() { }
46  Entry(const Entry &e): parent(e.parent), i(e.i) {}
47  };
48  // Note: all Entry* pointers returned in function calls are
49  // owned by the repository itself, not by the caller!
50 
51  // Interface guarantees empty string is NULL.
52  inline const Entry *EmptyString() { return NULL; }
53 
54  // Returns string of "parent" with i appended. Pointer
55  // owned by repository
56  const Entry *Successor(const Entry *parent, IntType i) {
58  new_entry_->i = i;
59 
60  std::pair<typename SetType::iterator, bool> pr = set_.insert(new_entry_);
61  if (pr.second) { // Was successfully inserted (was not there). We need to
62  // replace the element we inserted, which resides on the
63  // stack, with one from the heap.
64  const Entry *ans = new_entry_;
65  new_entry_ = new Entry();
66  return ans;
67  } else { // Was not inserted because an equivalent Entry already
68  // existed.
69  return *pr.first;
70  }
71  }
72 
73  const Entry *Concatenate (const Entry *a, const Entry *b) {
74  if (a == NULL) return b;
75  else if (b == NULL) return a;
76  std::vector<IntType> v;
77  ConvertToVector(b, &v);
78  const Entry *ans = a;
79  for(size_t i = 0; i < v.size(); i++)
80  ans = Successor(ans, v[i]);
81  return ans;
82  }
83  const Entry *CommonPrefix (const Entry *a, const Entry *b) {
84  std::vector<IntType> a_vec, b_vec;
85  ConvertToVector(a, &a_vec);
86  ConvertToVector(b, &b_vec);
87  const Entry *ans = NULL;
88  for(size_t i = 0; i < a_vec.size() && i < b_vec.size() &&
89  a_vec[i] == b_vec[i]; i++)
90  ans = Successor(ans, a_vec[i]);
91  return ans;
92  }
93 
94  // removes any elements from b that are not part of
95  // a common prefix with a.
96  void ReduceToCommonPrefix(const Entry *a,
97  std::vector<IntType> *b) {
98  size_t a_size = Size(a), b_size = b->size();
99  while (a_size> b_size) {
100  a = a->parent;
101  a_size--;
102  }
103  if (b_size > a_size)
104  b_size = a_size;
105  typename std::vector<IntType>::iterator b_begin = b->begin();
106  while (a_size != 0) {
107  if (a->i != *(b_begin + a_size - 1))
108  b_size = a_size - 1;
109  a = a->parent;
110  a_size--;
111  }
112  if (b_size != b->size())
113  b->resize(b_size);
114  }
115 
116  // removes the first n elements of a.
117  const Entry *RemovePrefix(const Entry *a, size_t n) {
118  if (n==0) return a;
119  std::vector<IntType> a_vec;
120  ConvertToVector(a, &a_vec);
121  assert(a_vec.size() >= n);
122  const Entry *ans = NULL;
123  for(size_t i = n; i < a_vec.size(); i++)
124  ans = Successor(ans, a_vec[i]);
125  return ans;
126  }
127 
128 
129 
130  // Returns true if a is a prefix of b. If a is prefix of b,
131  // time taken is |b| - |a|. Else, time taken is |b|.
132  bool IsPrefixOf(const Entry *a, const Entry *b) const {
133  if(a == NULL) return true; // empty string prefix of all.
134  if (a == b) return true;
135  if (b == NULL) return false;
136  return IsPrefixOf(a, b->parent);
137  }
138 
139 
140  inline size_t Size(const Entry *entry) const {
141  size_t ans = 0;
142  while (entry != NULL) {
143  ans++;
144  entry = entry->parent;
145  }
146  return ans;
147  }
148 
149  void ConvertToVector(const Entry *entry, std::vector<IntType> *out) const {
150  size_t length = Size(entry);
151  out->resize(length);
152  if (entry != NULL) {
153  typename std::vector<IntType>::reverse_iterator iter = out->rbegin();
154  while (entry != NULL) {
155  *iter = entry->i;
156  entry = entry->parent;
157  ++iter;
158  }
159  }
160  }
161 
162  const Entry *ConvertFromVector(const std::vector<IntType> &vec) {
163  const Entry *e = NULL;
164  for(size_t i = 0; i < vec.size(); i++)
165  e = Successor(e, vec[i]);
166  return e;
167  }
168 
170 
171  void Destroy() {
172  for (typename SetType::iterator iter = set_.begin();
173  iter != set_.end();
174  ++iter)
175  delete *iter;
176  SetType tmp;
177  tmp.swap(set_);
178  if (new_entry_) {
179  delete new_entry_;
180  new_entry_ = NULL;
181  }
182  }
183 
184  // Rebuild will rebuild this object, guaranteeing only
185  // to preserve the Entry values that are in the vector pointed
186  // to (this list does not have to be unique). The point of
187  // this is to save memory.
188  void Rebuild(const std::vector<const Entry*> &to_keep) {
189  SetType tmp_set;
190  for (typename std::vector<const Entry*>::const_iterator
191  iter = to_keep.begin();
192  iter != to_keep.end(); ++iter)
193  RebuildHelper(*iter, &tmp_set);
194  // Now delete all elems not in tmp_set.
195  for (typename SetType::iterator iter = set_.begin();
196  iter != set_.end(); ++iter) {
197  if (tmp_set.count(*iter) == 0)
198  delete (*iter); // delete the Entry; not needed.
199  }
200  set_.swap(tmp_set);
201  }
202 
204  int32 MemSize() const {
205  return set_.size() * sizeof(Entry) * 2; // this is a lower bound
206  // on the size this structure might take.
207  }
208  private:
209  class EntryKey { // Hash function object.
210  public:
211  inline size_t operator()(const Entry *entry) const {
212  size_t prime = 49109;
213  return static_cast<size_t>(entry->i)
214  + prime * reinterpret_cast<size_t>(entry->parent);
215  }
216  };
217  class EntryEqual {
218  public:
219  inline bool operator()(const Entry *e1, const Entry *e2) const {
220  return (*e1 == *e2);
221  }
222  };
223  typedef std::unordered_set<const Entry*, EntryKey, EntryEqual> SetType;
224 
225  void RebuildHelper(const Entry *to_add, SetType *tmp_set) {
226  while(true) {
227  if (to_add == NULL) return;
228  typename SetType::iterator iter = tmp_set->find(to_add);
229  if (iter == tmp_set->end()) { // not in tmp_set.
230  tmp_set->insert(to_add);
231  to_add = to_add->parent; // and loop.
232  } else {
233  return;
234  }
235  }
236  }
237 
239  Entry *new_entry_; // We always have a pre-allocated Entry ready to use,
240  // to avoid unnecessary news and deletes.
241  SetType set_;
242 
243 };
244 
245 
246 
247 
248 // class LatticeDeterminizer is templated on the same types that
249 // CompactLatticeWeight is templated on: the base weight (Weight), typically
250 // LatticeWeightTpl<float> etc. but could also be e.g. TropicalWeight, and the
251 // IntType, typically int32, used for the output symbols in the compact
252 // representation of strings [note: the output symbols would usually be
253 // p.d.f. id's in the anticipated use of this code] It has a special requirement
254 // on the Weight type: that there should be a Compare function on the weights
255 // such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 >
256 // w2. This requires that there be a total order on the weights.
257 
258 template<class Weight, class IntType> class LatticeDeterminizer {
259  public:
260  // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 correspondence
261  // between our states and the states in ofst. If destroy == true, release memory as we go
262  // (but we cannot output again).
263 
265  typedef ArcTpl<CompactWeight> CompactArc; // arc in compact, acceptor form of lattice
266  typedef ArcTpl<Weight> Arc; // arc in non-compact version of lattice
267 
268 
269  // Output to standard FST with CompactWeightTpl<Weight> as its weight type (the
270  // weight stores the original output-symbol strings). If destroy == true,
271  // release memory as we go (but we cannot output again).
272  void Output(MutableFst<CompactArc> *ofst, bool destroy = true) {
273  assert(determinized_);
274  typedef typename Arc::StateId StateId;
275  StateId nStates = static_cast<StateId>(output_arcs_.size());
276  if (destroy)
277  FreeMostMemory();
278  ofst->DeleteStates();
279  ofst->SetStart(kNoStateId);
280  if (nStates == 0) {
281  return;
282  }
283  for (StateId s = 0;s < nStates;s++) {
284  OutputStateId news = ofst->AddState();
285  assert(news == s);
286  }
287  ofst->SetStart(0);
288  // now process transitions.
289  for (StateId this_state = 0; this_state < nStates; this_state++) {
290  std::vector<TempArc> &this_vec(output_arcs_[this_state]);
291  typename std::vector<TempArc>::const_iterator iter = this_vec.begin(), end = this_vec.end();
292 
293  for (;iter != end; ++iter) {
294  const TempArc &temp_arc(*iter);
295  CompactArc new_arc;
296  std::vector<Label> seq;
297  repository_.ConvertToVector(temp_arc.string, &seq);
298  CompactWeight weight(temp_arc.weight, seq);
299  if (temp_arc.nextstate == kNoStateId) { // is really final weight.
300  ofst->SetFinal(this_state, weight);
301  } else { // is really an arc.
302  new_arc.nextstate = temp_arc.nextstate;
303  new_arc.ilabel = temp_arc.ilabel;
304  new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
305  new_arc.weight = weight; // includes string and weight.
306  ofst->AddArc(this_state, new_arc);
307  }
308  }
309  // Free up memory. Do this inside the loop as ofst is also allocating memory
310  if (destroy) { std::vector<TempArc> temp; std::swap(temp, this_vec); }
311  }
312  if (destroy) { std::vector<std::vector<TempArc> > temp; std::swap(temp, output_arcs_); }
313  }
314 
315  // Output to standard FST with Weight as its weight type. We will create extra
316  // states to handle sequences of symbols on the output. If destroy == true,
317  // release memory as we go (but we cannot output again).
318  void Output(MutableFst<Arc> *ofst, bool destroy = true) {
319  // Outputs to standard fst.
320  OutputStateId nStates = static_cast<OutputStateId>(output_arcs_.size());
321  ofst->DeleteStates();
322  if (nStates == 0) {
323  ofst->SetStart(kNoStateId);
324  return;
325  }
326  if (destroy)
327  FreeMostMemory();
328  // Add basic states-- but we will add extra ones to account for strings on output.
329  for (OutputStateId s = 0;s < nStates;s++) {
330  OutputStateId news = ofst->AddState();
331  assert(news == s);
332  }
333  ofst->SetStart(0);
334  for (OutputStateId this_state = 0; this_state < nStates; this_state++) {
335  std::vector<TempArc> &this_vec(output_arcs_[this_state]);
336 
337  typename std::vector<TempArc>::const_iterator iter = this_vec.begin(), end = this_vec.end();
338  for (; iter != end; ++iter) {
339  const TempArc &temp_arc(*iter);
340  std::vector<Label> seq;
341  repository_.ConvertToVector(temp_arc.string, &seq);
342 
343  if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
344  // Make a sequence of states going to a final state, with the strings
345  // as labels. Put the weight on the first arc.
346  OutputStateId cur_state = this_state;
347  for (size_t i = 0; i < seq.size(); i++) {
348  OutputStateId next_state = ofst->AddState();
349  Arc arc;
350  arc.nextstate = next_state;
351  arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
352  arc.ilabel = 0; // epsilon.
353  arc.olabel = seq[i];
354  ofst->AddArc(cur_state, arc);
355  cur_state = next_state;
356  }
357  ofst->SetFinal(cur_state, (seq.size() == 0 ? temp_arc.weight : Weight::One()));
358  } else { // Really an arc.
359  OutputStateId cur_state = this_state;
360  // Have to be careful with this integer comparison (i+1 < seq.size()) because unsigned.
361  // i < seq.size()-1 could fail for zero-length sequences.
362  for (size_t i = 0; i+1 < seq.size();i++) {
363  // for all but the last element of seq, create new state.
364  OutputStateId next_state = ofst->AddState();
365  Arc arc;
366  arc.nextstate = next_state;
367  arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
368  arc.ilabel = (i == 0 ? temp_arc.ilabel : 0); // put ilabel on first element of seq.
369  arc.olabel = seq[i];
370  ofst->AddArc(cur_state, arc);
371  cur_state = next_state;
372  }
373  // Add the final arc in the sequence.
374  Arc arc;
375  arc.nextstate = temp_arc.nextstate;
376  arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One());
377  arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0);
378  arc.olabel = (seq.size() > 0 ? seq.back() : 0);
379  ofst->AddArc(cur_state, arc);
380  }
381  }
382  // Free up memory. Do this inside the loop as ofst is also allocating memory
383  if (destroy) {
384  std::vector<TempArc> temp; temp.swap(this_vec);
385  }
386  }
387  if (destroy) {
388  std::vector<std::vector<TempArc> > temp;
389  temp.swap(output_arcs_);
390  repository_.Destroy();
391  }
392  }
393 
394 
395  // Initializer. After initializing the object you will typically
396  // call Determinize() and then call one of the Output functions.
397  // Note: ifst.Copy() will generally do a
398  // shallow copy. We do it like this for memory safety, rather than
399  // keeping a reference or pointer to ifst_.
400  LatticeDeterminizer(const Fst<Arc> &ifst,
402  num_arcs_(0), num_elems_(0), ifst_(ifst.Copy()), opts_(opts),
403  equal_(opts_.delta), determinized_(false),
404  minimal_hash_(3, hasher_, equal_), initial_hash_(3, hasher_, equal_) {
405  KALDI_ASSERT(Weight::Properties() & kIdempotent); // this algorithm won't
406  // work correctly otherwise.
407  }
408 
409  // frees all except output_arcs_, which contains the important info
410  // we need to output the FST.
411  void FreeMostMemory() {
412  if (ifst_) {
413  delete ifst_;
414  ifst_ = NULL;
415  }
416  for (typename MinimalSubsetHash::iterator iter = minimal_hash_.begin();
417  iter != minimal_hash_.end(); ++iter)
418  delete iter->first;
419  { MinimalSubsetHash tmp; tmp.swap(minimal_hash_); }
420  for (typename InitialSubsetHash::iterator iter = initial_hash_.begin();
421  iter != initial_hash_.end(); ++iter)
422  delete iter->first;
423  { InitialSubsetHash tmp; tmp.swap(initial_hash_); }
424  { std::vector<std::vector<Element>* > output_states_tmp;
425  output_states_tmp.swap(output_states_); }
426  { std::vector<char> tmp; tmp.swap(isymbol_or_final_); }
427  { std::vector<OutputStateId> tmp; tmp.swap(queue_); }
428  { std::vector<std::pair<Label, Element> > tmp; tmp.swap(all_elems_tmp_); }
429  }
430 
432  FreeMostMemory(); // rest is deleted by destructors.
433  }
434  void RebuildRepository() { // rebuild the string repository,
435  // freeing stuff we don't need.. we call this when memory usage
436  // passes a supplied threshold. We need to accumulate all the
437  // strings we need the repository to "remember", then tell it
438  // to clean the repository.
439  std::vector<StringId> needed_strings;
440  for (size_t i = 0; i < output_arcs_.size(); i++)
441  for (size_t j = 0; j < output_arcs_[i].size(); j++)
442  needed_strings.push_back(output_arcs_[i][j].string);
443 
444  // the following loop covers strings present in minimal_hash_
445  // which are also accessible via output_states_.
446  for (size_t i = 0; i < output_states_.size(); i++)
447  for (size_t j = 0; j < output_states_[i]->size(); j++)
448  needed_strings.push_back((*(output_states_[i]))[j].string);
449 
450  // the following loop covers strings present in initial_hash_.
451  for (typename InitialSubsetHash::const_iterator
452  iter = initial_hash_.begin();
453  iter != initial_hash_.end(); ++iter) {
454  const std::vector<Element> &vec = *(iter->first);
455  Element elem = iter->second;
456  for (size_t i = 0; i < vec.size(); i++)
457  needed_strings.push_back(vec[i].string);
458  needed_strings.push_back(elem.string);
459  }
460 
461  std::sort(needed_strings.begin(), needed_strings.end());
462  needed_strings.erase(std::unique(needed_strings.begin(),
463  needed_strings.end()),
464  needed_strings.end()); // uniq the strings.
465  repository_.Rebuild(needed_strings);
466  }
467 
469  int32 repo_size = repository_.MemSize(),
470  arcs_size = num_arcs_ * sizeof(TempArc),
471  elems_size = num_elems_ * sizeof(Element),
472  total_size = repo_size + arcs_size + elems_size;
473  if (opts_.max_mem > 0 && total_size > opts_.max_mem) { // We passed the memory threshold.
474  // This is usually due to the repository getting large, so we
475  // clean this out.
476  RebuildRepository();
477  int32 new_repo_size = repository_.MemSize(),
478  new_total_size = new_repo_size + arcs_size + elems_size;
479 
480  KALDI_VLOG(2) << "Rebuilt repository in determinize-lattice: repository shrank from "
481  << repo_size << " to " << new_repo_size << " bytes (approximately)";
482 
483  if (new_total_size > static_cast<int32>(opts_.max_mem * 0.8)) {
484  // Rebuilding didn't help enough-- we need a margin to stop
485  // having to rebuild too often.
486  KALDI_WARN << "Failure in determinize-lattice: size exceeds maximum "
487  << opts_.max_mem << " bytes; (repo,arcs,elems) = ("
488  << repo_size << "," << arcs_size << "," << elems_size
489  << "), after rebuilding, repo size was " << new_repo_size;
490  return false;
491  }
492  }
493  return true;
494  }
495 
496  // Returns true on success. Can fail for out-of-memory
497  // or max-states related reasons.
498  bool Determinize(bool *debug_ptr) {
499  assert(!determinized_);
500  // This determinizes the input fst but leaves it in the "special format"
501  // in "output_arcs_". Must be called after Initialize(). To get the
502  // output, call one of the Output routines.
503  try {
504  InitializeDeterminization(); // some start-up tasks.
505  while (!queue_.empty()) {
506  OutputStateId out_state = queue_.back();
507  queue_.pop_back();
508  ProcessState(out_state);
509  if (debug_ptr && *debug_ptr) Debug(); // will exit.
510  if (!CheckMemoryUsage()) return false;
511  }
512  return (determinized_ = true);
513  } catch (const std::bad_alloc &) {
514  int32 repo_size = repository_.MemSize(),
515  arcs_size = num_arcs_ * sizeof(TempArc),
516  elems_size = num_elems_ * sizeof(Element),
517  total_size = repo_size + arcs_size + elems_size;
518  KALDI_WARN << "Memory allocation error doing lattice determinization; using "
519  << total_size << " bytes (max = " << opts_.max_mem
520  << " (repo,arcs,elems) = ("
521  << repo_size << "," << arcs_size << "," << elems_size << ")";
522  return (determinized_ = false);
523  } catch (const std::runtime_error &) {
524  KALDI_WARN << "Caught exception doing lattice determinization";
525  return (determinized_ = false);
526  }
527  }
528  private:
529 
530  typedef typename Arc::Label Label;
531  typedef typename Arc::StateId StateId; // use this when we don't know if it's input or output.
532  typedef typename Arc::StateId InputStateId; // state in the input FST.
533  typedef typename Arc::StateId OutputStateId; // same as above but distinguish
534  // states in output Fst.
535 
536 
538  typedef const typename StringRepositoryType::Entry* StringId;
539 
540  // Element of a subset [of original states]
541  struct Element {
542  StateId state; // use StateId as this is usually InputStateId but in one case
543  // OutputStateId.
546  bool operator != (const Element &other) const {
547  return (state != other.state || string != other.string ||
548  weight != other.weight);
549  }
550  // This operator is only intended to support sorting in EpsilonClosure()
551  bool operator < (const Element &other) const {
552  return state < other.state;
553  }
554  };
555 
556  // Arcs in the format we temporarily create in this class (a representation, essentially of
557  // a Gallic Fst).
558  struct TempArc {
560  StringId string; // Look it up in the StringRepository, it's a sequence of Labels.
561  OutputStateId nextstate; // or kNoState for final weights.
563  };
564 
565  // Hashing function used in hash of subsets.
566  // A subset is a pointer to vector<Element>.
567  // The Elements are in sorted order on state id, and without repeated states.
568  // Because the order of Elements is fixed, we can use a hashing function that is
569  // order-dependent. However the weights are not included in the hashing function--
570  // we hash subsets that differ only in weight to the same key. This is not optimal
571  // in terms of the O(N) performance but typically if we have a lot of determinized
572  // states that differ only in weight then the input probably was pathological in some way,
573  // or even non-determinizable.
574  // We don't quantize the weights, in order to avoid inexactness in simple cases.
575  // Instead we apply the delta when comparing subsets for equality, and allow a small
576  // difference.
577 
578  class SubsetKey {
579  public:
580  size_t operator ()(const std::vector<Element> * subset) const { // hashes only the state and string.
581  size_t hash = 0, factor = 1;
582  for (typename std::vector<Element>::const_iterator iter= subset->begin(); iter != subset->end(); ++iter) {
583  hash *= factor;
584  hash += iter->state + reinterpret_cast<size_t>(iter->string);
585  factor *= 23531; // these numbers are primes.
586  }
587  return hash;
588  }
589  };
590 
591  // This is the equality operator on subsets. It checks for exact match on state-id
592  // and string, and approximate match on weights.
593  class SubsetEqual {
594  public:
595  bool operator ()(const std::vector<Element> * s1, const std::vector<Element> * s2) const {
596  size_t sz = s1->size();
597  assert(sz>=0);
598  if (sz != s2->size()) return false;
599  typename std::vector<Element>::const_iterator iter1 = s1->begin(),
600  iter1_end = s1->end(), iter2=s2->begin();
601  for (; iter1 < iter1_end; ++iter1, ++iter2) {
602  if (iter1->state != iter2->state ||
603  iter1->string != iter2->string ||
604  ! ApproxEqual(iter1->weight, iter2->weight, delta_)) return false;
605  }
606  return true;
607  }
608  float delta_;
609  SubsetEqual(float delta): delta_(delta) {}
610  SubsetEqual(): delta_(kDelta) {}
611  };
612 
613  // Operator that says whether two Elements have the same states.
614  // Used only for debug.
616  public:
617  bool operator ()(const std::vector<Element> * s1, const std::vector<Element> * s2) const {
618  size_t sz = s1->size();
619  assert(sz>=0);
620  if (sz != s2->size()) return false;
621  typename std::vector<Element>::const_iterator iter1 = s1->begin(),
622  iter1_end = s1->end(), iter2=s2->begin();
623  for (; iter1 < iter1_end; ++iter1, ++iter2) {
624  if (iter1->state != iter2->state) return false;
625  }
626  return true;
627  }
628  };
629 
630  // Define the hash type we use to map subsets (in minimal
631  // representation) to OutputStateId.
632  typedef std::unordered_map<const std::vector<Element>*, OutputStateId,
634 
635  // Define the hash type we use to map subsets (in initial
636  // representation) to OutputStateId, together with an
637  // extra weight. [note: we interpret the Element.state in here
638  // as an OutputStateId even though it's declared as InputStateId;
639  // these types are the same anyway].
640  typedef std::unordered_map<const std::vector<Element>*, Element,
642 
643 
644  // converts the representation of the subset from canonical (all states) to
645  // minimal (only states with output symbols on arcs leaving them, and final
646  // states). Output is not necessarily normalized, even if input_subset was.
647  void ConvertToMinimal(std::vector<Element> *subset) {
648  assert(!subset->empty());
649  typename std::vector<Element>::iterator cur_in = subset->begin(),
650  cur_out = subset->begin(), end = subset->end();
651  while (cur_in != end) {
652  if(IsIsymbolOrFinal(cur_in->state)) { // keep it...
653  *cur_out = *cur_in;
654  cur_out++;
655  }
656  cur_in++;
657  }
658  subset->resize(cur_out - subset->begin());
659  }
660 
661  // Takes a minimal, normalized subset, and converts it to an OutputStateId.
662  // Involves a hash lookup, and possibly adding a new OutputStateId.
663  // If it creates a new OutputStateId, it adds it to the queue.
664  OutputStateId MinimalToStateId(const std::vector<Element> &subset) {
665  typename MinimalSubsetHash::const_iterator iter
666  = minimal_hash_.find(&subset);
667  if (iter != minimal_hash_.end()) // Found a matching subset.
668  return iter->second;
669  OutputStateId ans = static_cast<OutputStateId>(output_arcs_.size());
670  std::vector<Element> *subset_ptr = new std::vector<Element>(subset);
671  output_states_.push_back(subset_ptr);
672  num_elems_ += subset_ptr->size();
673  output_arcs_.push_back(std::vector<TempArc>());
674  minimal_hash_[subset_ptr] = ans;
675  queue_.push_back(ans);
676  return ans;
677  }
678 
679 
680  // Given a normalized initial subset of elements (i.e. before epsilon closure),
681  // compute the corresponding output-state.
682  OutputStateId InitialToStateId(const std::vector<Element> &subset_in,
683  Weight *remaining_weight,
684  StringId *common_prefix) {
685  typename InitialSubsetHash::const_iterator iter
686  = initial_hash_.find(&subset_in);
687  if (iter != initial_hash_.end()) { // Found a matching subset.
688  const Element &elem = iter->second;
689  *remaining_weight = elem.weight;
690  *common_prefix = elem.string;
691  if (elem.weight == Weight::Zero())
692  KALDI_WARN << "Zero weight!"; // TEMP
693  return elem.state;
694  }
695  // else no matching subset-- have to work it out.
696  std::vector<Element> subset(subset_in);
697  // Follow through epsilons. Will add no duplicate states. note: after
698  // EpsilonClosure, it is the same as "canonical" subset, except not
699  // normalized (actually we never compute the normalized canonical subset,
700  // only the normalized minimal one).
701  EpsilonClosure(&subset); // follow epsilons.
702  ConvertToMinimal(&subset); // remove all but emitting and final states.
703 
704  Element elem; // will be used to store remaining weight and string, and
705  // OutputStateId, in initial_hash_;
706  NormalizeSubset(&subset, &elem.weight, &elem.string); // normalize subset; put
707  // common string and weight in "elem". The subset is now a minimal,
708  // normalized subset.
709 
710  OutputStateId ans = MinimalToStateId(subset);
711  *remaining_weight = elem.weight;
712  *common_prefix = elem.string;
713  if (elem.weight == Weight::Zero())
714  KALDI_WARN << "Zero weight!"; // TEMP
715 
716  // Before returning "ans", add the initial subset to the hash,
717  // so that we can bypass the epsilon-closure etc., next time
718  // we process the same initial subset.
719  std::vector<Element> *initial_subset_ptr = new std::vector<Element>(subset_in);
720  elem.state = ans;
721  initial_hash_[initial_subset_ptr] = elem;
722  num_elems_ += initial_subset_ptr->size(); // keep track of memory usage.
723  return ans;
724  }
725 
726  // returns the Compare value (-1 if a < b, 0 if a == b, 1 if a > b) according
727  // to the ordering we defined on strings for the CompactLatticeWeightTpl.
728  // see function
729  // inline int Compare (const CompactLatticeWeightTpl<WeightType,IntType> &w1,
730  // const CompactLatticeWeightTpl<WeightType,IntType> &w2)
731  // in lattice-weight.h.
732  // this is the same as that, but optimized for our data structures.
733  inline int Compare(const Weight &a_w, StringId a_str,
734  const Weight &b_w, StringId b_str) const {
735  int weight_comp = fst::Compare(a_w, b_w);
736  if (weight_comp != 0) return weight_comp;
737  // now comparing strings.
738  if (a_str == b_str) return 0;
739  std::vector<IntType> a_vec, b_vec;
740  repository_.ConvertToVector(a_str, &a_vec);
741  repository_.ConvertToVector(b_str, &b_vec);
742  // First compare their lengths.
743  int a_len = a_vec.size(), b_len = b_vec.size();
744  // use opposite order on the string lengths (c.f. Compare in
745  // lattice-weight.h)
746  if (a_len > b_len) return -1;
747  else if (a_len < b_len) return 1;
748  for(int i = 0; i < a_len; i++) {
749  if (a_vec[i] < b_vec[i]) return -1;
750  else if (a_vec[i] > b_vec[i]) return 1;
751  }
752  assert(0); // because we checked if a_str == b_str above, shouldn't reach here
753  return 0;
754  }
755 
756 
757  // This function computes epsilon closure of subset of states by following epsilon links.
758  // Called by InitialToStateId and Initialize.
759  // Has no side effects except on the string repository. The "output_subset" is not
760  // necessarily normalized (in the sense of there being no common substring), unless
761  // input_subset was.
762  void EpsilonClosure(std::vector<Element> *subset) {
763  // at input, subset must have only one example of each StateId. [will still
764  // be so at output]. This function follows input-epsilons, and augments the
765  // subset accordingly.
766 
767  std::deque<Element> queue;
768  std::unordered_map<InputStateId, Element> cur_subset;
769  typedef typename std::unordered_map<InputStateId, Element>::iterator MapIter;
770  typedef typename std::vector<Element>::const_iterator VecIter;
771 
772  for (VecIter iter = subset->begin(); iter != subset->end(); ++iter) {
773  queue.push_back(*iter);
774  cur_subset[iter->state] = *iter;
775  }
776 
777  // find whether input fst is known to be sorted on input label.
778  bool sorted = ((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
779  bool replaced_elems = false; // relates to an optimization, see below.
780  int counter = 0; // stops infinite loops here for non-lattice-determinizable input;
781  // useful in testing.
782  while (queue.size() != 0) {
783  Element elem = queue.front();
784  queue.pop_front();
785 
786  // The next if-statement is a kind of optimization. It's to prevent us
787  // unnecessarily repeating the processing of a state. "cur_subset" always
788  // contains only one Element with a particular state. The issue is that
789  // whenever we modify the Element corresponding to that state in "cur_subset",
790  // both the new (optimal) and old (less-optimal) Element will still be in
791  // "queue". The next if-statement stops us from wasting compute by
792  // processing the old Element.
793  if (replaced_elems && cur_subset[elem.state] != elem)
794  continue;
795  if (opts_.max_loop > 0 && counter++ > opts_.max_loop) {
796  KALDI_ERR << "Lattice determinization aborted since looped more than "
797  << opts_.max_loop << " times during epsilon closure";
798  }
799  for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done(); aiter.Next()) {
800  const Arc &arc = aiter.Value();
801  if (sorted && arc.ilabel != 0) break; // Break from the loop: due to sorting there will be no
802  // more transitions with epsilons as input labels.
803  if (arc.ilabel == 0
804  && arc.weight != Weight::Zero()) { // Epsilon transition.
805  Element next_elem;
806  next_elem.state = arc.nextstate;
807  next_elem.weight = Times(elem.weight, arc.weight);
808  // now must append strings
809  if (arc.olabel == 0)
810  next_elem.string = elem.string;
811  else
812  next_elem.string = repository_.Successor(elem.string, arc.olabel);
813 
814  MapIter iter = cur_subset.find(next_elem.state);
815  if (iter == cur_subset.end()) {
816  // was no such StateId: insert and add to queue.
817  cur_subset[next_elem.state] = next_elem;
818  queue.push_back(next_elem);
819  } else {
820  // was not inserted because one already there. In normal determinization we'd
821  // add the weights. Here, we find which one has the better weight, and
822  // keep its corresponding string.
823  int comp = Compare(next_elem.weight, next_elem.string,
824  iter->second.weight, iter->second.string);
825  if(comp == 1) { // next_elem is better, so use its (weight, string)
826  iter->second.string = next_elem.string;
827  iter->second.weight = next_elem.weight;
828  queue.push_back(next_elem);
829  replaced_elems = true;
830  }
831  // else it is the same or worse, so use original one.
832  }
833  }
834  }
835  }
836 
837  { // copy cur_subset to subset.
838  subset->clear();
839  subset->reserve(cur_subset.size());
840  MapIter iter = cur_subset.begin(), end = cur_subset.end();
841  for (; iter != end; ++iter) subset->push_back(iter->second);
842  // sort by state ID, because the subset hash function is order-dependent(see SubsetKey)
843  std::sort(subset->begin(), subset->end());
844  }
845  }
846 
847 
848  // This function works out the final-weight of the determinized state.
849  // called by ProcessSubset.
850  // Has no side effects except on the variable repository_, and output_arcs_.
851 
852  void ProcessFinal(OutputStateId output_state) {
853  const std::vector<Element> &minimal_subset = *(output_states_[output_state]);
854  // processes final-weights for this subset.
855 
856  // minimal_subset may be empty if the graphs is not connected/trimmed, I think,
857  // do don't check that it's nonempty.
858  bool is_final = false;
859  StringId final_string = NULL; // = NULL to keep compiler happy.
860  Weight final_weight = Weight::Zero();
861  typename std::vector<Element>::const_iterator iter = minimal_subset.begin(), end = minimal_subset.end();
862  for (; iter != end; ++iter) {
863  const Element &elem = *iter;
864  Weight this_final_weight = Times(elem.weight, ifst_->Final(elem.state));
865  StringId this_final_string = elem.string;
866  if (this_final_weight != Weight::Zero() &&
867  (!is_final || Compare(this_final_weight, this_final_string,
868  final_weight, final_string) == 1)) { // the new
869  // (weight, string) pair is more in semiring than our current
870  // one.
871  is_final = true;
872  final_weight = this_final_weight;
873  final_string = this_final_string;
874  }
875  }
876  if (is_final) {
877  // store final weights in TempArc structure, just like a transition.
878  TempArc temp_arc;
879  temp_arc.ilabel = 0;
880  temp_arc.nextstate = kNoStateId; // special marker meaning "final weight".
881  temp_arc.string = final_string;
882  temp_arc.weight = final_weight;
883  output_arcs_[output_state].push_back(temp_arc);
884  num_arcs_++;
885  }
886  }
887 
888  // NormalizeSubset normalizes the subset "elems" by
889  // removing any common string prefix (putting it in common_str),
890  // and dividing by the total weight (putting it in tot_weight).
891  void NormalizeSubset(std::vector<Element> *elems,
892  Weight *tot_weight,
893  StringId *common_str) {
894  if(elems->empty()) { // just set common_str, tot_weight
895  KALDI_WARN << "[empty subset]"; // TEMP
896  // to defaults and return...
897  *common_str = repository_.EmptyString();
898  *tot_weight = Weight::Zero();
899  return;
900  }
901  size_t size = elems->size();
902  std::vector<IntType> common_prefix;
903  repository_.ConvertToVector((*elems)[0].string, &common_prefix);
904  Weight weight = (*elems)[0].weight;
905  for (size_t i = 1; i < size; i++) {
906  weight = Plus(weight, (*elems)[i].weight);
907  repository_.ReduceToCommonPrefix((*elems)[i].string, &common_prefix);
908  }
909  assert(weight != Weight::Zero()); // we made sure to ignore arcs with zero
910  // weights on them, so we shouldn't have zero here.
911  size_t prefix_len = common_prefix.size();
912  for (size_t i = 0; i < size; i++) {
913  (*elems)[i].weight = Divide((*elems)[i].weight, weight, DIVIDE_LEFT);
914  (*elems)[i].string =
915  repository_.RemovePrefix((*elems)[i].string, prefix_len);
916  }
917  *common_str = repository_.ConvertFromVector(common_prefix);
918  *tot_weight = weight;
919  }
920 
921  // Take a subset of Elements that is sorted on state, and
922  // merge any Elements that have the same state (taking the best
923  // (weight, string) pair in the semiring).
924  void MakeSubsetUnique(std::vector<Element> *subset) {
925  typedef typename std::vector<Element>::iterator IterType;
926 
927  // This assert is designed to fail (usually) if the subset is not sorted on
928  // state.
929  assert(subset->size() < 2 || (*subset)[0].state <= (*subset)[1].state);
930 
931  IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
932  size_t num_out = 0;
933  // Merge elements with same state-id
934  while (cur_in != end) { // while we have more elements to process.
935  // At this point, cur_out points to location of next place we want to put an element,
936  // cur_in points to location of next element we want to process.
937  if (cur_in != cur_out) *cur_out = *cur_in;
938  cur_in++;
939  while (cur_in != end && cur_in->state == cur_out->state) {
940  if (Compare(cur_in->weight, cur_in->string,
941  cur_out->weight, cur_out->string) == 1) {
942  // if *cur_in > *cur_out in semiring, then take *cur_in.
943  cur_out->string = cur_in->string;
944  cur_out->weight = cur_in->weight;
945  }
946  cur_in++;
947  }
948  cur_out++;
949  num_out++;
950  }
951  subset->resize(num_out);
952  }
953 
954  // ProcessTransition is called from "ProcessTransitions". Broken out for
955  // clarity. Processes a transition from state "state". The set of Elements
956  // represents a set of next-states with associated weights and strings, each
957  // one arising from an arc from some state in a determinized-state; the
958  // next-states are not necessarily unique (i.e. there may be >1 entry
959  // associated with each), and any such sets of Elements have to be merged
960  // within this routine (we take the [weight, string] pair that's better in the
961  // semiring).
962  void ProcessTransition(OutputStateId state, Label ilabel, std::vector<Element> *subset) {
963  MakeSubsetUnique(subset); // remove duplicates with the same state.
964 
965  StringId common_str;
966  Weight tot_weight;
967  NormalizeSubset(subset, &tot_weight, &common_str);
968 
969  OutputStateId nextstate;
970  {
971  Weight next_tot_weight;
972  StringId next_common_str;
973  nextstate = InitialToStateId(*subset,
974  &next_tot_weight,
975  &next_common_str);
976  common_str = repository_.Concatenate(common_str, next_common_str);
977  tot_weight = Times(tot_weight, next_tot_weight);
978  }
979 
980  // Now add an arc to the next state (would have been created if necessary by
981  // InitialToStateId).
982  TempArc temp_arc;
983  temp_arc.ilabel = ilabel;
984  temp_arc.nextstate = nextstate;
985  temp_arc.string = common_str;
986  temp_arc.weight = tot_weight;
987  output_arcs_[state].push_back(temp_arc); // record the arc.
988  num_arcs_++;
989  }
990 
991 
992  // "less than" operator for pair<Label, Element>. Used in ProcessTransitions.
993  // Lexicographical order, which only compares the state when ordering the
994  // "Element" member of the pair.
995 
997  public:
998  inline bool operator () (const std::pair<Label, Element> &p1, const std::pair<Label, Element> &p2) {
999  if (p1.first < p2.first) return true;
1000  else if (p1.first > p2.first) return false;
1001  else {
1002  return p1.second.state < p2.second.state;
1003  }
1004  }
1005  };
1006 
1007 
1008  // ProcessTransitions processes emitting transitions (transitions
1009  // with ilabels) out of this subset of states.
1010  // Does not consider final states. Breaks the emitting transitions up by ilabel,
1011  // and creates a new transition in the determinized FST for each unique ilabel.
1012  // Does this by creating a big vector of pairs <Label, Element> and then sorting them
1013  // using a lexicographical ordering, and calling ProcessTransition for each range
1014  // with the same ilabel.
1015  // Side effects on repository, and (via ProcessTransition) on Q_, hash_,
1016  // and output_arcs_.
1017 
1018  void ProcessTransitions(OutputStateId output_state) {
1019  const std::vector<Element> &minimal_subset = *(output_states_[output_state]);
1020  // it's possible that minimal_subset could be empty if there are
1021  // unreachable parts of the graph, so don't check that it's nonempty.
1022  std::vector<std::pair<Label, Element> > &all_elems(all_elems_tmp_); // use class member
1023  // to avoid memory allocation/deallocation.
1024  {
1025  // Push back into "all_elems", elements corresponding to all
1026  // non-epsilon-input transitions out of all states in "minimal_subset".
1027  typename std::vector<Element>::const_iterator iter = minimal_subset.begin(), end = minimal_subset.end();
1028  for (;iter != end; ++iter) {
1029  const Element &elem = *iter;
1030  for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); ! aiter.Done(); aiter.Next()) {
1031  const Arc &arc = aiter.Value();
1032  if (arc.ilabel != 0
1033  && arc.weight != Weight::Zero()) { // Non-epsilon transition -- ignore epsilons here.
1034  std::pair<Label, Element> this_pr;
1035  this_pr.first = arc.ilabel;
1036  Element &next_elem(this_pr.second);
1037  next_elem.state = arc.nextstate;
1038  next_elem.weight = Times(elem.weight, arc.weight);
1039  if (arc.olabel == 0) // output epsilon
1040  next_elem.string = elem.string;
1041  else
1042  next_elem.string = repository_.Successor(elem.string, arc.olabel);
1043  all_elems.push_back(this_pr);
1044  }
1045  }
1046  }
1047  }
1048  PairComparator pc;
1049  std::sort(all_elems.begin(), all_elems.end(), pc);
1050  // now sorted first on input label, then on state.
1051  typedef typename std::vector<std::pair<Label, Element> >::const_iterator PairIter;
1052  PairIter cur = all_elems.begin(), end = all_elems.end();
1053  std::vector<Element> this_subset;
1054  while (cur != end) {
1055  // Process ranges that share the same input symbol.
1056  Label ilabel = cur->first;
1057  this_subset.clear();
1058  while (cur != end && cur->first == ilabel) {
1059  this_subset.push_back(cur->second);
1060  cur++;
1061  }
1062  // We now have a subset for this ilabel.
1063  assert(!this_subset.empty()); // temp.
1064  ProcessTransition(output_state, ilabel, &this_subset);
1065  }
1066  all_elems.clear(); // as it's a class variable-- want it to stay
1067  // emtpy.
1068  }
1069 
1070 
1071 
1072  // ProcessState does the processing of a determinized state, i.e. it creates
1073  // transitions out of it and the final-probability if any.
1074  void ProcessState(OutputStateId output_state) {
1075  ProcessFinal(output_state);
1076  ProcessTransitions(output_state);
1077  }
1078 
1079 
1080  void Debug() { // this function called if you send a signal
1081  // SIGUSR1 to the process (and it's caught by the handler in
1082  // fstdeterminizestar). It prints out some traceback
1083  // info and exits.
1084 
1085  KALDI_WARN << "Debug function called (probably SIGUSR1 caught)";
1086  // free up memory from the hash as we need a little memory
1087  { MinimalSubsetHash hash_tmp; hash_tmp.swap(minimal_hash_); }
1088 
1089  if (output_arcs_.size() <= 2) {
1090  KALDI_ERR << "Nothing to trace back";
1091  }
1092  size_t max_state = output_arcs_.size() - 2; // Don't take the last
1093  // one as we might be halfway into constructing it.
1094 
1095  std::vector<OutputStateId> predecessor(max_state+1, kNoStateId);
1096  for (size_t i = 0; i < max_state; i++) {
1097  for (size_t j = 0; j < output_arcs_[i].size(); j++) {
1098  OutputStateId nextstate = output_arcs_[i][j].nextstate;
1099  // Always find an earlier-numbered predecessor; this
1100  // is always possible because of the way the algorithm
1101  // works.
1102  if (nextstate <= max_state && nextstate > i)
1103  predecessor[nextstate] = i;
1104  }
1105  }
1106  std::vector<std::pair<Label, StringId> > traceback;
1107  // 'traceback' is a pair of (ilabel, olabel-seq).
1108  OutputStateId cur_state = max_state; // A recently constructed state.
1109 
1110  while (cur_state != 0 && cur_state != kNoStateId) {
1111  OutputStateId last_state = predecessor[cur_state];
1112  std::pair<Label, StringId> p;
1113  size_t i;
1114  for (i = 0; i < output_arcs_[last_state].size(); i++) {
1115  if (output_arcs_[last_state][i].nextstate == cur_state) {
1116  p.first = output_arcs_[last_state][i].ilabel;
1117  p.second = output_arcs_[last_state][i].string;
1118  traceback.push_back(p);
1119  break;
1120  }
1121  }
1122  KALDI_ASSERT(i != output_arcs_[last_state].size()); // Or fell off loop.
1123  cur_state = last_state;
1124  }
1125  if (cur_state == kNoStateId)
1126  KALDI_WARN << "Traceback did not reach start state "
1127  << "(possibly debug-code error)";
1128 
1129  std::stringstream ss;
1130  ss << "Traceback follows in format "
1131  << "ilabel (olabel olabel) ilabel (olabel) ... :";
1132  for (ssize_t i = traceback.size() - 1; i >= 0; i--) {
1133  ss << ' ' << traceback[i].first << " ( ";
1134  std::vector<Label> seq;
1135  repository_.ConvertToVector(traceback[i].second, &seq);
1136  for (size_t j = 0; j < seq.size(); j++)
1137  ss << seq[j] << ' ';
1138  ss << ')';
1139  }
1140  KALDI_ERR << ss.str();
1141  }
1142 
1143  bool IsIsymbolOrFinal(InputStateId state) { // returns true if this state
1144  // of the input FST either is final or has an osymbol on an arc out of it.
1145  // Uses the vector isymbol_or_final_ as a cache for this info.
1146  assert(state >= 0);
1147  if (isymbol_or_final_.size() <= state)
1148  isymbol_or_final_.resize(state+1, static_cast<char>(OSF_UNKNOWN));
1149  if (isymbol_or_final_[state] == static_cast<char>(OSF_NO))
1150  return false;
1151  else if (isymbol_or_final_[state] == static_cast<char>(OSF_YES))
1152  return true;
1153  // else work it out...
1154  isymbol_or_final_[state] = static_cast<char>(OSF_NO);
1155  if (ifst_->Final(state) != Weight::Zero())
1156  isymbol_or_final_[state] = static_cast<char>(OSF_YES);
1157  for (ArcIterator<Fst<Arc> > aiter(*ifst_, state);
1158  !aiter.Done();
1159  aiter.Next()) {
1160  const Arc &arc = aiter.Value();
1161  if (arc.ilabel != 0 && arc.weight != Weight::Zero()) {
1162  isymbol_or_final_[state] = static_cast<char>(OSF_YES);
1163  return true;
1164  }
1165  }
1166  return IsIsymbolOrFinal(state); // will only recurse once.
1167  }
1168 
1170  if(ifst_->Properties(kExpanded, false) != 0) { // if we know the number of
1171  // states in ifst_, it might be a bit more efficient
1172  // to pre-size the hashes so we're not constantly rebuilding them.
1173 #if !(__GNUC__ == 4 && __GNUC_MINOR__ == 0)
1174  StateId num_states =
1175  down_cast<const ExpandedFst<Arc>*, const Fst<Arc> >(ifst_)->NumStates();
1176  minimal_hash_.rehash(num_states/2 + 3);
1177  initial_hash_.rehash(num_states/2 + 3);
1178 #endif
1179  }
1180  InputStateId start_id = ifst_->Start();
1181  if (start_id != kNoStateId) {
1182  /* Insert determinized-state corresponding to the start state into hash and
1183  queue. Unlike all the other states, we don't "normalize" the representation
1184  of this determinized-state before we put it into minimal_hash_. This is actually
1185  what we want, as otherwise we'd have problems dealing with any extra weight
1186  and string and might have to create a "super-initial" state which would make
1187  the output nondeterministic. Normalization is only needed to make the
1188  determinized output more minimal anyway, it's not needed for correctness.
1189  Note, we don't put anything in the initial_hash_. The initial_hash_ is only
1190  a lookaside buffer anyway, so this isn't a problem-- it will get populated
1191  later if it needs to be.
1192  */
1193  Element elem;
1194  elem.state = start_id;
1195  elem.weight = Weight::One();
1196  elem.string = repository_.EmptyString(); // Id of empty sequence.
1197  std::vector<Element> subset;
1198  subset.push_back(elem);
1199  EpsilonClosure(&subset); // follow through epsilon-inputs links
1200  ConvertToMinimal(&subset); // remove all but final states and
1201  // states with input-labels on arcs out of them.
1202  std::vector<Element> *subset_ptr = new std::vector<Element>(subset);
1203  assert(output_arcs_.empty() && output_states_.empty());
1204  // add the new state...
1205  output_states_.push_back(subset_ptr);
1206  output_arcs_.push_back(std::vector<TempArc>());
1207  OutputStateId initial_state = 0;
1208  minimal_hash_[subset_ptr] = initial_state;
1209  queue_.push_back(initial_state);
1210  }
1211  }
1212 
1214 
1215 
1216  std::vector<std::vector<Element>* > output_states_; // maps from output state to
1217  // minimal representation [normalized].
1218  // View pointers as owned in
1219  // minimal_hash_.
1220  std::vector<std::vector<TempArc> > output_arcs_; // essentially an FST in our format.
1221 
1222  int num_arcs_; // keep track of memory usage: number of arcs in output_arcs_
1223  int num_elems_; // keep track of memory usage: number of elems in output_states_
1224 
1225  const Fst<Arc> *ifst_;
1227  SubsetKey hasher_; // object that computes keys-- has no data members.
1228  SubsetEqual equal_; // object that compares subsets-- only data member is delta_.
1229  bool determinized_; // set to true when user called Determinize(); used to make
1230  // sure this object is used correctly.
1231  MinimalSubsetHash minimal_hash_; // hash from Subset to OutputStateId. Subset is "minimal
1232  // representation" (only include final and states and states with
1233  // nonzero ilabel on arc out of them. Owns the pointers
1234  // in its keys.
1235  InitialSubsetHash initial_hash_; // hash from Subset to Element, which
1236  // represents the OutputStateId together
1237  // with an extra weight and string. Subset
1238  // is "initial representation". The extra
1239  // weight and string is needed because after
1240  // we convert to minimal representation and
1241  // normalize, there may be an extra weight
1242  // and string. Owns the pointers
1243  // in its keys.
1244  std::vector<OutputStateId> queue_; // Queue of output-states to process. Starts with
1245  // state 0, and increases and then (hopefully) decreases in length during
1246  // determinization. LIFO queue (queue discipline doesn't really matter).
1247 
1248  std::vector<std::pair<Label, Element> > all_elems_tmp_; // temporary vector used in ProcessTransitions.
1249 
1250  enum IsymbolOrFinal { OSF_UNKNOWN = 0, OSF_NO = 1, OSF_YES = 2 };
1251 
1252  std::vector<char> isymbol_or_final_; // A kind of cache; it says whether
1253  // each state is (emitting or final) where emitting means it has at least one
1254  // non-epsilon output arc. Only accessed by IsIsymbolOrFinal()
1255 
1256  LatticeStringRepository<IntType> repository_; // defines a compact and fast way of
1257  // storing sequences of labels.
1258 };
1259 
1260 
1261 // normally Weight would be LatticeWeight<float> (which has two floats),
1262 // or possibly TropicalWeightTpl<float>, and IntType would be int32.
1263 template<class Weight, class IntType>
1264 bool DeterminizeLattice(const Fst<ArcTpl<Weight> > &ifst,
1265  MutableFst<ArcTpl<Weight> > *ofst,
1267  bool *debug_ptr) {
1268  ofst->SetInputSymbols(ifst.InputSymbols());
1269  ofst->SetOutputSymbols(ifst.OutputSymbols());
1270  LatticeDeterminizer<Weight, IntType> det(ifst, opts);
1271  if (!det.Determinize(debug_ptr))
1272  return false;
1273  det.Output(ofst);
1274  return true;
1275 }
1276 
1277 
1278 // normally Weight would be LatticeWeight<float> (which has two floats),
1279 // or possibly TropicalWeightTpl<float>, and IntType would be int32.
1280 template<class Weight, class IntType>
1281 bool DeterminizeLattice(const Fst<ArcTpl<Weight> >&ifst,
1282  MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > >*ofst,
1284  bool *debug_ptr) {
1285  ofst->SetInputSymbols(ifst.InputSymbols());
1286  ofst->SetOutputSymbols(ifst.OutputSymbols());
1287  LatticeDeterminizer<Weight, IntType> det(ifst, opts);
1288  if (!det.Determinize(debug_ptr))
1289  return false;
1290  det.Output(ofst);
1291  return true;
1292 }
1293 
1294 } // namespace fst
1295 
1296 #endif // KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
fst::StdArc::StateId StateId
OutputStateId MinimalToStateId(const std::vector< Element > &subset)
const Entry * ConvertFromVector(const std::vector< IntType > &vec)
LatticeWeightTpl< FloatType > Divide(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, DivideType typ=DIVIDE_ANY)
const StringRepositoryType::Entry * StringId
void Rebuild(const std::vector< const Entry *> &to_keep)
const Entry * CommonPrefix(const Entry *a, const Entry *b)
bool operator!=(const LatticeWeightTpl< FloatType > &wa, const LatticeWeightTpl< FloatType > &wb)
void Output(MutableFst< Arc > *ofst, bool destroy=true)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
bool Determinize(bool *debug_ptr)
LatticeWeightTpl< FloatType > Plus(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
void Output(MutableFst< CompactArc > *ofst, bool destroy=true)
void ConvertToVector(const Entry *entry, std::vector< IntType > *out) const
LatticeStringRepository< IntType > repository_
std::unordered_set< const Entry *, EntryKey, EntryEqual > SetType
void swap(basic_filebuf< CharT, Traits > &x, basic_filebuf< CharT, Traits > &y)
kaldi::int32 int32
bool ApproxEqual(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, float delta=kDelta)
OutputStateId InitialToStateId(const std::vector< Element > &subset_in, Weight *remaining_weight, StringId *common_prefix)
bool operator()(const Entry *e1, const Entry *e2) const
const Entry * Successor(const Entry *parent, IntType i)
std::vector< char > isymbol_or_final_
size_t operator()(const Entry *entry) const
void ProcessTransition(OutputStateId state, Label ilabel, std::vector< Element > *subset)
void RebuildHelper(const Entry *to_add, SetType *tmp_set)
const Entry * RemovePrefix(const Entry *a, size_t n)
void NormalizeSubset(std::vector< Element > *elems, Weight *tot_weight, StringId *common_str)
std::vector< std::vector< TempArc > > output_arcs_
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
void ConvertToMinimal(std::vector< Element > *subset)
int Compare(const Weight &a_w, StringId a_str, const Weight &b_w, StringId b_str) const
LatticeDeterminizer(const Fst< Arc > &ifst, DeterminizeLatticeOptions opts)
bool IsIsymbolOrFinal(InputStateId state)
void EpsilonClosure(std::vector< Element > *subset)
struct rnnlm::@11::@12 n
ArcTpl< CompactWeight > CompactArc
#define KALDI_ERR
Definition: kaldi-error.h:147
void MakeSubsetUnique(std::vector< Element > *subset)
bool operator==(const Entry &other) const
#define KALDI_WARN
Definition: kaldi-error.h:150
std::unordered_map< const std::vector< Element > *, Element, SubsetKey, SubsetEqual > InitialSubsetHash
fst::StdArc::Label Label
DeterminizeLatticeOptions opts_
fst::StdArc::Weight Weight
std::vector< std::pair< Label, Element > > all_elems_tmp_
bool operator<(const Int32Pair &a, const Int32Pair &b)
Definition: cu-matrixdim.h:83
std::vector< OutputStateId > queue_
CompactLatticeWeightTpl< Weight, IntType > CompactWeight
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeStringRepository)
void ProcessTransitions(OutputStateId output_state)
int Compare(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.
LatticeStringRepository< IntType > StringRepositoryType
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
bool DeterminizeLattice(const Fst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< Weight > > *ofst, DeterminizeLatticeOptions opts, bool *debug_ptr)
This function implements the normal version of DeterminizeLattice, in which the output strings are re...
std::vector< std::vector< Element > *> output_states_
void ProcessState(OutputStateId output_state)
bool IsPrefixOf(const Entry *a, const Entry *b) const
const Entry * Concatenate(const Entry *a, const Entry *b)
void ReduceToCommonPrefix(const Entry *a, std::vector< IntType > *b)
void ProcessFinal(OutputStateId output_state)
std::unordered_map< const std::vector< Element > *, OutputStateId, SubsetKey, SubsetEqual > MinimalSubsetHash
size_t Size(const Entry *entry) const
void Copy(const CuMatrixBase< Real > &src, const CuArray< int32 > &copy_from_indices, CuMatrixBase< Real > *tgt)
Copies elements from src into tgt as given by copy_from_indices.
Definition: cu-math.cc:173