determinize-lattice-pruned.cc
Go to the documentation of this file.
1 // lat/determinize-lattice-pruned.cc
2 
3 // Copyright 2009-2012 Microsoft Corporation
4 // 2012-2013 Johns Hopkins University (Author: Daniel Povey)
5 // 2014 Guoguo Chen
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include <vector>
23 #include <climits>
24 #include "fstext/determinize-lattice.h" // for LatticeStringRepository
25 #include "fstext/fstext-utils.h"
26 #include "lat/lattice-functions.h" // for PruneLattice
27 #include "lat/minimize-lattice.h" // for minimization
28 #include "lat/push-lattice.h" // for minimization
30 
31 namespace fst {
32 
33 using std::vector;
34 using std::pair;
35 using std::greater;
36 
37 // class LatticeDeterminizerPruned is templated on the same types that
38 // CompactLatticeWeight is templated on: the base weight (Weight), typically
39 // LatticeWeightTpl<float> etc. but could also be e.g. TropicalWeight, and the
40 // IntType, typically int32, used for the output symbols in the compact
41 // representation of strings [note: the output symbols would usually be
42 // p.d.f. id's in the anticipated use of this code] It has a special requirement
43 // on the Weight type: that there should be a Compare function on the weights
44 // such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 >
45 // w2. This requires that there be a total order on the weights.
46 
47 template<class Weight, class IntType> class LatticeDeterminizerPruned {
48  public:
49  // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 correspondence
50  // between our states and the states in ofst. If destroy == true, release memory as we go
51  // (but we cannot output again).
52 
54  typedef ArcTpl<CompactWeight> CompactArc; // arc in compact, acceptor form of lattice
55  typedef ArcTpl<Weight> Arc; // arc in non-compact version of lattice
56 
57  // Output to standard FST with CompactWeightTpl<Weight> as its weight type (the
58  // weight stores the original output-symbol strings). If destroy == true,
59  // release memory as we go (but we cannot output again).
60  void Output(MutableFst<CompactArc> *ofst, bool destroy = true) {
62  typedef typename Arc::StateId StateId;
63  StateId nStates = static_cast<StateId>(output_states_.size());
64  if (destroy)
66  ofst->DeleteStates();
67  ofst->SetStart(kNoStateId);
68  if (nStates == 0) {
69  return;
70  }
71  for (StateId s = 0;s < nStates;s++) {
72  OutputStateId news = ofst->AddState();
73  KALDI_ASSERT(news == s);
74  }
75  ofst->SetStart(0);
76  // now process transitions.
77  for (StateId this_state_id = 0; this_state_id < nStates; this_state_id++) {
78  OutputState &this_state = *(output_states_[this_state_id]);
79  vector<TempArc> &this_vec(this_state.arcs);
80  typename vector<TempArc>::const_iterator iter = this_vec.begin(), end = this_vec.end();
81 
82  for (;iter != end; ++iter) {
83  const TempArc &temp_arc(*iter);
84  CompactArc new_arc;
85  vector<Label> olabel_seq;
86  repository_.ConvertToVector(temp_arc.string, &olabel_seq);
87  CompactWeight weight(temp_arc.weight, olabel_seq);
88  if (temp_arc.nextstate == kNoStateId) { // is really final weight.
89  ofst->SetFinal(this_state_id, weight);
90  } else { // is really an arc.
91  new_arc.nextstate = temp_arc.nextstate;
92  new_arc.ilabel = temp_arc.ilabel;
93  new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
94  new_arc.weight = weight; // includes string and weight.
95  ofst->AddArc(this_state_id, new_arc);
96  }
97  }
98  // Free up memory. Do this inside the loop as ofst is also allocating memory,
99  // and we want to reduce the maximum amount ever allocated.
100  if (destroy) { vector<TempArc> temp; temp.swap(this_vec); }
101  }
102  if (destroy) {
104  repository_.Destroy();
105  }
106  }
107 
108  // Output to standard FST with Weight as its weight type. We will create extra
109  // states to handle sequences of symbols on the output. If destroy == true,
110  // release memory as we go (but we cannot output again).
111  void Output(MutableFst<Arc> *ofst, bool destroy = true) {
112  // Outputs to standard fst.
113  OutputStateId nStates = static_cast<OutputStateId>(output_states_.size());
114  ofst->DeleteStates();
115  if (nStates == 0) {
116  ofst->SetStart(kNoStateId);
117  return;
118  }
119  if (destroy)
120  FreeMostMemory();
121  // Add basic states-- but we will add extra ones to account for strings on output.
122  for (OutputStateId s = 0; s< nStates;s++) {
123  OutputStateId news = ofst->AddState();
124  KALDI_ASSERT(news == s);
125  }
126  ofst->SetStart(0);
127  for (OutputStateId this_state_id = 0; this_state_id < nStates; this_state_id++) {
128  OutputState &this_state = *(output_states_[this_state_id]);
129  vector<TempArc> &this_vec(this_state.arcs);
130 
131  typename vector<TempArc>::const_iterator iter = this_vec.begin(), end = this_vec.end();
132  for (; iter != end; ++iter) {
133  const TempArc &temp_arc(*iter);
134  vector<Label> seq;
135  repository_.ConvertToVector(temp_arc.string, &seq);
136 
137  if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
138  // Make a sequence of states going to a final state, with the strings
139  // as labels. Put the weight on the first arc.
140  OutputStateId cur_state = this_state_id;
141  for (size_t i = 0; i < seq.size(); i++) {
142  OutputStateId next_state = ofst->AddState();
143  Arc arc;
144  arc.nextstate = next_state;
145  arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
146  arc.ilabel = 0; // epsilon.
147  arc.olabel = seq[i];
148  ofst->AddArc(cur_state, arc);
149  cur_state = next_state;
150  }
151  ofst->SetFinal(cur_state, (seq.size() == 0 ? temp_arc.weight : Weight::One()));
152  } else { // Really an arc.
153  OutputStateId cur_state = this_state_id;
154  // Have to be careful with this integer comparison (i+1 < seq.size()) because unsigned.
155  // i < seq.size()-1 could fail for zero-length sequences.
156  for (size_t i = 0; i+1 < seq.size();i++) {
157  // for all but the last element of seq, create new state.
158  OutputStateId next_state = ofst->AddState();
159  Arc arc;
160  arc.nextstate = next_state;
161  arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
162  arc.ilabel = (i == 0 ? temp_arc.ilabel : 0); // put ilabel on first element of seq.
163  arc.olabel = seq[i];
164  ofst->AddArc(cur_state, arc);
165  cur_state = next_state;
166  }
167  // Add the final arc in the sequence.
168  Arc arc;
169  arc.nextstate = temp_arc.nextstate;
170  arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One());
171  arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0);
172  arc.olabel = (seq.size() > 0 ? seq.back() : 0);
173  ofst->AddArc(cur_state, arc);
174  }
175  }
176  // Free up memory. Do this inside the loop as ofst is also allocating memory
177  if (destroy) { vector<TempArc> temp; temp.swap(this_vec); }
178  }
179  if (destroy) {
181  repository_.Destroy();
182  }
183  }
184 
185 
186  // Initializer. After initializing the object you will typically
187  // call Determinize() and then call one of the Output functions.
188  // Note: ifst.Copy() will generally do a
189  // shallow copy. We do it like this for memory safety, rather than
190  // keeping a reference or pointer to ifst_.
191  LatticeDeterminizerPruned(const ExpandedFst<Arc> &ifst,
192  double beam,
194  num_arcs_(0), num_elems_(0), ifst_(ifst.Copy()), beam_(beam), opts_(opts),
195  equal_(opts_.delta), determinized_(false),
197  KALDI_ASSERT(Weight::Properties() & kIdempotent); // this algorithm won't
198  // work correctly otherwise.
199  }
200 
202  for (size_t i = 0; i < output_states_.size(); i++)
203  delete output_states_[i];
204  vector<OutputState*> temp;
205  temp.swap(output_states_);
206  }
207 
208  // frees all memory except the info (in output_states_[ ]->arcs)
209  // that we need to output the FST.
210  void FreeMostMemory() {
211  if (ifst_) {
212  delete ifst_;
213  ifst_ = NULL;
214  }
215  { MinimalSubsetHash tmp; tmp.swap(minimal_hash_); }
216 
217  for (size_t i = 0; i < output_states_.size(); i++) {
218  vector<Element> empty_subset;
219  empty_subset.swap(output_states_[i]->minimal_subset);
220  }
221 
222  for (typename InitialSubsetHash::iterator iter = initial_hash_.begin();
223  iter != initial_hash_.end(); ++iter)
224  delete iter->first;
225  { InitialSubsetHash tmp; tmp.swap(initial_hash_); }
226  for (size_t i = 0; i < output_states_.size(); i++) {
227  vector<Element> tmp;
228  tmp.swap(output_states_[i]->minimal_subset);
229  }
230  { vector<char> tmp; tmp.swap(isymbol_or_final_); }
231  { // Free up the queue. I'm not sure how to make sure all
232  // the memory is really freed (no swap() function)... doesn't really
233  // matter much though.
234  while (!queue_.empty()) {
235  Task *t = queue_.top();
236  delete t;
237  queue_.pop();
238  }
239  }
240  { vector<pair<Label, Element> > tmp; tmp.swap(all_elems_tmp_); }
241  }
242 
244  FreeMostMemory();
246  // rest is deleted by destructors.
247  }
248 
249  void RebuildRepository() { // rebuild the string repository,
250  // freeing stuff we don't need.. we call this when memory usage
251  // passes a supplied threshold. We need to accumulate all the
252  // strings we need the repository to "remember", then tell it
253  // to clean the repository.
254  std::vector<StringId> needed_strings;
255  for (size_t i = 0; i < output_states_.size(); i++) {
256  AddStrings(output_states_[i]->minimal_subset, &needed_strings);
257  for (size_t j = 0; j < output_states_[i]->arcs.size(); j++)
258  needed_strings.push_back(output_states_[i]->arcs[j].string);
259  }
260 
261  { // the queue doesn't allow us access to the underlying vector,
262  // so we have to resort to a temporary collection.
263  std::vector<Task*> tasks;
264  while (!queue_.empty()) {
265  Task *task = queue_.top();
266  queue_.pop();
267  tasks.push_back(task);
268  AddStrings(task->subset, &needed_strings);
269  }
270  for (size_t i = 0; i < tasks.size(); i++)
271  queue_.push(tasks[i]);
272  }
273 
274  // the following loop covers strings present in initial_hash_.
275  for (typename InitialSubsetHash::const_iterator
276  iter = initial_hash_.begin();
277  iter != initial_hash_.end(); ++iter) {
278  const vector<Element> &vec = *(iter->first);
279  Element elem = iter->second;
280  AddStrings(vec, &needed_strings);
281  needed_strings.push_back(elem.string);
282  }
283  std::sort(needed_strings.begin(), needed_strings.end());
284  needed_strings.erase(std::unique(needed_strings.begin(),
285  needed_strings.end()),
286  needed_strings.end()); // uniq the strings.
287  KALDI_LOG << "Rebuilding repository.";
288 
289  repository_.Rebuild(needed_strings);
290  }
291 
293  int32 repo_size = repository_.MemSize(),
294  arcs_size = num_arcs_ * sizeof(TempArc),
295  elems_size = num_elems_ * sizeof(Element),
296  total_size = repo_size + arcs_size + elems_size;
297  if (opts_.max_mem > 0 && total_size > opts_.max_mem) { // We passed the memory threshold.
298  // This is usually due to the repository getting large, so we
299  // clean this out.
301  int32 new_repo_size = repository_.MemSize(),
302  new_total_size = new_repo_size + arcs_size + elems_size;
303 
304  KALDI_VLOG(2) << "Rebuilt repository in determinize-lattice: repository shrank from "
305  << repo_size << " to " << new_repo_size << " bytes (approximately)";
306 
307  if (new_total_size > static_cast<int32>(opts_.max_mem * 0.8)) {
308  // Rebuilding didn't help enough-- we need a margin to stop
309  // having to rebuild too often. We'll just return to the user at
310  // this point, with a partial lattice that's pruned tighter than
311  // the specified beam. Here we figure out what the effective
312  // beam was.
313  double effective_beam = beam_;
314  if (!queue_.empty()) { // Note: queue should probably not be empty; we're
315  // just being paranoid here.
316  Task *task = queue_.top();
317  double total_weight = backward_costs_[ifst_->Start()]; // best weight of FST.
318  effective_beam = task->priority_cost - total_weight;
319  }
320  KALDI_WARN << "Did not reach requested beam in determinize-lattice: "
321  << "size exceeds maximum " << opts_.max_mem
322  << " bytes; (repo,arcs,elems) = (" << repo_size << ","
323  << arcs_size << "," << elems_size
324  << "), after rebuilding, repo size was " << new_repo_size
325  << ", effective beam was " << effective_beam
326  << " vs. requested beam " << beam_;
327  return false;
328  }
329  }
330  return true;
331  }
332 
333  bool Determinize(double *effective_beam) {
335  // This determinizes the input fst but leaves it in the "special format"
336  // in "output_arcs_". Must be called after Initialize(). To get the
337  // output, call one of the Output routines.
338 
339  InitializeDeterminization(); // some start-up tasks.
340  while (!queue_.empty()) {
341  Task *task = queue_.top();
342  // Note: the queue contains only tasks that are "within the beam".
343  // We also have to check whether we have reached one of the user-specified
344  // maximums, of estimated memory, arcs, or states. The condition for
345  // ending is:
346  // num-states is more than user specified, OR
347  // num-arcs is more than user specified, OR
348  // memory passed a user-specified threshold and cleanup failed
349  // to get it below that threshold.
350  size_t num_states = output_states_.size();
351  if ((opts_.max_states > 0 && num_states > opts_.max_states) ||
352  (opts_.max_arcs > 0 && num_arcs_ > opts_.max_arcs) ||
353  (num_states % 10 == 0 && !CheckMemoryUsage())) { // note: at some point
354  // it was num_states % 100, not num_states % 10, but I encountered an example
355  // where memory was exhausted before we reached state #100.
356  KALDI_VLOG(1) << "Lattice determinization terminated but not "
357  << " because of lattice-beam. (#states, #arcs) is ( "
358  << output_states_.size() << ", " << num_arcs_
359  << " ), versus limits ( " << opts_.max_states << ", "
360  << opts_.max_arcs << " ) (else, may be memory limit).";
361  break;
362  // we terminate the determinization here-- whatever we already expanded is
363  // what we'll return... because we expanded stuff in order of total
364  // (forward-backward) weight, the stuff we returned first is the most
365  // important.
366  }
367  queue_.pop();
368  ProcessTransition(task->state, task->label, &(task->subset));
369  delete task;
370  }
371  determinized_ = true;
372  if (effective_beam != NULL) {
373  if (queue_.empty()) *effective_beam = beam_;
374  else
375  *effective_beam = queue_.top()->priority_cost -
376  backward_costs_[ifst_->Start()];
377  }
378  return (queue_.empty()); // return success if queue was empty, i.e. we processed
379  // all tasks and did not break out of the loop early due to reaching a memory,
380  // arc or state limit.
381  }
382  private:
383 
384  typedef typename Arc::Label Label;
385  typedef typename Arc::StateId StateId; // use this when we don't know if it's input or output.
386  typedef typename Arc::StateId InputStateId; // state in the input FST.
387  typedef typename Arc::StateId OutputStateId; // same as above but distinguish
388  // states in output Fst.
389 
391  typedef const typename StringRepositoryType::Entry* StringId;
392 
393  // Element of a subset [of original states]
394  struct Element {
395  StateId state; // use StateId as this is usually InputStateId but in one case
396  // OutputStateId.
397  StringId string;
399  bool operator != (const Element &other) const {
400  return (state != other.state || string != other.string ||
401  weight != other.weight);
402  }
403  // This operator is only intended for the priority_queue in the function
404  // EpsilonClosure().
405  bool operator > (const Element &other) const {
406  return state > other.state;
407  }
408  // This operator is only intended to support sorting in EpsilonClosure()
409  bool operator < (const Element &other) const {
410  return state < other.state;
411  }
412  };
413 
414  // Arcs in the format we temporarily create in this class (a representation, essentially of
415  // a Gallic Fst).
416  struct TempArc {
417  Label ilabel;
418  StringId string; // Look it up in the StringRepository, it's a sequence of Labels.
419  OutputStateId nextstate; // or kNoState for final weights.
421  };
422 
423  // Hashing function used in hash of subsets.
424  // A subset is a pointer to vector<Element>.
425  // The Elements are in sorted order on state id, and without repeated states.
426  // Because the order of Elements is fixed, we can use a hashing function that is
427  // order-dependent. However the weights are not included in the hashing function--
428  // we hash subsets that differ only in weight to the same key. This is not optimal
429  // in terms of the O(N) performance but typically if we have a lot of determinized
430  // states that differ only in weight then the input probably was pathological in some way,
431  // or even non-determinizable.
432  // We don't quantize the weights, in order to avoid inexactness in simple cases.
433  // Instead we apply the delta when comparing subsets for equality, and allow a small
434  // difference.
435 
436  class SubsetKey {
437  public:
438  size_t operator ()(const vector<Element> * subset) const { // hashes only the state and string.
439  size_t hash = 0, factor = 1;
440  for (typename vector<Element>::const_iterator iter= subset->begin(); iter != subset->end(); ++iter) {
441  hash *= factor;
442  hash += iter->state + reinterpret_cast<size_t>(iter->string);
443  factor *= 23531; // these numbers are primes.
444  }
445  return hash;
446  }
447  };
448 
449  // This is the equality operator on subsets. It checks for exact match on state-id
450  // and string, and approximate match on weights.
451  class SubsetEqual {
452  public:
453  bool operator ()(const vector<Element> * s1, const vector<Element> * s2) const {
454  size_t sz = s1->size();
455  KALDI_ASSERT(sz>=0);
456  if (sz != s2->size()) return false;
457  typename vector<Element>::const_iterator iter1 = s1->begin(),
458  iter1_end = s1->end(), iter2=s2->begin();
459  for (; iter1 < iter1_end; ++iter1, ++iter2) {
460  if (iter1->state != iter2->state ||
461  iter1->string != iter2->string ||
462  ! ApproxEqual(iter1->weight, iter2->weight, delta_)) return false;
463  }
464  return true;
465  }
466  float delta_;
467  SubsetEqual(float delta): delta_(delta) {}
468  SubsetEqual(): delta_(kDelta) {}
469  };
470 
471  // Operator that says whether two Elements have the same states.
472  // Used only for debug.
474  public:
475  bool operator ()(const vector<Element> * s1, const vector<Element> * s2) const {
476  size_t sz = s1->size();
477  KALDI_ASSERT(sz>=0);
478  if (sz != s2->size()) return false;
479  typename vector<Element>::const_iterator iter1 = s1->begin(),
480  iter1_end = s1->end(), iter2=s2->begin();
481  for (; iter1 < iter1_end; ++iter1, ++iter2) {
482  if (iter1->state != iter2->state) return false;
483  }
484  return true;
485  }
486  };
487 
488  // Define the hash type we use to map subsets (in minimal
489  // representation) to OutputStateId.
490  typedef unordered_map<const vector<Element>*, OutputStateId,
492 
493  // Define the hash type we use to map subsets (in initial
494  // representation) to OutputStateId, together with an
495  // extra weight. [note: we interpret the Element.state in here
496  // as an OutputStateId even though it's declared as InputStateId;
497  // these types are the same anyway].
498  typedef unordered_map<const vector<Element>*, Element,
500 
501 
502  // converts the representation of the subset from canonical (all states) to
503  // minimal (only states with output symbols on arcs leaving them, and final
504  // states). Output is not necessarily normalized, even if input_subset was.
505  void ConvertToMinimal(vector<Element> *subset) {
506  KALDI_ASSERT(!subset->empty());
507  typename vector<Element>::iterator cur_in = subset->begin(),
508  cur_out = subset->begin(), end = subset->end();
509  while (cur_in != end) {
510  if(IsIsymbolOrFinal(cur_in->state)) { // keep it...
511  *cur_out = *cur_in;
512  cur_out++;
513  }
514  cur_in++;
515  }
516  subset->resize(cur_out - subset->begin());
517  }
518 
519  // Takes a minimal, normalized subset, and converts it to an OutputStateId.
520  // Involves a hash lookup, and possibly adding a new OutputStateId.
521  // If it creates a new OutputStateId, it creates a new record for it, works
522  // out its final-weight, and puts stuff on the queue relating to its
523  // transitions.
524  OutputStateId MinimalToStateId(const vector<Element> &subset,
525  const double forward_cost) {
526  typename MinimalSubsetHash::const_iterator iter
527  = minimal_hash_.find(&subset);
528  if (iter != minimal_hash_.end()) { // Found a matching subset.
529  OutputStateId state_id = iter->second;
530  const OutputState &state = *(output_states_[state_id]);
531  // Below is just a check that the algorithm is working...
532  if (forward_cost < state.forward_cost - 0.1) {
533  // for large weights, this check could fail due to roundoff.
534  KALDI_WARN << "New cost is less (check the difference is small) "
535  << forward_cost << ", "
536  << state.forward_cost;
537  }
538  return state_id;
539  }
540  OutputStateId state_id = static_cast<OutputStateId>(output_states_.size());
541  OutputState *new_state = new OutputState(subset, forward_cost);
542  minimal_hash_[&(new_state->minimal_subset)] = state_id;
543  output_states_.push_back(new_state);
544  num_elems_ += subset.size();
545  // Note: in the previous algorithm, we pushed the new state-id onto the queue
546  // at this point. Here, the queue happens elsewhere, and we directly process
547  // the state (which result in stuff getting added to the queue).
548  ProcessFinal(state_id); // will work out the final-prob.
549  ProcessTransitions(state_id); // will process transitions and add stuff to the queue.
550  return state_id;
551  }
552 
553 
554  // Given a normalized initial subset of elements (i.e. before epsilon closure),
555  // compute the corresponding output-state.
556  OutputStateId InitialToStateId(const vector<Element> &subset_in,
557  double forward_cost,
558  Weight *remaining_weight,
559  StringId *common_prefix) {
560  typename InitialSubsetHash::const_iterator iter
561  = initial_hash_.find(&subset_in);
562  if (iter != initial_hash_.end()) { // Found a matching subset.
563  const Element &elem = iter->second;
564  *remaining_weight = elem.weight;
565  *common_prefix = elem.string;
566  if (elem.weight == Weight::Zero())
567  KALDI_WARN << "Zero weight!";
568  return elem.state;
569  }
570  // else no matching subset-- have to work it out.
571  vector<Element> subset(subset_in);
572  // Follow through epsilons. Will add no duplicate states. note: after
573  // EpsilonClosure, it is the same as "canonical" subset, except not
574  // normalized (actually we never compute the normalized canonical subset,
575  // only the normalized minimal one).
576  EpsilonClosure(&subset); // follow epsilons.
577  ConvertToMinimal(&subset); // remove all but emitting and final states.
578 
579  Element elem; // will be used to store remaining weight and string, and
580  // OutputStateId, in initial_hash_;
581  NormalizeSubset(&subset, &elem.weight, &elem.string); // normalize subset; put
582  // common string and weight in "elem". The subset is now a minimal,
583  // normalized subset.
584 
585  forward_cost += ConvertToCost(elem.weight);
586  OutputStateId ans = MinimalToStateId(subset, forward_cost);
587  *remaining_weight = elem.weight;
588  *common_prefix = elem.string;
589  if (elem.weight == Weight::Zero())
590  KALDI_WARN << "Zero weight!";
591 
592  // Before returning "ans", add the initial subset to the hash,
593  // so that we can bypass the epsilon-closure etc., next time
594  // we process the same initial subset.
595  vector<Element> *initial_subset_ptr = new vector<Element>(subset_in);
596  elem.state = ans;
597  initial_hash_[initial_subset_ptr] = elem;
598  num_elems_ += initial_subset_ptr->size(); // keep track of memory usage.
599  return ans;
600  }
601 
602  // returns the Compare value (-1 if a < b, 0 if a == b, 1 if a > b) according
603  // to the ordering we defined on strings for the CompactLatticeWeightTpl.
604  // see function
605  // inline int Compare (const CompactLatticeWeightTpl<WeightType,IntType> &w1,
606  // const CompactLatticeWeightTpl<WeightType,IntType> &w2)
607  // in lattice-weight.h.
608  // this is the same as that, but optimized for our data structures.
609  inline int Compare(const Weight &a_w, StringId a_str,
610  const Weight &b_w, StringId b_str) const {
611  int weight_comp = fst::Compare(a_w, b_w);
612  if (weight_comp != 0) return weight_comp;
613  // now comparing strings.
614  if (a_str == b_str) return 0;
615  vector<IntType> a_vec, b_vec;
616  repository_.ConvertToVector(a_str, &a_vec);
617  repository_.ConvertToVector(b_str, &b_vec);
618  // First compare their lengths.
619  int a_len = a_vec.size(), b_len = b_vec.size();
620  // use opposite order on the string lengths (c.f. Compare in
621  // lattice-weight.h)
622  if (a_len > b_len) return -1;
623  else if (a_len < b_len) return 1;
624  for(int i = 0; i < a_len; i++) {
625  if (a_vec[i] < b_vec[i]) return -1;
626  else if (a_vec[i] > b_vec[i]) return 1;
627  }
628  KALDI_ASSERT(0); // because we checked if a_str == b_str above, shouldn't reach here
629  return 0;
630  }
631 
632  // This function computes epsilon closure of subset of states by following epsilon links.
633  // Called by InitialToStateId and Initialize.
634  // Has no side effects except on the string repository. The "output_subset" is not
635  // necessarily normalized (in the sense of there being no common substring), unless
636  // input_subset was.
637  void EpsilonClosure(vector<Element> *subset) {
638  // at input, subset must have only one example of each StateId. [will still
639  // be so at output]. This function follows input-epsilons, and augments the
640  // subset accordingly.
641 
642  std::priority_queue<Element, vector<Element>, greater<Element> > queue;
643  unordered_map<InputStateId, Element> cur_subset;
644  typedef typename unordered_map<InputStateId, Element>::iterator MapIter;
645  typedef typename vector<Element>::const_iterator VecIter;
646 
647  for (VecIter iter = subset->begin(); iter != subset->end(); ++iter) {
648  queue.push(*iter);
649  cur_subset[iter->state] = *iter;
650  }
651 
652  // find whether input fst is known to be sorted on input label.
653  bool sorted = ((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
654  bool replaced_elems = false; // relates to an optimization, see below.
655  int counter = 0; // stops infinite loops here for non-lattice-determinizable input
656  // (e.g. input with negative-cost epsilon loops); useful in testing.
657  while (queue.size() != 0) {
658  Element elem = queue.top();
659  queue.pop();
660 
661  // The next if-statement is a kind of optimization. It's to prevent us
662  // unnecessarily repeating the processing of a state. "cur_subset" always
663  // contains only one Element with a particular state. The issue is that
664  // whenever we modify the Element corresponding to that state in "cur_subset",
665  // both the new (optimal) and old (less-optimal) Element will still be in
666  // "queue". The next if-statement stops us from wasting compute by
667  // processing the old Element.
668  if (replaced_elems && cur_subset[elem.state] != elem)
669  continue;
670  if (opts_.max_loop > 0 && counter++ > opts_.max_loop) {
671  KALDI_ERR << "Lattice determinization aborted since looped more than "
672  << opts_.max_loop << " times during epsilon closure.";
673  }
674  for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, elem.state); !aiter.Done(); aiter.Next()) {
675  const Arc &arc = aiter.Value();
676  if (sorted && arc.ilabel != 0) break; // Break from the loop: due to sorting there will be no
677  // more transitions with epsilons as input labels.
678  if (arc.ilabel == 0
679  && arc.weight != Weight::Zero()) { // Epsilon transition.
680  Element next_elem;
681  next_elem.state = arc.nextstate;
682  next_elem.weight = Times(elem.weight, arc.weight);
683  // next_elem.string is not set up yet... create it only
684  // when we know we need it (this is an optimization)
685 
686  MapIter iter = cur_subset.find(next_elem.state);
687  if (iter == cur_subset.end()) {
688  // was no such StateId: insert and add to queue.
689  next_elem.string = (arc.olabel == 0 ? elem.string :
690  repository_.Successor(elem.string, arc.olabel));
691  cur_subset[next_elem.state] = next_elem;
692  queue.push(next_elem);
693  } else {
694  // was not inserted because one already there. In normal
695  // determinization we'd add the weights. Here, we find which one
696  // has the better weight, and keep its corresponding string.
697  int comp = fst::Compare(next_elem.weight, iter->second.weight);
698  if (comp == 0) { // A tie on weights. This should be a rare case;
699  // we don't optimize for it.
700  next_elem.string = (arc.olabel == 0 ? elem.string :
701  repository_.Successor(elem.string,
702  arc.olabel));
703  comp = Compare(next_elem.weight, next_elem.string,
704  iter->second.weight, iter->second.string);
705  }
706  if(comp == 1) { // next_elem is better, so use its (weight, string)
707  next_elem.string = (arc.olabel == 0 ? elem.string :
708  repository_.Successor(elem.string, arc.olabel));
709  iter->second.string = next_elem.string;
710  iter->second.weight = next_elem.weight;
711  queue.push(next_elem);
712  replaced_elems = true;
713  }
714  // else it is the same or worse, so use original one.
715  }
716  }
717  }
718  }
719 
720  { // copy cur_subset to subset.
721  subset->clear();
722  subset->reserve(cur_subset.size());
723  MapIter iter = cur_subset.begin(), end = cur_subset.end();
724  for (; iter != end; ++iter) subset->push_back(iter->second);
725  // sort by state ID, because the subset hash function is order-dependent(see SubsetKey)
726  std::sort(subset->begin(), subset->end());
727  }
728  }
729 
730 
731  // This function works out the final-weight of the determinized state.
732  // called by ProcessSubset.
733  // Has no side effects except on the variable repository_, and
734  // output_states_[output_state_id].arcs
735 
736  void ProcessFinal(OutputStateId output_state_id) {
737  OutputState &state = *(output_states_[output_state_id]);
738  const vector<Element> &minimal_subset = state.minimal_subset;
739  // processes final-weights for this subset. state.minimal_subset_ may be
740  // empty if the graphs is not connected/trimmed, I think, do don't check
741  // that it's nonempty.
742  StringId final_string = repository_.EmptyString(); // set it to keep the
743  // compiler happy; if it doesn't get set in the loop, we won't use the value anyway.
744  Weight final_weight = Weight::Zero();
745  bool is_final = false;
746  typename vector<Element>::const_iterator iter = minimal_subset.begin(), end = minimal_subset.end();
747  for (; iter != end; ++iter) {
748  const Element &elem = *iter;
749  Weight this_final_weight = Times(elem.weight, ifst_->Final(elem.state));
750  StringId this_final_string = elem.string;
751  if (this_final_weight != Weight::Zero() &&
752  (!is_final || Compare(this_final_weight, this_final_string,
753  final_weight, final_string) == 1)) { // the new
754  // (weight, string) pair is more in semiring than our current
755  // one.
756  is_final = true;
757  final_weight = this_final_weight;
758  final_string = this_final_string;
759  }
760  }
761  if (is_final &&
762  ConvertToCost(final_weight) + state.forward_cost <= cutoff_) {
763  // store final weights in TempArc structure, just like a transition.
764  // Note: we only store the final-weight if it's inside the pruning beam, hence
765  // the stuff with Compare.
766  TempArc temp_arc;
767  temp_arc.ilabel = 0;
768  temp_arc.nextstate = kNoStateId; // special marker meaning "final weight".
769  temp_arc.string = final_string;
770  temp_arc.weight = final_weight;
771  state.arcs.push_back(temp_arc);
772  num_arcs_++;
773  }
774  }
775 
776  // NormalizeSubset normalizes the subset "elems" by
777  // removing any common string prefix (putting it in common_str),
778  // and dividing by the total weight (putting it in tot_weight).
779  void NormalizeSubset(vector<Element> *elems,
780  Weight *tot_weight,
781  StringId *common_str) {
782  if(elems->empty()) { // just set common_str, tot_weight
783  // to defaults and return...
784  KALDI_WARN << "empty subset";
785  *common_str = repository_.EmptyString();
786  *tot_weight = Weight::Zero();
787  return;
788  }
789  size_t size = elems->size();
790  vector<IntType> common_prefix;
791  repository_.ConvertToVector((*elems)[0].string, &common_prefix);
792  Weight weight = (*elems)[0].weight;
793  for(size_t i = 1; i < size; i++) {
794  weight = Plus(weight, (*elems)[i].weight);
795  repository_.ReduceToCommonPrefix((*elems)[i].string, &common_prefix);
796  }
797  KALDI_ASSERT(weight != Weight::Zero()); // we made sure to ignore arcs with zero
798  // weights on them, so we shouldn't have zero here.
799  size_t prefix_len = common_prefix.size();
800  for(size_t i = 0; i < size; i++) {
801  (*elems)[i].weight = Divide((*elems)[i].weight, weight, DIVIDE_LEFT);
802  (*elems)[i].string =
803  repository_.RemovePrefix((*elems)[i].string, prefix_len);
804  }
805  *common_str = repository_.ConvertFromVector(common_prefix);
806  *tot_weight = weight;
807  }
808 
809  // Take a subset of Elements that is sorted on state, and
810  // merge any Elements that have the same state (taking the best
811  // (weight, string) pair in the semiring).
812  void MakeSubsetUnique(vector<Element> *subset) {
813  typedef typename vector<Element>::iterator IterType;
814 
815  // This KALDI_ASSERT is designed to fail (usually) if the subset is not sorted on
816  // state.
817  KALDI_ASSERT(subset->size() < 2 || (*subset)[0].state <= (*subset)[1].state);
818 
819  IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
820  size_t num_out = 0;
821  // Merge elements with same state-id
822  while (cur_in != end) { // while we have more elements to process.
823  // At this point, cur_out points to location of next place we want to put an element,
824  // cur_in points to location of next element we want to process.
825  if (cur_in != cur_out) *cur_out = *cur_in;
826  cur_in++;
827  while (cur_in != end && cur_in->state == cur_out->state) {
828  if (Compare(cur_in->weight, cur_in->string,
829  cur_out->weight, cur_out->string) == 1) {
830  // if *cur_in > *cur_out in semiring, then take *cur_in.
831  cur_out->string = cur_in->string;
832  cur_out->weight = cur_in->weight;
833  }
834  cur_in++;
835  }
836  cur_out++;
837  num_out++;
838  }
839  subset->resize(num_out);
840  }
841 
842  // ProcessTransition was called from "ProcessTransitions" in the non-pruned
843  // code, but now we in effect put the calls to ProcessTransition on a priority
844  // queue, and it now gets called directly from Determinize(). This function
845  // processes a transition from state "ostate_id". The set "subset" of Elements
846  // represents a set of next-states with associated weights and strings, each
847  // one arising from an arc from some state in a determinized-state; the
848  // next-states are unique (there is only one Entry assocated with each)
849  void ProcessTransition(OutputStateId ostate_id, Label ilabel, vector<Element> *subset) {
850 
851  double forward_cost = output_states_[ostate_id]->forward_cost;
852  StringId common_str;
853  Weight tot_weight;
854  NormalizeSubset(subset, &tot_weight, &common_str);
855  forward_cost += ConvertToCost(tot_weight);
856 
857  OutputStateId nextstate;
858  {
859  Weight next_tot_weight;
860  StringId next_common_str;
861  nextstate = InitialToStateId(*subset,
862  forward_cost,
863  &next_tot_weight,
864  &next_common_str);
865  common_str = repository_.Concatenate(common_str, next_common_str);
866  tot_weight = Times(tot_weight, next_tot_weight);
867  }
868 
869  // Now add an arc to the next state (would have been created if necessary by
870  // InitialToStateId).
871  TempArc temp_arc;
872  temp_arc.ilabel = ilabel;
873  temp_arc.nextstate = nextstate;
874  temp_arc.string = common_str;
875  temp_arc.weight = tot_weight;
876  output_states_[ostate_id]->arcs.push_back(temp_arc); // record the arc.
877  num_arcs_++;
878  }
879 
880 
881  // "less than" operator for pair<Label, Element>. Used in ProcessTransitions.
882  // Lexicographical order, which only compares the state when ordering the
883  // "Element" member of the pair.
884 
886  public:
887  inline bool operator () (const pair<Label, Element> &p1, const pair<Label, Element> &p2) {
888  if (p1.first < p2.first) return true;
889  else if (p1.first > p2.first) return false;
890  else {
891  return p1.second.state < p2.second.state;
892  }
893  }
894  };
895 
896 
897  // ProcessTransitions processes emitting transitions (transitions with
898  // ilabels) out of this subset of states. It actualy only creates records
899  // ("Task") that get added to the queue. The transitions will be processed in
900  // priority order from Determinize(). This function soes not consider final
901  // states. Partitions the emitting transitions up by ilabel (by sorting on
902  // ilabel), and for each unique ilabel, it creates a Task record that contains
903  // the information we need to process the transition.
904 
905  void ProcessTransitions(OutputStateId output_state_id) {
906  const vector<Element> &minimal_subset = output_states_[output_state_id]->minimal_subset;
907  // it's possible that minimal_subset could be empty if there are
908  // unreachable parts of the graph, so don't check that it's nonempty.
909  vector<pair<Label, Element> > &all_elems(all_elems_tmp_); // use class member
910  // to avoid memory allocation/deallocation.
911  {
912  // Push back into "all_elems", elements corresponding to all
913  // non-epsilon-input transitions out of all states in "minimal_subset".
914  typename vector<Element>::const_iterator iter = minimal_subset.begin(), end = minimal_subset.end();
915  for (;iter != end; ++iter) {
916  const Element &elem = *iter;
917  for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, elem.state); ! aiter.Done(); aiter.Next()) {
918  const Arc &arc = aiter.Value();
919  if (arc.ilabel != 0
920  && arc.weight != Weight::Zero()) { // Non-epsilon transition -- ignore epsilons here.
921  pair<Label, Element> this_pr;
922  this_pr.first = arc.ilabel;
923  Element &next_elem(this_pr.second);
924  next_elem.state = arc.nextstate;
925  next_elem.weight = Times(elem.weight, arc.weight);
926  if (arc.olabel == 0) // output epsilon
927  next_elem.string = elem.string;
928  else
929  next_elem.string = repository_.Successor(elem.string, arc.olabel);
930  all_elems.push_back(this_pr);
931  }
932  }
933  }
934  }
935  PairComparator pc;
936  std::sort(all_elems.begin(), all_elems.end(), pc);
937  // now sorted first on input label, then on state.
938  typedef typename vector<pair<Label, Element> >::const_iterator PairIter;
939  PairIter cur = all_elems.begin(), end = all_elems.end();
940  while (cur != end) {
941  // The old code (non-pruned) called ProcessTransition; here, instead,
942  // we'll put the calls into a priority queue.
943  Task *task = new Task;
944  // Process ranges that share the same input symbol.
945  Label ilabel = cur->first;
946  task->state = output_state_id;
947  task->priority_cost = std::numeric_limits<double>::infinity();
948  task->label = ilabel;
949  while (cur != end && cur->first == ilabel) {
950  task->subset.push_back(cur->second);
951  const Element &element = cur->second;
952  // Note: we'll later include the term "forward_cost" in the
953  // priority_cost.
954  task->priority_cost = std::min(task->priority_cost,
955  ConvertToCost(element.weight) +
956  backward_costs_[element.state]);
957  cur++;
958  }
959 
960  // After the command below, the "priority_cost" is a value comparable to
961  // the total-weight of the input FST, like a total-path weight... of
962  // course, it will typically be less (in the semiring) than that.
963  // note: we represent it just as a double.
964  task->priority_cost += output_states_[output_state_id]->forward_cost;
965 
966  if (task->priority_cost > cutoff_) {
967  // This task would never get done as it's past the pruning cutoff.
968  delete task;
969  } else {
970  MakeSubsetUnique(&(task->subset)); // remove duplicate Elements with the same state.
971  queue_.push(task); // Push the task onto the queue. The queue keeps it
972  // in prioritized order, so we always process the one with the "best"
973  // weight (highest in the semiring).
974 
975  { // this is a check.
976  double best_cost = backward_costs_[ifst_->Start()],
977  tolerance = 0.01 + 1.0e-04 * std::abs(best_cost);
978  if (task->priority_cost < best_cost - tolerance) {
979  KALDI_WARN << "Cost below best cost was encountered:"
980  << task->priority_cost << " < " << best_cost;
981  }
982  }
983  }
984  }
985  all_elems.clear(); // as it's a reference to a class variable; we want it to stay
986  // empty.
987  }
988 
989 
990  bool IsIsymbolOrFinal(InputStateId state) { // returns true if this state
991  // of the input FST either is final or has an osymbol on an arc out of it.
992  // Uses the vector isymbol_or_final_ as a cache for this info.
993  KALDI_ASSERT(state >= 0);
994  if (isymbol_or_final_.size() <= state)
995  isymbol_or_final_.resize(state+1, static_cast<char>(OSF_UNKNOWN));
996  if (isymbol_or_final_[state] == static_cast<char>(OSF_NO))
997  return false;
998  else if (isymbol_or_final_[state] == static_cast<char>(OSF_YES))
999  return true;
1000  // else work it out...
1001  isymbol_or_final_[state] = static_cast<char>(OSF_NO);
1002  if (ifst_->Final(state) != Weight::Zero())
1003  isymbol_or_final_[state] = static_cast<char>(OSF_YES);
1004  for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, state);
1005  !aiter.Done();
1006  aiter.Next()) {
1007  const Arc &arc = aiter.Value();
1008  if (arc.ilabel != 0 && arc.weight != Weight::Zero()) {
1009  isymbol_or_final_[state] = static_cast<char>(OSF_YES);
1010  return true;
1011  }
1012  }
1013  return IsIsymbolOrFinal(state); // will only recurse once.
1014  }
1015 
1017  // Sets up the backward_costs_ array, and the cutoff_ variable.
1018  KALDI_ASSERT(beam_ > 0);
1019 
1020  // Only handle the toplogically sorted case.
1021  backward_costs_.resize(ifst_->NumStates());
1022  for (StateId s = ifst_->NumStates() - 1; s >= 0; s--) {
1023  double &cost = backward_costs_[s];
1024  cost = ConvertToCost(ifst_->Final(s));
1025  for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, s);
1026  !aiter.Done(); aiter.Next()) {
1027  const Arc &arc = aiter.Value();
1028  cost = std::min(cost,
1029  ConvertToCost(arc.weight) + backward_costs_[arc.nextstate]);
1030  }
1031  }
1032 
1033  if (ifst_->Start() == kNoStateId) return; // we'll be returning
1034  // an empty FST.
1035 
1036  double best_cost = backward_costs_[ifst_->Start()];
1037  if (best_cost == std::numeric_limits<double>::infinity())
1038  KALDI_WARN << "Total weight of input lattice is zero.";
1039  cutoff_ = best_cost + beam_;
1040  }
1041 
1043  // We insist that the input lattice be topologically sorted. This is not a
1044  // fundamental limitation of the algorithm (which in principle should be
1045  // applicable to even cyclic FSTs), but it helps us more efficiently
1046  // compute the backward_costs_ array. There may be some other reason we
1047  // require this, that escapes me at the moment.
1048  KALDI_ASSERT(ifst_->Properties(kTopSorted, true) != 0);
1050 #if !(__GNUC__ == 4 && __GNUC_MINOR__ == 0)
1051  if(ifst_->Properties(kExpanded, false) != 0) { // if we know the number of
1052  // states in ifst_, it might be a bit more efficient
1053  // to pre-size the hashes so we're not constantly rebuilding them.
1054  StateId num_states =
1055  down_cast<const ExpandedFst<Arc>*, const Fst<Arc> >(ifst_)->NumStates();
1056  minimal_hash_.rehash(num_states/2 + 3);
1057  initial_hash_.rehash(num_states/2 + 3);
1058  }
1059 #endif
1060  InputStateId start_id = ifst_->Start();
1061  if (start_id != kNoStateId) {
1062  /* Create determinized-state corresponding to the start state....
1063  Unlike all the other states, we don't "normalize" the representation
1064  of this determinized-state before we put it into minimal_hash_. This is actually
1065  what we want, as otherwise we'd have problems dealing with any extra weight
1066  and string and might have to create a "super-initial" state which would make
1067  the output nondeterministic. Normalization is only needed to make the
1068  determinized output more minimal anyway, it's not needed for correctness.
1069  Note, we don't put anything in the initial_hash_. The initial_hash_ is only
1070  a lookaside buffer anyway, so this isn't a problem-- it will get populated
1071  later if it needs to be.
1072  */
1073  vector<Element> subset(1);
1074  subset[0].state = start_id;
1075  subset[0].weight = Weight::One();
1076  subset[0].string = repository_.EmptyString(); // Id of empty sequence.
1077  EpsilonClosure(&subset); // follow through epsilon-input links
1078  ConvertToMinimal(&subset); // remove all but final states and
1079  // states with input-labels on arcs out of them.
1080  // Weight::One() is the "forward-weight" of this determinized state...
1081  // i.e. the minimal cost from the start of the determinized FST to this
1082  // state [One() because it's the start state].
1083  OutputState *initial_state = new OutputState(subset, 0);
1084  KALDI_ASSERT(output_states_.empty());
1085  output_states_.push_back(initial_state);
1086  num_elems_ += subset.size();
1087  OutputStateId initial_state_id = 0;
1088  minimal_hash_[&(initial_state->minimal_subset)] = initial_state_id;
1089  ProcessFinal(initial_state_id);
1090  ProcessTransitions(initial_state_id); // this will add tasks to
1091  // the queue, which we'll start processing in Determinize().
1092  }
1093  }
1094 
1096 
1097  struct OutputState {
1098  vector<Element> minimal_subset;
1099  vector<TempArc> arcs; // arcs out of the state-- those that have been processed.
1100  // Note: the final-weight is included here with kNoStateId as the state id. We
1101  // always process the final-weight regardless of the beam; when producing the
1102  // output we may have to ignore some of these.
1103  double forward_cost; // Represents minimal cost from start-state
1104  // to this state. Used in prioritization of tasks, and pruning.
1105  // Note: we know this minimal cost from when we first create the OutputState;
1106  // this is because of the priority-queue we use, that ensures that the
1107  // "best" path into the state will be expanded first.
1108  OutputState(const vector<Element> &minimal_subset,
1109  double forward_cost): minimal_subset(minimal_subset),
1110  forward_cost(forward_cost) { }
1111  };
1112 
1113  vector<OutputState*> output_states_; // All the info about the output states.
1114 
1115  int num_arcs_; // keep track of memory usage: number of arcs in output_states_[ ]->arcs
1116  int num_elems_; // keep track of memory usage: number of elems in output_states_ and
1117  // the keys of initial_hash_
1118 
1119  const ExpandedFst<Arc> *ifst_;
1120  std::vector<double> backward_costs_; // This vector stores, for every state in ifst_,
1121  // the minimal cost to the end-state (i.e. the sum of weights; they are guaranteed to
1122  // have "take-the-minimum" semantics). We get the double from the ConvertToCost()
1123  // function on the lattice weights.
1124 
1125  double beam_;
1126  double cutoff_; // beam plus total-weight of input (and note, the weight is
1127  // guaranteed to be "tropical-like" so the sum does represent a min-cost.
1128 
1130  SubsetKey hasher_; // object that computes keys-- has no data members.
1131  SubsetEqual equal_; // object that compares subsets-- only data member is delta_.
1132  bool determinized_; // set to true when user called Determinize(); used to make
1133  // sure this object is used correctly.
1134  MinimalSubsetHash minimal_hash_; // hash from Subset to OutputStateId. Subset is "minimal
1135  // representation" (only include final and states and states with
1136  // nonzero ilabel on arc out of them. Owns the pointers
1137  // in its keys.
1138  InitialSubsetHash initial_hash_; // hash from Subset to Element, which
1139  // represents the OutputStateId together
1140  // with an extra weight and string. Subset
1141  // is "initial representation". The extra
1142  // weight and string is needed because after
1143  // we convert to minimal representation and
1144  // normalize, there may be an extra weight
1145  // and string. Owns the pointers
1146  // in its keys.
1147 
1148  struct Task {
1149  OutputStateId state; // State from which we're processing the transition.
1150  Label label; // Label on the transition we're processing out of this state.
1151  vector<Element> subset; // Weighted subset of states (with strings)-- not normalized.
1152  double priority_cost; // Cost used in deciding priority of tasks. Note:
1153  // we assume there is a ConvertToCost() function that converts the semiring to double.
1154  };
1155 
1156  struct TaskCompare {
1157  inline int operator() (const Task *t1, const Task *t2) {
1158  // view this like operator <, which is the default template parameter
1159  // to std::priority_queue.
1160  // returns true if t1 is worse than t2.
1161  return (t1->priority_cost > t2->priority_cost);
1162  }
1163  };
1164 
1165  // This priority queue contains "Task"s to be processed; these correspond
1166  // to transitions out of determinized states. We process these in priority
1167  // order according to the best weight of any path passing through these
1168  // determinized states... it's possible to work this out.
1169  std::priority_queue<Task*, vector<Task*>, TaskCompare> queue_;
1170 
1171  vector<pair<Label, Element> > all_elems_tmp_; // temporary vector used in ProcessTransitions.
1172 
1173  enum IsymbolOrFinal { OSF_UNKNOWN = 0, OSF_NO = 1, OSF_YES = 2 };
1174 
1175  vector<char> isymbol_or_final_; // A kind of cache; it says whether
1176  // each state is (emitting or final) where emitting means it has at least one
1177  // non-epsilon output arc. Only accessed by IsIsymbolOrFinal()
1178 
1179  LatticeStringRepository<IntType> repository_; // defines a compact and fast way of
1180  // storing sequences of labels.
1181 
1182  void AddStrings(const vector<Element> &vec,
1183  vector<StringId> *needed_strings) {
1184  for (typename std::vector<Element>::const_iterator iter = vec.begin();
1185  iter != vec.end(); ++iter)
1186  needed_strings->push_back(iter->string);
1187  }
1188 };
1189 
1190 
1191 // normally Weight would be LatticeWeight<float> (which has two floats),
1192 // or possibly TropicalWeightTpl<float>, and IntType would be int32.
1193 // Caution: there are two versions of the function DeterminizeLatticePruned,
1194 // with identical code but different output FST types.
1195 template<class Weight, class IntType>
1197  const ExpandedFst<ArcTpl<Weight> >&ifst,
1198  double beam,
1199  MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > >*ofst,
1201  ofst->SetInputSymbols(ifst.InputSymbols());
1202  ofst->SetOutputSymbols(ifst.OutputSymbols());
1203  if (ifst.NumStates() == 0) {
1204  ofst->DeleteStates();
1205  return true;
1206  }
1207  KALDI_ASSERT(opts.retry_cutoff >= 0.0 && opts.retry_cutoff < 1.0);
1208  int32 max_num_iters = 10; // avoid the potential for infinite loops if
1209  // retrying.
1210  VectorFst<ArcTpl<Weight> > temp_fst;
1211 
1212  for (int32 iter = 0; iter < max_num_iters; iter++) {
1213  LatticeDeterminizerPruned<Weight, IntType> det(iter == 0 ? ifst : temp_fst,
1214  beam, opts);
1215  double effective_beam;
1216  bool ans = det.Determinize(&effective_beam);
1217  // if it returns false it will typically still produce reasonable output,
1218  // just with a narrower beam than "beam". If the user specifies an infinite
1219  // beam we don't do this beam-narrowing.
1220  if (effective_beam >= beam * opts.retry_cutoff ||
1221  beam == std::numeric_limits<double>::infinity() ||
1222  iter + 1 == max_num_iters) {
1223  det.Output(ofst);
1224  return ans;
1225  } else {
1226  // The code below to set "beam" is a heuristic.
1227  // If effective_beam is very small, we want to reduce by a lot.
1228  // But never change the beam by more than a factor of two.
1229  if (effective_beam < 0.0) effective_beam = 0.0;
1230  double new_beam = beam * sqrt(effective_beam / beam);
1231  if (new_beam < 0.5 * beam) new_beam = 0.5 * beam;
1232  beam = new_beam;
1233  if (iter == 0) temp_fst = ifst;
1234  kaldi::PruneLattice(beam, &temp_fst);
1235  KALDI_LOG << "Pruned state-level lattice with beam " << beam
1236  << " and retrying determinization with that beam.";
1237  }
1238  }
1239  return false; // Suppress compiler warning; this code is unreachable.
1240 }
1241 
1242 
1243 // normally Weight would be LatticeWeight<float> (which has two floats),
1244 // or possibly TropicalWeightTpl<float>, and IntType would be int32.
1245 // Caution: there are two versions of the function DeterminizeLatticePruned,
1246 // with identical code but different output FST types.
1247 template<class Weight>
1248 bool DeterminizeLatticePruned(const ExpandedFst<ArcTpl<Weight> > &ifst,
1249  double beam,
1250  MutableFst<ArcTpl<Weight> > *ofst,
1252  typedef int32 IntType;
1253  ofst->SetInputSymbols(ifst.InputSymbols());
1254  ofst->SetOutputSymbols(ifst.OutputSymbols());
1255  KALDI_ASSERT(opts.retry_cutoff >= 0.0 && opts.retry_cutoff < 1.0);
1256  if (ifst.NumStates() == 0) {
1257  ofst->DeleteStates();
1258  return true;
1259  }
1260  int32 max_num_iters = 10; // avoid the potential for infinite loops if
1261  // retrying.
1262  VectorFst<ArcTpl<Weight> > temp_fst;
1263 
1264  for (int32 iter = 0; iter < max_num_iters; iter++) {
1265  LatticeDeterminizerPruned<Weight, IntType> det(iter == 0 ? ifst : temp_fst,
1266  beam, opts);
1267  double effective_beam;
1268  bool ans = det.Determinize(&effective_beam);
1269  // if it returns false it will typically still
1270  // produce reasonable output, just with a
1271  // narrower beam than "beam".
1272  if (effective_beam >= beam * opts.retry_cutoff ||
1273  iter + 1 == max_num_iters) {
1274  det.Output(ofst);
1275  return ans;
1276  } else {
1277  // The code below to set "beam" is a heuristic.
1278  // If effective_beam is very small, we want to reduce by a lot.
1279  // But never change the beam by more than a factor of two.
1280  if (effective_beam < 0)
1281  effective_beam = 0;
1282  double new_beam = beam * sqrt(effective_beam / beam);
1283  if (new_beam < 0.5 * beam) new_beam = 0.5 * beam;
1284  KALDI_WARN << "Effective beam " << effective_beam << " was less than beam "
1285  << beam << " * cutoff " << opts.retry_cutoff << ", pruning raw "
1286  << "lattice with new beam " << new_beam << " and retrying.";
1287  beam = new_beam;
1288  if (iter == 0) temp_fst = ifst;
1289  kaldi::PruneLattice(beam, &temp_fst);
1290  }
1291  }
1292  return false; // Suppress compiler warning; this code is unreachable.
1293 }
1294 
1295 template<class Weight>
1297  const kaldi::TransitionModel &trans_model,
1298  MutableFst<ArcTpl<Weight> > *fst) {
1299  // Define some types.
1300  typedef ArcTpl<Weight> Arc;
1301  typedef typename Arc::StateId StateId;
1302  typedef typename Arc::Label Label;
1303 
1304  // Work out the first phone symbol. This is more related to the phone
1305  // insertion function, so we put it here and make it the returning value of
1306  // DeterminizeLatticeInsertPhones().
1307  Label first_phone_label = HighestNumberedInputSymbol(*fst) + 1;
1308 
1309  // Insert phones here.
1310  for (StateIterator<MutableFst<Arc> > siter(*fst);
1311  !siter.Done(); siter.Next()) {
1312  StateId state = siter.Value();
1313  if (state == fst->Start())
1314  continue;
1315  for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
1316  !aiter.Done(); aiter.Next()) {
1317  Arc arc = aiter.Value();
1318 
1319  // Note: the words are on the input symbol side and transition-id's are on
1320  // the output symbol side.
1321  if ((arc.olabel != 0)
1322  && (trans_model.TransitionIdToHmmState(arc.olabel) == 0)
1323  && (!trans_model.IsSelfLoop(arc.olabel))) {
1324  Label phone =
1325  static_cast<Label>(trans_model.TransitionIdToPhone(arc.olabel));
1326 
1327  // Skips <eps>.
1328  KALDI_ASSERT(phone != 0);
1329 
1330  if (arc.ilabel == 0) {
1331  // If there is no word on the arc, insert the phone directly.
1332  arc.ilabel = first_phone_label + phone;
1333  } else {
1334  // Otherwise, add an additional arc.
1335  StateId additional_state = fst->AddState();
1336  StateId next_state = arc.nextstate;
1337  arc.nextstate = additional_state;
1338  fst->AddArc(additional_state,
1339  Arc(first_phone_label + phone, 0,
1340  Weight::One(), next_state));
1341  }
1342  }
1343 
1344  aiter.SetValue(arc);
1345  }
1346  }
1347 
1348  return first_phone_label;
1349 }
1350 
1351 template<class Weight>
1353  typename ArcTpl<Weight>::Label first_phone_label,
1354  MutableFst<ArcTpl<Weight> > *fst) {
1355  // Define some types.
1356  typedef ArcTpl<Weight> Arc;
1357  typedef typename Arc::StateId StateId;
1358  typedef typename Arc::Label Label;
1359 
1360  // Delete phones here.
1361  for (StateIterator<MutableFst<Arc> > siter(*fst);
1362  !siter.Done(); siter.Next()) {
1363  StateId state = siter.Value();
1364  for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
1365  !aiter.Done(); aiter.Next()) {
1366  Arc arc = aiter.Value();
1367 
1368  if (arc.ilabel >= first_phone_label)
1369  arc.ilabel = 0;
1370 
1371  aiter.SetValue(arc);
1372  }
1373  }
1374 }
1375 // instantiate for type LatticeWeight
1376 template
1378  ArcTpl<kaldi::LatticeWeight>::Label first_phone_label,
1379  MutableFst<ArcTpl<kaldi::LatticeWeight> > *fst);
1380 
1392 template<class Weight, class IntType>
1394  const kaldi::TransitionModel &trans_model,
1395  double beam,
1396  MutableFst<ArcTpl<Weight> > *fst,
1397  const DeterminizeLatticePrunedOptions &opts) {
1398  // First, insert the phones.
1399  typename ArcTpl<Weight>::Label first_phone_label =
1400  DeterminizeLatticeInsertPhones(trans_model, fst);
1401  TopSort(fst);
1402 
1403  // Second, do determinization with phone inserted.
1404  bool ans = DeterminizeLatticePruned<Weight>(*fst, beam, fst, opts);
1405 
1406  // Finally, remove the inserted phones.
1407  DeterminizeLatticeDeletePhones(first_phone_label, fst);
1408  TopSort(fst);
1409 
1410  return ans;
1411 }
1412 
1413 // "Destructive" version of DeterminizeLatticePhonePruned() where the input
1414 // lattice might be modified.
1415 template<class Weight, class IntType>
1417  const kaldi::TransitionModel &trans_model,
1418  MutableFst<ArcTpl<Weight> > *ifst,
1419  double beam,
1420  MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
1422  // Returning status.
1423  bool ans = true;
1424 
1425  // Make sure at least one of opts.phone_determinize and opts.word_determinize
1426  // is not false, otherwise calling this function doesn't make any sense.
1427  if ((opts.phone_determinize || opts.word_determinize) == false) {
1428  KALDI_WARN << "Both --phone-determinize and --word-determinize are set to "
1429  << "false, copying lattice without determinization.";
1430  // We are expecting the words on the input side.
1431  ConvertLattice<Weight, IntType>(*ifst, ofst, false);
1432  return ans;
1433  }
1434 
1435  // Determinization options.
1437  det_opts.delta = opts.delta;
1438  det_opts.max_mem = opts.max_mem;
1439 
1440  // If --phone-determinize is true, do the determinization on phone + word
1441  // lattices.
1442  if (opts.phone_determinize) {
1443  KALDI_VLOG(3) << "Doing first pass of determinization on phone + word "
1444  << "lattices.";
1445  ans = DeterminizeLatticePhonePrunedFirstPass<Weight, IntType>(
1446  trans_model, beam, ifst, det_opts) && ans;
1447 
1448  // If --word-determinize is false, we've finished the job and return here.
1449  if (!opts.word_determinize) {
1450  // We are expecting the words on the input side.
1451  ConvertLattice<Weight, IntType>(*ifst, ofst, false);
1452  return ans;
1453  }
1454  }
1455 
1456  // If --word-determinize is true, do the determinization on word lattices.
1457  if (opts.word_determinize) {
1458  KALDI_VLOG(3) << "Doing second pass of determinization on word lattices.";
1459  ans = DeterminizeLatticePruned<Weight, IntType>(
1460  *ifst, beam, ofst, det_opts) && ans;
1461  }
1462 
1463  // If --minimize is true, push and minimize after determinization.
1464  if (opts.minimize) {
1465  KALDI_VLOG(3) << "Pushing and minimizing on word lattices.";
1466  ans = PushCompactLatticeStrings<Weight, IntType>(ofst) && ans;
1467  ans = PushCompactLatticeWeights<Weight, IntType>(ofst) && ans;
1468  ans = MinimizeCompactLattice<Weight, IntType>(ofst) && ans;
1469  }
1470 
1471  return ans;
1472 }
1473 
1474 // Normal verson of DeterminizeLatticePhonePruned(), where the input lattice
1475 // will be kept as unchanged.
1476 template<class Weight, class IntType>
1478  const kaldi::TransitionModel &trans_model,
1479  const ExpandedFst<ArcTpl<Weight> > &ifst,
1480  double beam,
1481  MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
1483  VectorFst<ArcTpl<Weight> > temp_fst(ifst);
1484  return DeterminizeLatticePhonePruned(trans_model, &temp_fst,
1485  beam, ofst, opts);
1486 }
1487 
1489  const kaldi::TransitionModel &trans_model,
1490  MutableFst<kaldi::LatticeArc> *ifst,
1491  double beam,
1492  MutableFst<kaldi::CompactLatticeArc> *ofst,
1494  bool ans = true;
1495  Invert(ifst);
1496  if (ifst->Properties(fst::kTopSorted, true) == 0) {
1497  if (!TopSort(ifst)) {
1498  // Cannot topologically sort the lattice -- determinization will fail.
1499  KALDI_ERR << "Topological sorting of state-level lattice failed (probably"
1500  << " your lexicon has empty words or your LM has epsilon cycles"
1501  << ").";
1502  }
1503  }
1504  ILabelCompare<kaldi::LatticeArc> ilabel_comp;
1505  ArcSort(ifst, ilabel_comp);
1506  ans = DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
1507  trans_model, ifst, beam, ofst, opts);
1508  Connect(ofst);
1509  return ans;
1510 }
1511 
1512 // Instantiate the templates for the types we might need.
1513 // Note: there are actually four templates, each of which
1514 // we instantiate for a single type.
1515 template
1516 bool DeterminizeLatticePruned<kaldi::LatticeWeight>(
1517  const ExpandedFst<kaldi::LatticeArc> &ifst,
1518  double prune,
1519  MutableFst<kaldi::CompactLatticeArc> *ofst,
1521 
1522 template
1523 bool DeterminizeLatticePruned<kaldi::LatticeWeight>(
1524  const ExpandedFst<kaldi::LatticeArc> &ifst,
1525  double prune,
1526  MutableFst<kaldi::LatticeArc> *ofst,
1528 
1529 template
1530 bool DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
1531  const kaldi::TransitionModel &trans_model,
1532  const ExpandedFst<kaldi::LatticeArc> &ifst,
1533  double prune,
1534  MutableFst<kaldi::CompactLatticeArc> *ofst,
1536 
1537 template
1538 bool DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
1539  const kaldi::TransitionModel &trans_model,
1540  MutableFst<kaldi::LatticeArc> *ifst,
1541  double prune,
1542  MutableFst<kaldi::CompactLatticeArc> *ofst,
1544 
1545 }
fst::StdArc::StateId StateId
void ProcessTransitions(OutputStateId output_state_id)
DeterminizeLatticePrunedOptions opts_
LatticeWeightTpl< FloatType > Divide(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, DivideType typ=DIVIDE_ANY)
OutputStateId MinimalToStateId(const vector< Element > &subset, const double forward_cost)
void Output(MutableFst< CompactArc > *ofst, bool destroy=true)
bool DeterminizeLatticePruned(const ExpandedFst< ArcTpl< Weight > > &ifst, double beam, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > *ofst, DeterminizeLatticePrunedOptions opts)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
LatticeWeightTpl< FloatType > Plus(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
LatticeDeterminizerPruned(const ExpandedFst< Arc > &ifst, double beam, DeterminizeLatticePrunedOptions opts)
bool DeterminizeLatticePhonePruned(const kaldi::TransitionModel &trans_model, MutableFst< ArcTpl< Weight > > *ifst, double beam, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, IntType > > > *ofst, DeterminizeLatticePhonePrunedOptions opts)
"Destructive" version of DeterminizeLatticePhonePruned() where the input lattice might be changed...
kaldi::int32 int32
bool ApproxEqual(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, float delta=kDelta)
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeDeterminizerPruned)
void AddStrings(const vector< Element > &vec, vector< StringId > *needed_strings)
bool DeterminizeLatticePhonePrunedFirstPass(const kaldi::TransitionModel &trans_model, double beam, MutableFst< ArcTpl< Weight > > *fst, const DeterminizeLatticePrunedOptions &opts)
This function does a first pass determinization with phone symbols inserted at phone boundary...
CompactLatticeWeightTpl< Weight, IntType > CompactWeight
LatticeStringRepository< IntType > StringRepositoryType
double ConvertToCost(const LatticeWeightTpl< Float > &w)
int32 TransitionIdToHmmState(int32 trans_id) const
bool IsSelfLoop(int32 trans_id) const
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
void ProcessFinal(OutputStateId output_state_id)
fst::StdArc::Label Label
unordered_map< const vector< Element > *, Element, SubsetKey, SubsetEqual > InitialSubsetHash
bool Determinize(double *effective_beam)
fst::StdArc::Weight Weight
OutputStateId InitialToStateId(const vector< Element > &subset_in, double forward_cost, Weight *remaining_weight, StringId *common_prefix)
LatticeStringRepository< IntType > repository_
void NormalizeSubset(vector< Element > *elems, Weight *tot_weight, StringId *common_str)
Arc::Label HighestNumberedInputSymbol(const Fst< Arc > &fst)
Returns the highest numbered input symbol id of the FST (or zero for an empty FST.
void ProcessTransition(OutputStateId ostate_id, Label ilabel, vector< Element > *subset)
int Compare(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.
bool PruneLattice(BaseFloat beam, LatType *lat)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
OutputState(const vector< Element > &minimal_subset, double forward_cost)
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
vector< pair< Label, Element > > all_elems_tmp_
ArcTpl< Weight >::Label DeterminizeLatticeInsertPhones(const kaldi::TransitionModel &trans_model, MutableFst< ArcTpl< Weight > > *fst)
This function takes in lattices and inserts phones at phone boundaries.
void Output(MutableFst< Arc > *ofst, bool destroy=true)
void EpsilonClosure(vector< Element > *subset)
const StringRepositoryType::Entry * StringId
unordered_map< const vector< Element > *, OutputStateId, SubsetKey, SubsetEqual > MinimalSubsetHash
void MakeSubsetUnique(vector< Element > *subset)
#define KALDI_LOG
Definition: kaldi-error.h:153
int32 TransitionIdToPhone(int32 trans_id) const
std::priority_queue< Task *, vector< Task * >, TaskCompare > queue_
bool DeterminizeLatticePhonePrunedWrapper(const kaldi::TransitionModel &trans_model, MutableFst< kaldi::LatticeArc > *ifst, double beam, MutableFst< kaldi::CompactLatticeArc > *ofst, DeterminizeLatticePhonePrunedOptions opts)
This function is a wrapper of DeterminizeLatticePhonePruned() that works for Lattice type FSTs...
void DeterminizeLatticeDeletePhones(typename ArcTpl< Weight >::Label first_phone_label, MutableFst< ArcTpl< Weight > > *fst)
This function takes in lattices and deletes "phones" from them.
void ConvertToMinimal(vector< Element > *subset)
int Compare(const Weight &a_w, StringId a_str, const Weight &b_w, StringId b_str) const
void Copy(const CuMatrixBase< Real > &src, const CuArray< int32 > &copy_from_indices, CuMatrixBase< Real > *tgt)
Copies elements from src into tgt as given by copy_from_indices.
Definition: cu-math.cc:173