compose-lattice-pruned.cc
Go to the documentation of this file.
1 // lat/compose-lattice-pruned.cc
2 
3 // Copyright 2009-2012 Microsoft Corporation
4 // 2012-2013 Johns Hopkins University (Author: Daniel Povey)
5 // 2014 Guoguo Chen
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
23 #include "lat/lattice-functions.h"
24 
25 namespace kaldi {
26 
48  public:
50  const ComposeLatticePrunedOptions &opts,
51  const CompactLattice &clat,
53  CompactLattice* composed_clat);
54 
55  // Does the composition. You must call this just once per object.
56  void Compose();
57 
58  private:
59 
60  // Gets the num-arcs limit for this iteration of the algorithm, which will be
61  // opts_.initial_num_arcs if there are currently no arcs; or otherwise
62  // opts_.growth_ration * the current number of arcs (subject to the
63  // opts_.max_arcs limit if we have already reached a final-state). This helps
64  // ensure that we call RecomputePruningInfo() on an appropriate schedule.
65  int32 GetCurrentArcLimit() const;
66 
67  // This function, called just once at the start, computes all the static
68  // information about the input lattice 'clat', in lat_state_info_. (however,
69  // the 'composed_states' members are just set to the empty vector for now.
71 
72  // Called just once at the start, this sets up the first state in the
73  // composed output.
74  void AddFirstState();
75 
76  // This function processes the next un-expanded transition (or final-state)
77  // out of the composed state numbered 'composed_state_to_expand'.
78  void ProcessQueueElement(int32 composed_state_to_expand);
79 
80  // This is a part of ProcessQueueElements() that has been broken out
81  // for clarity. it process the arc_index'th arc out of this source state.
82  void ProcessTransition(int32 composed_src_state,
83  int32 arc_index);
84 
85  // This function recomputes certain members of the ComposedStateInfo relating
86  // to the output states: namely, 'forward_cost', 'backward_cost' and
87  // 'delta_backward_cost'. In between calls to this function, we try to
88  // keep those quantities as accurate as possible, but they aren't
89  // completely accurate (see comments by their declarations for more info).
90  void RecomputePruningInfo();
91 
92  // Sets '*composed_states' to a list of the states that currently
93  // exist in the composed output, in topologically sorted order.
94  // At exit, *composed_states will be a permutation of numbers
95  // [0, 1, ... clat_out_->NumStates() - 1], beginning with the
96  // start-state 0.
97  void GetTopsortedStateList(std::vector<int32> *composed_states) const;
98 
99  // Called from RecomputePruningInfo(), this computes all the 'forward_cost'
100  // and 'prev_composed_state' members of the ComposedStateInfo.
101  // @param [in] composed_states This is expected to be a list,
102  // in topological order, of all currently existing composed states,
103  // as produced by GetTopsortedStateList().
104  void ComputeForwardCosts(const std::vector<int32> &composed_states);
105 
106  // Called from RecomputePruningInfo(), this computes all the 'backward_cost'
107  // members of the ComposedStateInfo. It also sets 'output_best_cost_'.
108  // 'composed_states' is expected to be a list, in topological order, of all
109  // currently existing composed states, as produced by GetTopsortedStateList().
110  void ComputeBackwardCosts(const std::vector<int32> &composed_states);
111 
112  // Called from RecomputePruningInfo(), this computes all the
113  // 'delta_backward_cost' members of the ComposedStateInfo. 'composed_states'
114  // is expected to be a list, in topological order, of all currently existing
115  // composed states, as produced by GetTopsortedStateList(). It also computes
116  // the 'expected_cost_offset' values for all states, and uses them recreate
117  // 'composed_state_queue_'.
118  void ComputeDeltaBackwardCosts(const std::vector<int32> &composed_states);
119 
120 
121  // This struct contains information about a state of the input lattice.
123  // 'backward_cost' is the total cost of the best path from this state to
124  // the final state in the source lattice, including the final-prob.
126 
127  // 'arc_delta_costs' is an array, one for each arc (and the final-prob, if
128  // present), showing how much the cost to the final-state for the best path
129  // starting in this state and exiting through each arc (or final-prob),
130  // differs from 'backward_cost'. Specifically, it contains pairs
131  // (delta_cost, arc_index), where delta_cost >= 0 and arc_index is
132  // either the index into this state's array of arcs (for arcs), or -1
133  // if this represents the final-prob.
134  //
135  // 'arc_delta_costs' will be sorted, so that the first element has
136  // .first=0.0 and the delta-costs will be increasing order. This means that
137  // we expand them from the start of the array, in order to process the best
138  // arcs first.
139  // lat_state_info_[i].arc_delta_costs.size() will equal will equal
140  // clat_.NumStates(i), plus one if clat_.Final(i) != Zero().
141  std::vector<std::pair<BaseFloat, int32> > arc_delta_costs;
142 
143 
144  // 'composed_states' is a list of the state-ids in the composed output
145  // that correspond to this state in the lattice, so we expect
146  // that composed_state_info_[composed_states[i]].lat_state
147  // equals the index of this lattice state. This is helpful in
148  // accessing the states in the output lattice in topological
149  // order.
150  std::vector<int32> composed_states;
151  };
152 
153  // This struct contains information about a state of the composed
154  // output.
156  // 'lat_state' and 'lm_state' form the pair of states in the two FSTs
157  // that this state corresponds to. The unordered map 'pair_to_state_' maps these
158  // state-pairs to the index of the composed state (the state-index in clat_out_).
161 
162  // the number of arcs on the path from the start state to this state, in the
163  // composed lattice, by which this state was first reached.
165 
166  // If you have just called RecomputePruningInfo(), then
167  // 'forward_cost' will equal the cost of the best path from the start-state
168  // to this state, in the composed output.
169  //
170  // In between calls to RecomputePruningInfo() it may not always be fully up
171  // to date; instead it will be an upper bound on what it would be if you had
172  // just called RecomputePruningInfo(); it will be the cost of some path but
173  // not necessarily the best path.
174  double forward_cost;
175 
176  // 'backward_cost' relates to the cost from this state to the final-state in
177  // the composed FST. (By this we mean, more precisely, the cost of the best
178  // path from this state to any final state, including the final-prob in that
179  // final state).
180  //
181  // If we have just called RecomputePruningInfo(), then the following rules
182  // specify what the value of 'backward_cost' will be:
183  // - If a final state is reachable from this state, backward_cost
184  // will contain the cost of the best path from this state to the
185  // final state (including the corresponding final-prob).
186  // - Otherwise, it will contain +infinity.
187  //
188  // If RecomputePruningInfo() has not just been called), it may contain any
189  // value that is >= the value the the rules above specify (since, for
190  // existing states, we don't modify it between calls to
191  // RecomputePruningInfo()). For states that have been added since
192  // RecomputePruningInfo() was last called, it will be infinity.
194 
195  // 'delta_backward_cost' is a quantity that is used in our heuristic of the
196  // cost to an end-state from expanding a previously un-expanded arc. It is
197  // an estimate of the difference between the backward cost in this struct
198  // (this->backward_cost) and the backward cost in the input lattice
199  // (LatticeStateInfo::backward_cost). This term reflects the anticipated
200  // extra costs from 'det_fst_', which, while fairly close to zero, may be
201  // substantial enough to want to correct for.
202  //
203  // The following is the value that 'delta_backward_cost' will have if
204  // RecomputePruningInfo() has just been called:
205  // - If backward_cost is finite (this state in the composed result can
206  // reach the final state via currently expanded states), then
207  // delta_backward_cost is this->backward_cost minus
208  // lat_state_info_[lat_state].backward_cost. (It will mostly, but
209  // not always, be <= 0, reflecting that the new LM is better than
210  // the old LM).
211  // - On the other hand, if backward_cost is infinite: delta_backward_cost
212  // is set to the delta_backward_cost of the previous state on the best
213  // path from the start state of the composed result to this state (or
214  // zero if this is the start state).
215  //
216  // If RecomputePruningInfo() has not just been called, then:
217  // - For states created since RecomputePruningInfo() was last called,
218  // delta_backward_cost will be inherited from the source state from
219  // which the new state was expanded.
220  // - For other states, delta_backward_cost will be unchanged since
221  // RecomputePruningInfo() was last called.
222  // The above rules may make the delta_backward_cost a less accurate, but
223  // still probably reasonable, heuristic. What it is a heuristic for,
224  // is: if we were to successfully reach an end-state of the composed output
225  // from this state, what would be the resulting backward_cost
226  // minus lat_state_info_[lat_state].backward_cost.
228 
229  // 'prev_composed_state' is the previous state on the best path from
230  // the start-state to the current state (or -1 if this is the start state).
231  // It is computed in RecomputePruningInfo() when setting up 'forward_cost',
232  // and then used to compute delta_backward_cost. It is not otherwise
233  // used.
235 
236  // 'sorted_arc_index' is an index into the 'arc_delta_costs' array which is
237  // a member of the LatticeStateInfo object corresponding to the lattice
238  // state 'lat_state'. It corresponds to the next arc (or final-prob) out of
239  // the input lattice that we have yet to expand in the composition; or -1 if
240  // we have expanded all of them. When we first reach a composed state,
241  // 'sorted_arc_index' will be zero; then it will increase one at a time as
242  // we expand arcs until either the composition terminates or we have
243  // expanded all the arcs and it becomes -1.
245 
246  // 'arc_delta_cost' is a derived quantity that we store here for easier
247  // access. Suppose this_lat_info is lat_state_info_[lat_state]; then
248  // if sorted_arc_index >= 0, then:
249  // arc_delta_cost == this_lat_info.arc_delta_costs[sorted_arc_index].first
250  // else: arc_delta_cost == +infinity.
251  //
252  // what 'arc_delta_cost' represents (or is a heuristic for), is the expected
253  // cost of a path to the final-state leaving through the arc we're about to
254  // expand, minus the expected cost of any path to the final-state starting
255  // from this state.
257 
258  // view 'expected_cost_offset' a phantom field of this struct, that has
259  // been optimized out. It's clearer if we act like it's a field, but
260  // actually it's not stored.
261  //
262  // 'expected_cost_offset' is a derived quantity that reflects the expected
263  // cost (according to our heuristic) of the best path we might encounter
264  // when expanding the next previously unseen arc (or final-prob),
265  // corresponding to 'sorted_arc_index'. (This is the expected cost of a
266  // successful path, from the beginning to the end of the lattice, but
267  // constrained to be a path that contains the arc we're about to expand).
268  //
269  // The 'offset' part is about subtracting the best cost of the lattice, so we
270  // can cast to float without too much loss of accuracy:
271  // expected_cost_offset = expected_cost - lat_best_cost_.
272  //
273  // We define expected_cost_offset by defining the 'expected_cost' part;
274  // for clarity:
275  // First, let lat_backward_cost equal the backward_cost of the LatticeStateInfo
276  // corresponding to 'lat_state', i.e.
277  // lat_backward_cost = lat_state_info_[lat_state].backward_cost. Then:
278  // expected_cost = forward_cost + lat_backward_cost +
279  // delta_backward_cost + arc_delta_cost.
280  // expected_cost_offset will always equal the above minus lat_best_cost_.
281  //
282  // The formula for expected_cost above is a pretty good heuristic for what
283  // the cost to the end-state will be. If the costs in det_fst_ were zero,
284  // then the expression (forward_cost + lat_backward_cost + arc_delta_cost)
285  // would be exact, and we would expand things in the ideal, best-first
286  // order. "delta_backward_cost" is a reasonable approximation for the extra
287  // costs from 'det_fst_'.
288  // BaseFloat expected_cost_offset;
289  };
290 
291  // This bool variable is initialized to false, and will be updated to true
292  // the first time a Final() function is called on the det_fst_. Then we will
293  // immediately call RecomputeRruningInfo() so that the output_best_cost_ is
294  // changed from +inf to a finite value, to be used in beam search. This is the
295  // only time the RecomputeRruningInfo() function is called manually; otherwise
296  // it always follows an automatic schedule based on the num-arcs of the output
297  // lattice.
299 
300  // This variable, which we set initially to -1000, makes sure that in the
301  // beginning of the algorithm, we always prioritize exploring the lattice
302  // in a depth-first way. Once we find a path reaching a final state, this
303  // variable will be reset to 0.
304  // The reason we do this is because the beam-search depends on a good estimate
305  // of the composed-best-cost, which before we reach a final state, we instead
306  // borrow the value from best-cost from the input lattice, which is usually
307  // systematically worse than the RNNLM scores, and makes the algorithm spend
308  // a lot of time before reaching any final state, especially if the input
309  // lattices are large.
315 
316  // This counter keeps track of the number of arcs in the output lattice
317  // clat_out_. When it exceeds max_arcs,
319 
320  std::vector<LatticeStateInfo> lat_state_info_;
321 
322  // 'lat_best_cost' is the cost of the best path in the input lattice,
323  // equal to lat_state_info_[0].backward_cost (we check that 0 is the
324  // start state in the input lattice).
326 
327  // 'output_best_cost_' is the cost of the best successful path in the output
328  // lattice 'clat_out_'; or +infinity if 'clat_out_' does not yet have any
329  // successful paths. It is updated only when RecomputePruningInfo() is
330  // called.
332 
333 
334  // current_cutoff_ is a value used in deciding which composed states
335  // need to be included in the queue. Each time RecomputePruningInfo()
336  // called, current_cutoff_ is set to
337  // (output_best_cost_ - lat_best_cost_ + opts_.lattice_compose_beam).
338  // It will be +infinity if the output lattice doesn't yet have any
339  // successful paths. It decreases with time. You can compare the
340  // phantom 'expected_cost_offset' members of ComposedStateInfo with this
341  // value; if they are more than this value, then there is no need
342  // to enter the corresponding state into the queue.
344 
345  typedef std::priority_queue<std::pair<BaseFloat, int32>,
346  std::vector<std::pair<BaseFloat, int32> >,
347  std::greater<std::pair<BaseFloat, int32> > > QueueType;
348 
349  // composed_state_queue_ is a priority queue of the composed states
350  // that we are intending to expand. It contains pairs
351  // (expected_cost_offset, composed_state_index),
352  // where expected_cost_offset == the phantom variable
353  // composed_state_info_[composed_state_index].expected_cost_offset.
354  // We process the states from lowest cost first.
355  // Every time RecomputePruningInfo() is called, this is cleared and repopulated
356  // (since the states' expected_cost_offset values may change), and in between
357  // calls to RecomputePruningInfo(), we do insert elements for newly created
358  // states.
360 
361 
362  std::vector<ComposedStateInfo> composed_state_info_;
363 
364  // This maps a pair (lat_state, lm_state) to the index of the
365  // state in the composed FST. That would correspond to a state-id in
366  // clat_out_, and also to an index into 'composed_state_info_'.
367  unordered_map<std::pair<int32,int32>,
369 
370  // This contains the set of state-indexes of the input lattice that already
371  // have states in the composed output (i.e. is in accessed_lat_states_ if and
372  // only if !lat_state_info_[i].composed_states.empty(). The point is to be
373  // able to enumerate, in order or in reverse order, just those states in the
374  // lattice that appear in the composed output (it's an efficiency thing that
375  // will matter more for early iterations of the composition, when we need
376  // to access the output lattice in topological order).
377  std::set<int32> accessed_lat_states_;
378 };
379 
380 
382  std::vector<int32> *composed_states) const {
383  composed_states->clear();
384  composed_states->reserve(clat_out_->NumStates());
385  std::set<int32>::const_iterator iter = accessed_lat_states_.begin(),
386  end = accessed_lat_states_.end();
387  for (; iter != end; ++iter) {
388  int32 lat_state = *iter;
389  const LatticeStateInfo &input_lat_info = lat_state_info_[lat_state];
390  composed_states->insert(composed_states->end(),
391  input_lat_info.composed_states.begin(),
392  input_lat_info.composed_states.end());
393  }
394  KALDI_ASSERT((*composed_states)[0] == 0 &&
395  static_cast<int32>(composed_states->size()) ==
396  clat_out_->NumStates());
397 }
398 
400  int32 current_num_arcs = num_arcs_out_;
401  if (current_num_arcs == 0) {
402  return opts_.initial_num_arcs;
403  } else {
405  int32 ans = static_cast<int32>(current_num_arcs *
407  if (ans == current_num_arcs) // make sure the target increases.
408  ans = current_num_arcs + 1;
409  // if we have already reached a final state, then
410  // apply the max_arcs limit.
411  if (output_best_cost_ - output_best_cost_ == 0.0 &&
412  ans > opts_.max_arcs)
413  ans = opts_.max_arcs;
414  return ans;
415  }
416 
417 }
418 
419 
421  std::vector<int32> all_composed_states;
422  GetTopsortedStateList(&all_composed_states);
423  ComputeForwardCosts(all_composed_states);
424  ComputeBackwardCosts(all_composed_states);
425  ComputeDeltaBackwardCosts(all_composed_states);
426 }
427 
429  const std::vector<int32> &composed_states) {
430  KALDI_ASSERT(composed_states[0] == 0);
431 
432  // Note: when we initialized composed_state_info_[0]
433  // we set forward_cost = 0.0, prev_composed_state = -1.
434 
435  std::vector<ComposedStateInfo>::iterator
436  state_iter = composed_state_info_.begin(),
437  state_end = composed_state_info_.end();
438 
439  state_iter->depth = 0; // start state has depth 0
440  ++state_iter; // Skip over the start state.
441  // Set all other forward_cost fields to infinity and prev_composed_state to
442  // -1.
443  for (; state_iter != state_end; ++state_iter) {
444  state_iter->forward_cost = std::numeric_limits<double>::infinity();
445  state_iter->prev_composed_state = -1;
446  }
447 
448  std::vector<int32>::const_iterator state_index_iter = composed_states.begin(),
449  state_index_end = composed_states.end();
450  for (; state_index_iter != state_index_end; ++state_index_iter) {
451  int32 composed_state_index = *state_index_iter;
453  composed_state_index];
454  double forward_cost = info.forward_cost;
455  // The next line is a check for infinity. If infinities have appeared, it
456  // either means there is a bug in the algorithm or there were infinities or
457  // NaN's in the lattice.
458  KALDI_ASSERT(forward_cost - forward_cost == 0.0);
459  fst::ArcIterator<CompactLattice> aiter(*clat_out_,
460  composed_state_index);
461  for (; !aiter.Done(); aiter.Next()) {
462  const CompactLatticeArc &arc = aiter.Value();
463  double arc_cost = ConvertToCost(arc.weight),
464  next_forward_cost = forward_cost + arc_cost;
465  ComposedStateInfo &next_info = composed_state_info_[arc.nextstate];
466  if (next_info.forward_cost > next_forward_cost) {
467  next_info.forward_cost = next_forward_cost;
468  next_info.prev_composed_state = composed_state_index;
469  next_info.depth = composed_state_info_[composed_state_index].depth + 1;
470  }
471  }
472  }
473 }
474 
476  const std::vector<int32> &composed_states) {
477  // Access the composed states in reverse topological order from latest to
478  // earliest.
479  std::vector<int32>::const_reverse_iterator iter = composed_states.rbegin(),
480  end = composed_states.rend();
481  for (; iter != end; ++iter) {
482  int32 composed_state_index = *iter;
483  ComposedStateInfo &info = composed_state_info_[composed_state_index];
484  double backward_cost =
485  ConvertToCost(clat_out_->Final(composed_state_index));
486  fst::ArcIterator<CompactLattice> aiter(*clat_out_,
487  composed_state_index);
488  for (; !aiter.Done(); aiter.Next()) {
489  const CompactLatticeArc &arc = aiter.Value();
490  double arc_cost = ConvertToCost(arc.weight),
491  next_backward_cost = composed_state_info_[arc.nextstate].backward_cost,
492  this_backward_cost = arc_cost + next_backward_cost;
493  if (this_backward_cost < backward_cost)
494  backward_cost = this_backward_cost;
495  }
496  // It's OK if at this point, backward_cost is still +infinity. This means
497  // that this state cannot reach the end yet, which means we have not yet
498  // expanded any path from this state all the way to a final-state of the
499  // output.
501  }
502  output_best_cost_ = composed_state_info_[0].backward_cost;
503  // See the declaration of current_cutoff_ for more information. Note: on
504  // early iterations, before any path reaches a final state of the composed
505  // lattice, current_cutoff_ may be +infinity, and this is OK.
508 }
509 
511  const std::vector<int32> &composed_states) {
512 
513  int32 num_states = clat_out_->NumStates();
514  for (int32 composed_state_index = 0; composed_state_index < num_states;
515  ++composed_state_index) {
516  ComposedStateInfo &info = composed_state_info_[composed_state_index];
517  int32 lat_state = info.lat_state;
518  // Note: delta_backward_cost will be +infinity at this stage if the
519  // backward_cost was +infinity. This is OK; we'll set them all to
520  // finite values later in this function.
521  info.delta_backward_cost =
522  info.backward_cost - lat_state_info_[lat_state].backward_cost + info.depth * depth_penalty_;
523  }
524 
525  // 'queue_elements' is a list of items (expected_cost_offset,
526  // composed_state_index) that we are going to add to composed_state_queue_,
527  // after clearing it. It's more efficient to accumulate them as a vector
528  // and add them all at once, than adding them one by one (search online for
529  // "heapify" if this seems confusing).
530  std::vector<std::pair<BaseFloat, int32> > queue_elements;
531  queue_elements.reserve(num_states);
532 
533  double lat_best_cost = lat_best_cost_;
534  BaseFloat current_cutoff = current_cutoff_;
535  std::vector<int32>::const_iterator iter = composed_states.begin(),
536  end = composed_states.end();
537  for (; iter != end; ++iter) {
538  int32 composed_state_index = *iter;
539  ComposedStateInfo &info = composed_state_info_[composed_state_index];
540  if (info.delta_backward_cost - info.delta_backward_cost != 0) {
541  // if info.delta_backward_cost is +infinity...
542  int32 prev_composed_state = info.prev_composed_state;
543  if (prev_composed_state < 0) {
544  KALDI_ASSERT(composed_state_index == 0);
545  info.delta_backward_cost = 0.0;
546  } else {
547  const ComposedStateInfo &prev_info =
548  composed_state_info_[prev_composed_state];
549  // Check that prev_info.delta_backward_cost is finite.
551  prev_info.delta_backward_cost == 0.0);
553  }
554  }
555  double lat_backward_cost = lat_state_info_[info.lat_state].backward_cost;
556  // See the formula by where expected_cost_offset is declared in the
557  // struct for explanation.
558  BaseFloat expected_cost_offset =
559  info.forward_cost + lat_backward_cost + info.delta_backward_cost +
560  info.arc_delta_cost - lat_best_cost;
561  // If info.expected_cost_offset were real, we'd set it here:
562  //info.expected_cost_offset = expected_cost_offset;
563 
564  // At this point expected_cost_offset may be infinite, if arc_delta_cost was
565  // infinite (reflecting that we processed all the arcs, and the final-state
566  // if applicable, of the lattice state corresponding to this composed state.
567  if (expected_cost_offset < current_cutoff) {
568  queue_elements.push_back(std::pair<BaseFloat, int32>(
569  expected_cost_offset, composed_state_index));
570  }
571  }
572 
573  // Reinitialize composed_state_queue_ from 'queue_elements'.
574  QueueType temp_queue(queue_elements.begin(), queue_elements.end());
575  composed_state_queue_.swap(temp_queue);
576 }
577 
579  KALDI_ASSERT(clat_in_.Properties(fst::kTopSorted, true) ==
580  fst::kTopSorted && clat_in_.NumStates() > 0 &&
581  clat_in_.Start() == 0);
582  int32 num_lat_states = clat_in_.NumStates();
583  lat_state_info_.resize(num_lat_states);
584 
585  for (int32 s = num_lat_states - 1; s >= 0; s--) {
587  std::vector<std::pair<double, int32> > arc_costs;
588  double backward_cost = ConvertToCost(clat_in_.Final(s));
589  if (backward_cost != std::numeric_limits<double>::infinity())
590  arc_costs.push_back(std::pair<BaseFloat,int32>(backward_cost, -1));
591  fst::ArcIterator<CompactLattice> aiter(clat_in_, s);
592  int32 arc_index = 0;
593  for (; !aiter.Done(); aiter.Next(), ++arc_index) {
594  const CompactLatticeArc &arc = aiter.Value();
595  KALDI_ASSERT(arc.nextstate > s);
596  backward_cost = lat_state_info_[arc.nextstate].backward_cost +
597  ConvertToCost(arc.weight);
598  KALDI_ASSERT(backward_cost - backward_cost == 0.0 &&
599  "Possibly not all states of input lattice are co-accessible?");
600  arc_costs.push_back(std::pair<BaseFloat,int32>(backward_cost, arc_index));
601  }
602  std::sort(arc_costs.begin(), arc_costs.end());
603  KALDI_ASSERT(!arc_costs.empty() &&
604  "Possibly not all states of input lattice are co-accessible?");
605  backward_cost = arc_costs[0].first;
606  info.backward_cost = backward_cost; // this is the state's backward_cost,
607  // reflecting the best path to the end.
608  info.arc_delta_costs.resize(arc_costs.size());
609  std::vector<std::pair<double, int32> >::const_iterator
610  src_iter = arc_costs.begin(), src_end = arc_costs.end();
611  std::vector<std::pair<BaseFloat, int32> >::iterator
612  dest_iter = info.arc_delta_costs.begin();
613  for (; src_iter != src_end; ++src_iter, ++dest_iter) {
614  dest_iter->first = BaseFloat(src_iter->first - backward_cost);
615  dest_iter->second = src_iter->second;
616  }
617  }
618  lat_best_cost_ = lat_state_info_[0].backward_cost;
619 }
620 
622  const ComposeLatticePrunedOptions &opts,
623  const CompactLattice &clat_in,
625  CompactLattice* composed_clat): output_reached_final_(false),
626  opts_(opts), clat_in_(clat_in), det_fst_(det_fst),
627  clat_out_(composed_clat),
628  num_arcs_out_(0),
629  output_best_cost_(std::numeric_limits<double>::infinity()),
630  current_cutoff_(std::numeric_limits<double>::infinity()) {
631  clat_out_->DeleteStates();
632  depth_penalty_ = -1000;
633 }
634 
635 
637  int32 state_id = clat_out_->AddState();
638  clat_out_->SetStart(state_id);
639  KALDI_ASSERT(state_id == 0);
640  composed_state_info_.resize(1);
641  ComposedStateInfo &composed_state = composed_state_info_[0];
642  composed_state.lat_state = 0;
643  composed_state.lm_state = det_fst_->Start();
644  composed_state.depth = 0;
645  composed_state.forward_cost = 0.0;
646  composed_state.backward_cost = std::numeric_limits<double>::infinity();
647  composed_state.delta_backward_cost = 0.0;
648  composed_state.prev_composed_state = -1;
649  composed_state.sorted_arc_index = 0;
650  composed_state.arc_delta_cost = 0.0; // the first arc_delta_cost is always 0.0
651  // due to sorting; no need to look it up.
652  lat_state_info_[0].composed_states.push_back(state_id);
653  accessed_lat_states_.insert(state_id);
654  pair_to_state_[std::pair<int32, int32>(0, det_fst_->Start())] = state_id;
655 
656  BaseFloat expected_cost_offset = 0.0; // the formula simplifies to zero
657  // in this case.
659  std::pair<BaseFloat, int32>(expected_cost_offset,
660  state_id)); // actually (0.0, 0).
661 
662 }
663 
664 
666  int32 src_composed_state) {
667  KALDI_ASSERT(static_cast<size_t>(src_composed_state) <
668  composed_state_info_.size());
669 
670  ComposedStateInfo &src_composed_state_info = composed_state_info_[
671  src_composed_state];
672  int32 lat_state = src_composed_state_info.lat_state;
673  const LatticeStateInfo &lat_state_info =
674  lat_state_info_[lat_state];
675 
676  int32 sorted_arc_index = src_composed_state_info.sorted_arc_index,
677  num_sorted_arcs = lat_state_info.arc_delta_costs.size();
678  // note: num_sorted_arcs will be the number of arcs from this
679  // lattice state; plus one if there is a final-prob.
680  KALDI_ASSERT(sorted_arc_index >= 0);
681 
682  { // this block update the state's 'sorted_arc_index', 'arc_delta_cost' and
683  // 'expected_cost_offset' to reflect the fact that (by the time we exit from
684  // this function) we will have processed this arc (or the final-prob);
685  // it also re-inserts this state into the queue, if appropriate.
686  BaseFloat expected_cost_offset;
687  if (sorted_arc_index + 1 == num_sorted_arcs) {
688  src_composed_state_info.sorted_arc_index = -1;
689  src_composed_state_info.arc_delta_cost =
690  std::numeric_limits<BaseFloat>::infinity();
691  expected_cost_offset =
692  std::numeric_limits<BaseFloat>::infinity();
693  } else {
694  src_composed_state_info.sorted_arc_index = sorted_arc_index + 1;
695  src_composed_state_info.arc_delta_cost =
696  lat_state_info.arc_delta_costs[sorted_arc_index+1].first;
697  expected_cost_offset =
698  (src_composed_state_info.forward_cost +
699  lat_state_info.backward_cost +
700  src_composed_state_info.delta_backward_cost +
701  src_composed_state_info.arc_delta_cost - lat_best_cost_);
702  }
703  // We do '<' here rather than '<=', so that if current_cutoff_ is infinity
704  // and expected_cost_offset is infinity (because we've exhausted all the
705  // transitions from this state, and sorted_arc_index is now -1), we don't
706  // add this element to the queue.
707  if (expected_cost_offset < current_cutoff_) {
708  // this state has another exit arc (or final prob) that is good
709  // enough to re-enter into the queue. Note: if we are processing
710  // an arc out of this state and the destination state is new,
711  // we may also add something new to the queue at that time.
712 
713  // the following call should be equivalent to
714  // composed_state_queue_.push(std::pair<BaseFloat,int32>(...)) with
715  // the same pair of args.
716  composed_state_queue_.emplace(
717  expected_cost_offset, src_composed_state);
718  }
719  }
720 
721  int32 arc_index = lat_state_info.arc_delta_costs[sorted_arc_index].second;
722  if (arc_index < 0) { // This (arc_index == -1) means it is not really an arc
723  // index; it's a final-prob.
724  int32 lm_state = src_composed_state_info.lm_state;
725  BaseFloat lm_final_cost = det_fst_->Final(lm_state).Value();
726  if (lm_final_cost != std::numeric_limits<BaseFloat>::infinity()) {
727  // If there is a final-prob on this LM state (note: there always will be
728  // for conventional language models), then add the final-prob of this
729  // state...
730  CompactLattice::Weight final_weight = clat_in_.Final(lat_state);
731  // assume 'final_weight' is not Zero(); otherwise the final-prob should
732  // not have been present in 'arc_delta_costs'.
733  Lattice::Weight final_lat_weight = final_weight.Weight();
734  final_lat_weight.SetValue1(final_lat_weight.Value1() +
735  lm_final_cost);
736  final_weight.SetWeight(final_lat_weight);
737  clat_out_->SetFinal(src_composed_state, final_weight);
738  double final_cost = ConvertToCost(final_lat_weight);
739  if (final_cost < src_composed_state_info.backward_cost)
740  src_composed_state_info.backward_cost = final_cost;
741  if (!output_reached_final_) {
742  output_reached_final_ = true;
743  depth_penalty_ = 0.0;
745  }
746  }
747  } else {
748  // It really was an arc. This code is very complicated, so we make it its
749  // own function.
750  ProcessTransition(src_composed_state, arc_index);
751  }
752 }
753 
755  int32 arc_index) {
756  // Make src_composed_state a const pointer not a reference, as we may have to
757  // modify the pointer if composed_state_info_ is resized.
758  const ComposedStateInfo *src_info = &(composed_state_info_[
759  src_composed_state]);
760  int32 src_lat_state = src_info->lat_state;
761  // Get the arc we are going to expand.
762  fst::ArcIterator<CompactLattice> aiter(clat_in_, src_lat_state);
763  aiter.Seek(arc_index);
764  const CompactLatticeArc &lat_arc = aiter.Value();
765  // Note: this code is for CompactLatticeArc, in which the ilabel and olabel
766  // are the same, but we're writing it in such a way that it will naturally
767  // generalize to LatticeArc, so there are separate variables for the ilabel
768  // and the olabel.
769  int32 dest_lat_state = lat_arc.nextstate,
770  ilabel = lat_arc.ilabel,
771  olabel = lat_arc.olabel;
772  // Note: we expect that ilabel == olabel, since this is a CompactLattice, but this
773  // may not be so if we extend this to work with Lattice.
774  fst::StdArc lm_arc;
775 
776  // the input lattice might have epsilons
777  if (olabel == 0) {
778  lm_arc.ilabel = 0;
779  lm_arc.olabel = 0;
780  lm_arc.nextstate = src_info->lm_state;
781  lm_arc.weight = fst::StdArc::Weight(0.0);
782  } else if (!det_fst_->GetArc(src_info->lm_state, olabel, &lm_arc)) {
783  // for normal language models we don't expect this to happen, but the
784  // appropriate behavior is to do nothing; the composed arc does not exist,
785  // so there is no arc to add and no new state to create.
786  return;
787  }
788  int32 dest_lm_state = lm_arc.nextstate;
789  // The following assertion is necessary because CompactLattice cannot support
790  // different ilabel vs. olabel; and also it's an expectation about
791  // language-models.
792  KALDI_ASSERT(lm_arc.ilabel == lm_arc.olabel);
793 
794  LatticeStateInfo &dest_lat_state_info =
795  lat_state_info_[dest_lat_state];
796 
797  int32 dest_composed_state;
798  ComposedStateInfo *dest_info;
799 
800  { // The next block works out 'dest_composed_state' and
801  // 'dest_info', and if the destination state did not already
802  // exist, creates a new composed state.
803  typedef std::unordered_map<std::pair<int32,int32>, int32,
804  PairHasher<int32> > MapType;
805  int32 new_composed_state = clat_out_->NumStates();
806  std::pair<const std::pair<int32,int32>, int32> value(
807  std::pair<int32,int32>(dest_lat_state, dest_lm_state), new_composed_state);
808  std::pair<MapType::iterator, bool> ret =
809  pair_to_state_.insert(value);
810  if (ret.second) {
811  // Successfully inserted: this dest-state did not already exist. Most of
812  // the rest of this block deals with the consequences of adding a new
813  // state.
814  int32 ans = clat_out_->AddState();
815  KALDI_ASSERT(ans == new_composed_state);
816  dest_composed_state = new_composed_state;
817  composed_state_info_.resize(dest_composed_state + 1);
818  dest_info = &(composed_state_info_[dest_composed_state]);
819  // Re-assign src_composed_state as the vector might have been reallocated.
820  src_info = &(composed_state_info_[src_composed_state]);
821  if (dest_lat_state_info.composed_states.empty())
822  accessed_lat_states_.insert(dest_lat_state);
823  dest_lat_state_info.composed_states.push_back(new_composed_state);
824  dest_info->lat_state = dest_lat_state;
825  dest_info->lm_state = dest_lm_state;
826  dest_info->depth = src_info->depth + 1;
827  dest_info->forward_cost =
828  src_info->forward_cost +
829  ConvertToCost(lat_arc.weight) + lm_arc.weight.Value();
830  dest_info->backward_cost =
831  std::numeric_limits<double>::infinity();
832  dest_info->delta_backward_cost =
833  src_info->delta_backward_cost + dest_info->depth * depth_penalty_;
834  // The 'prev_composed_state' field will not be read again until after it's
835  // overwritten; we set it as below only for debugging purposes (the
836  // negation is also for debugging purposes).
837  dest_info->prev_composed_state = -src_composed_state;
838  dest_info->sorted_arc_index = 0;
839  dest_info->arc_delta_cost = 0.0;
840  // Note: in the expression below, which can be understood with reference
841  // to the comment by the declaration of the phantom variable
842  // 'expected_cost_offset', 'arc_delta_cost' is known to equal 0.0 so it
843  // has been removed.
844  BaseFloat expected_cost_offset =
845  (dest_info->forward_cost +
846  dest_lat_state_info.backward_cost +
847  dest_info->delta_backward_cost -
849  if (expected_cost_offset < current_cutoff_) {
850  // the following call should be equivalent to
851  // composed_state_queue_.push(std::pair<BaseFloat,int32>(...)) with
852  // the same pair of args.
853  composed_state_queue_.emplace(expected_cost_offset,
854  dest_composed_state);
855  }
856  } else { // the destination composed state already existed.
857  dest_composed_state = ret.first->second;
858  dest_info = &(composed_state_info_[dest_composed_state]);
859  }
860  }
861  // Add the arc from the src to dest state in the composed output.
862  CompactLatticeArc new_arc;
863  new_arc.nextstate = dest_composed_state;
864  // Actually the ilabel and olabel are the same, but writing it this way will
865  // generalize better to type Lattice if we need that later.
866  new_arc.ilabel = ilabel;
867  new_arc.olabel = olabel;
868  new_arc.weight = lat_arc.weight;
869  // 'weight' is the weight part, as opposed to the string part.
870  LatticeArc::Weight weight = new_arc.weight.Weight();
871  // include the LM-arc's weight in the weight of the new arc.
872  weight.SetValue1(fst::Times(weight.Value1(), lm_arc.weight).Value());
873  new_arc.weight.SetWeight(weight);
874  clat_out_->AddArc(src_composed_state, new_arc);
875  num_arcs_out_++;
876 }
877 
878 static int32 TotalNumArcs(const CompactLattice &clat) {
879  int32 num_states = clat.NumStates(),
880  num_arcs = 0;
881  for (int32 s = 0; s < num_states; s++)
882  num_arcs += clat.NumArcs(s);
883  return num_arcs;
884 }
885 
886 
888  if (clat_in_.NumStates() == 0) {
889  KALDI_WARN << "Input lattice to composition is empty.";
890  return;
891  }
893  AddFirstState();
894  // while (we have not reached final state ||
895  // num-arcs produced < target num-arcs) { ...
896  while (output_best_cost_ == std::numeric_limits<double>::infinity() ||
899  int32 this_iter_arc_limit = GetCurrentArcLimit();
900  while (num_arcs_out_ < this_iter_arc_limit &&
901  !composed_state_queue_.empty()) {
902  int32 src_composed_state = composed_state_queue_.top().second;
903  composed_state_queue_.pop();
904  ProcessQueueElement(src_composed_state);
905  }
906  if (composed_state_queue_.empty())
907  break;
908  }
909 
910  fst::Connect(clat_out_);
912 
913  if (GetVerboseLevel() >= 2) {
914  int32 num_arcs_in = TotalNumArcs(clat_in_),
915  orig_num_arcs_out = num_arcs_out_,
916  num_arcs_out = TotalNumArcs(*clat_out_),
917  num_states_in = clat_in_.NumStates(),
918  orig_num_states_out = composed_state_info_.size(),
919  num_states_out = clat_out_->NumStates();
920  std::ostringstream os;
921  os << "Input lattice had " << num_arcs_in << '/' << num_states_in
922  << " arcs/states; output lattice has " << num_arcs_out << '/'
923  << num_states_out;
924  if (num_arcs_out != orig_num_arcs_out) {
925  os << " (before pruning: " << orig_num_arcs_out << '/'
926  << orig_num_states_out << ")";
927  }
928  if (!composed_state_queue_.empty()) {
929  // Below, composed_state_queue_.top().first + lat_best_cost is an
930  // expected-cost of the best path from the composed output that we *did
931  // not* expand. This, minus the best cost in the output compact lattice,
932  // can be interpreted as the beam that we effecctively pruned the output
933  // lattice to.
934  BaseFloat effective_beam =
936  os << ". Effective beam was " << effective_beam;
937  }
938  KALDI_VLOG(2) << os.str();
939  }
940 
941  if (clat_out_->NumStates() == 0) {
942  KALDI_WARN << "Composed lattice has no states: something went wrong.";
943  }
944 }
945 
947  const ComposeLatticePrunedOptions &opts,
948  const CompactLattice &clat,
950  CompactLattice* composed_clat) {
951  PrunedCompactLatticeComposer composer(opts, clat, det_fst, composed_clat);
952  composer.Compose();
953 }
954 
955 } // namespace kaldi
std::vector< LatticeStateInfo > lat_state_info_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)=0
Note: ilabel must not be epsilon.
std::vector< ComposedStateInfo > composed_state_info_
const ComposeLatticePrunedOptions & opts_
virtual Weight Final(StateId s)=0
fst::DeterministicOnDemandFst< fst::StdArc > * det_fst_
void ComputeDeltaBackwardCosts(const std::vector< int32 > &composed_states)
fst::StdArc StdArc
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
void ProcessTransition(int32 composed_src_state, int32 arc_index)
kaldi::int32 int32
virtual StateId Start()=0
PrunedCompactLatticeComposer implements an algorithm for pruned composition.
std::priority_queue< std::pair< BaseFloat, int32 >, std::vector< std::pair< BaseFloat, int32 > >, std::greater< std::pair< BaseFloat, int32 > > > QueueType
unordered_map< std::pair< int32, int32 >, int32, PairHasher< int32 > > pair_to_state_
PrunedCompactLatticeComposer(const ComposeLatticePrunedOptions &opts, const CompactLattice &clat, fst::DeterministicOnDemandFst< fst::StdArc > *det_fst, CompactLattice *composed_clat)
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
void ComputeBackwardCosts(const std::vector< int32 > &composed_states)
float BaseFloat
Definition: kaldi-types.h:29
void ProcessQueueElement(int32 composed_state_to_expand)
void ComputeForwardCosts(const std::vector< int32 > &composed_states)
double ConvertToCost(const LatticeWeightTpl< Float > &w)
std::vector< std::pair< BaseFloat, int32 > > arc_delta_costs
#define KALDI_WARN
Definition: kaldi-error.h:150
void ComposeCompactLatticePruned(const ComposeLatticePrunedOptions &opts, const CompactLattice &clat, fst::DeterministicOnDemandFst< fst::StdArc > *det_fst, CompactLattice *composed_clat)
Does pruned composition of a lattice &#39;clat&#39; with a DeterministicOnDemandFst &#39;det_fst&#39;; implements LM ...
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
fst::StdArc::Weight Weight
void GetTopsortedStateList(std::vector< int32 > *composed_states) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
fst::ArcTpl< CompactLatticeWeight > CompactLatticeArc
Definition: kaldi-lattice.h:42
void TopSortCompactLatticeIfNeeded(CompactLattice *clat)
Topologically sort the compact lattice if not already topologically sorted.
static int32 TotalNumArcs(const CompactLattice &clat)
A hashing function-object for pairs of ints.
Definition: stl-utils.h:235