lattice-functions.cc
Go to the documentation of this file.
1 // lat/lattice-functions.cc
2 
3 // Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal)
4 // 2012-2013 Johns Hopkins University (Author: Daniel Povey); Chao Weng;
5 // Bagher BabaAli
6 // 2013 Cisco Systems (author: Neha Agrawal) [code modified
7 // from original code in ../gmmbin/gmm-rescore-lattice.cc]
8 // 2014 Guoguo Chen
9 
10 // See ../../COPYING for clarification regarding multiple authors
11 //
12 // Licensed under the Apache License, Version 2.0 (the "License");
13 // you may not use this file except in compliance with the License.
14 // You may obtain a copy of the License at
15 //
16 // http://www.apache.org/licenses/LICENSE-2.0
17 //
18 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
19 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
20 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
21 // MERCHANTABLITY OR NON-INFRINGEMENT.
22 // See the Apache 2 License for the specific language governing permissions and
23 // limitations under the License.
24 
25 
26 #include "lat/lattice-functions.h"
27 #include "hmm/transition-model.h"
28 #include "util/stl-utils.h"
29 #include "base/kaldi-math.h"
30 #include "hmm/hmm-utils.h"
31 
32 namespace kaldi {
33 using std::map;
34 using std::vector;
35 
36 void GetPerFrameAcousticCosts(const Lattice &nbest, Vector<BaseFloat> *per_frame_loglikes) {
37  using namespace fst;
39  vector<BaseFloat> loglikes;
40 
41  int32 cur_state = nbest.Start();
42  int32 prev_frame = -1;
43  BaseFloat eps_acwt = 0.0;
44  while(1) {
45  Weight w = nbest.Final(cur_state);
46  if (w != Weight::Zero()) {
47  KALDI_ASSERT(nbest.NumArcs(cur_state) == 0);
48  if (per_frame_loglikes != NULL) {
49  SubVector<BaseFloat> subvec(&(loglikes[0]), loglikes.size());
50  Vector<BaseFloat> vec(subvec);
51  *per_frame_loglikes = vec;
52  }
53  break;
54  } else {
55  KALDI_ASSERT(nbest.NumArcs(cur_state) == 1);
56  fst::ArcIterator<Lattice> iter(nbest, cur_state);
57  const Lattice::Arc &arc = iter.Value();
58  BaseFloat acwt = arc.weight.Value2();
59  if (arc.ilabel != 0) {
60  if (eps_acwt > 0) {
61  acwt += eps_acwt;
62  eps_acwt = 0.0;
63  }
64  loglikes.push_back(acwt);
65  prev_frame++;
66  } else if (acwt == acwt){
67  if (prev_frame > -1) {
68  loglikes[prev_frame] += acwt;
69  } else {
70  eps_acwt += acwt;
71  }
72  }
73  cur_state = arc.nextstate;
74  }
75  }
76 }
77 
78 int32 LatticeStateTimes(const Lattice &lat, vector<int32> *times) {
79  if (!lat.Properties(fst::kTopSorted, true))
80  KALDI_ERR << "Input lattice must be topologically sorted.";
81  KALDI_ASSERT(lat.Start() == 0);
82  int32 num_states = lat.NumStates();
83  times->clear();
84  times->resize(num_states, -1);
85  (*times)[0] = 0;
86  for (int32 state = 0; state < num_states; state++) {
87  int32 cur_time = (*times)[state];
88  for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
89  aiter.Next()) {
90  const LatticeArc &arc = aiter.Value();
91 
92  if (arc.ilabel != 0) { // Non-epsilon input label on arc
93  // next time instance
94  if ((*times)[arc.nextstate] == -1) {
95  (*times)[arc.nextstate] = cur_time + 1;
96  } else {
97  KALDI_ASSERT((*times)[arc.nextstate] == cur_time + 1);
98  }
99  } else { // epsilon input label on arc
100  // Same time instance
101  if ((*times)[arc.nextstate] == -1)
102  (*times)[arc.nextstate] = cur_time;
103  else
104  KALDI_ASSERT((*times)[arc.nextstate] == cur_time);
105  }
106  }
107  }
108  return (*std::max_element(times->begin(), times->end()));
109 }
110 
111 int32 CompactLatticeStateTimes(const CompactLattice &lat, vector<int32> *times) {
112  if (!lat.Properties(fst::kTopSorted, true))
113  KALDI_ERR << "Input lattice must be topologically sorted.";
114  KALDI_ASSERT(lat.Start() == 0);
115  int32 num_states = lat.NumStates();
116  times->clear();
117  times->resize(num_states, -1);
118  (*times)[0] = 0;
119  int32 utt_len = -1;
120  for (int32 state = 0; state < num_states; state++) {
121  int32 cur_time = (*times)[state];
122  for (fst::ArcIterator<CompactLattice> aiter(lat, state); !aiter.Done();
123  aiter.Next()) {
124  const CompactLatticeArc &arc = aiter.Value();
125  int32 arc_len = static_cast<int32>(arc.weight.String().size());
126  if ((*times)[arc.nextstate] == -1)
127  (*times)[arc.nextstate] = cur_time + arc_len;
128  else
129  KALDI_ASSERT((*times)[arc.nextstate] == cur_time + arc_len);
130  }
131  if (lat.Final(state) != CompactLatticeWeight::Zero()) {
132  int32 this_utt_len = (*times)[state] + lat.Final(state).String().size();
133  if (utt_len == -1) utt_len = this_utt_len;
134  else {
135  if (this_utt_len != utt_len) {
136  KALDI_WARN << "Utterance does not "
137  "seem to have a consistent length.";
138  utt_len = std::max(utt_len, this_utt_len);
139  }
140  }
141  }
142  }
143  if (utt_len == -1) {
144  KALDI_WARN << "Utterance does not have a final-state.";
145  return 0;
146  }
147  return utt_len;
148 }
149 
151  vector<double> *alpha) {
152  using namespace fst;
153 
154  // typedef the arc, weight types
155  typedef CompactLattice::Arc Arc;
156  typedef Arc::Weight Weight;
157  typedef Arc::StateId StateId;
158 
159  //Make sure the lattice is topologically sorted.
160  if (clat.Properties(fst::kTopSorted, true) == 0) {
161  KALDI_WARN << "Input lattice must be topologically sorted.";
162  return false;
163  }
164  if (clat.Start() != 0) {
165  KALDI_WARN << "Input lattice must start from state 0.";
166  return false;
167  }
168 
169  int32 num_states = clat.NumStates();
170  (*alpha).resize(0);
171  (*alpha).resize(num_states, kLogZeroDouble);
172 
173  // Now propagate alphas forward. Note that we don't acount the weight of the
174  // final state to alpha[final_state] -- we acount it to beta[final_state];
175  (*alpha)[0] = 0.0;
176  for (StateId s = 0; s < num_states; s++) {
177  double this_alpha = (*alpha)[s];
178  for (ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done(); aiter.Next()) {
179  const Arc &arc = aiter.Value();
180  double arc_like = -(arc.weight.Weight().Value1() + arc.weight.Weight().Value2());
181  (*alpha)[arc.nextstate] = LogAdd((*alpha)[arc.nextstate], this_alpha + arc_like);
182  }
183  }
184 
185  return true;
186 }
187 
189  vector<double> *beta) {
190  using namespace fst;
191 
192  // typedef the arc, weight types
193  typedef CompactLattice::Arc Arc;
194  typedef Arc::Weight Weight;
195  typedef Arc::StateId StateId;
196 
197  // Make sure the lattice is topologically sorted.
198  if (clat.Properties(fst::kTopSorted, true) == 0) {
199  KALDI_WARN << "Input lattice must be topologically sorted.";
200  return false;
201  }
202  if (clat.Start() != 0) {
203  KALDI_WARN << "Input lattice must start from state 0.";
204  return false;
205  }
206 
207  int32 num_states = clat.NumStates();
208  (*beta).resize(0);
209  (*beta).resize(num_states, kLogZeroDouble);
210 
211  // Now propagate betas backward. Note that beta[final_state] contains the
212  // weight of the final state in the lattice -- compare that with alpha.
213  for (StateId s = num_states-1; s >= 0; s--) {
214  Weight f = clat.Final(s);
215  double this_beta = -(f.Weight().Value1()+f.Weight().Value2());
216  for (ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done(); aiter.Next()) {
217  const Arc &arc = aiter.Value();
218  double arc_like = -(arc.weight.Weight().Value1()+arc.weight.Weight().Value2());
219  double arc_beta = (*beta)[arc.nextstate] + arc_like;
220  this_beta = LogAdd(this_beta, arc_beta);
221  }
222  (*beta)[s] = this_beta;
223  }
224 
225  return true;
226 }
227 
228 template<class LatType> // could be Lattice or CompactLattice
229 bool PruneLattice(BaseFloat beam, LatType *lat) {
230  typedef typename LatType::Arc Arc;
231  typedef typename Arc::Weight Weight;
232  typedef typename Arc::StateId StateId;
233 
234  KALDI_ASSERT(beam > 0.0);
235  if (!lat->Properties(fst::kTopSorted, true)) {
236  if (fst::TopSort(lat) == false) {
237  KALDI_WARN << "Cycles detected in lattice";
238  return false;
239  }
240  }
241  // We assume states before "start" are not reachable, since
242  // the lattice is topologically sorted.
243  int32 start = lat->Start();
244  int32 num_states = lat->NumStates();
245  if (num_states == 0) return false;
246  std::vector<double> forward_cost(num_states,
247  std::numeric_limits<double>::infinity()); // viterbi forward.
248  forward_cost[start] = 0.0; // lattice can't have cycles so couldn't be
249  // less than this.
250  double best_final_cost = std::numeric_limits<double>::infinity();
251  // Update the forward probs.
252  // Thanks to Jing Zheng for finding a bug here.
253  for (int32 state = 0; state < num_states; state++) {
254  double this_forward_cost = forward_cost[state];
255  for (fst::ArcIterator<LatType> aiter(*lat, state);
256  !aiter.Done();
257  aiter.Next()) {
258  const Arc &arc(aiter.Value());
259  StateId nextstate = arc.nextstate;
260  KALDI_ASSERT(nextstate > state && nextstate < num_states);
261  double next_forward_cost = this_forward_cost +
262  ConvertToCost(arc.weight);
263  if (forward_cost[nextstate] > next_forward_cost)
264  forward_cost[nextstate] = next_forward_cost;
265  }
266  Weight final_weight = lat->Final(state);
267  double this_final_cost = this_forward_cost +
268  ConvertToCost(final_weight);
269  if (this_final_cost < best_final_cost)
270  best_final_cost = this_final_cost;
271  }
272  int32 bad_state = lat->AddState(); // this state is not final.
273  double cutoff = best_final_cost + beam;
274 
275  // Go backwards updating the backward probs (which share memory with the
276  // forward probs), and pruning arcs and deleting final-probs. We prune arcs
277  // by making them point to the non-final state "bad_state". We'll then use
278  // Trim() to remove unnecessary arcs and states. [this is just easier than
279  // doing it ourselves.]
280  std::vector<double> &backward_cost(forward_cost);
281  for (int32 state = num_states - 1; state >= 0; state--) {
282  double this_forward_cost = forward_cost[state];
283  double this_backward_cost = ConvertToCost(lat->Final(state));
284  if (this_backward_cost + this_forward_cost > cutoff
285  && this_backward_cost != std::numeric_limits<double>::infinity())
286  lat->SetFinal(state, Weight::Zero());
287  for (fst::MutableArcIterator<LatType> aiter(lat, state);
288  !aiter.Done();
289  aiter.Next()) {
290  Arc arc(aiter.Value());
291  StateId nextstate = arc.nextstate;
292  KALDI_ASSERT(nextstate > state && nextstate < num_states);
293  double arc_cost = ConvertToCost(arc.weight),
294  arc_backward_cost = arc_cost + backward_cost[nextstate],
295  this_fb_cost = this_forward_cost + arc_backward_cost;
296  if (arc_backward_cost < this_backward_cost)
297  this_backward_cost = arc_backward_cost;
298  if (this_fb_cost > cutoff) { // Prune the arc.
299  arc.nextstate = bad_state;
300  aiter.SetValue(arc);
301  }
302  }
303  backward_cost[state] = this_backward_cost;
304  }
305  fst::Connect(lat);
306  return (lat->NumStates() > 0);
307 }
308 
309 // instantiate the template for lattice and CompactLattice.
310 template bool PruneLattice(BaseFloat beam, Lattice *lat);
311 template bool PruneLattice(BaseFloat beam, CompactLattice *lat);
312 
313 
315  double *acoustic_like_sum) {
316  // Note, Posterior is defined as follows: Indexed [frame], then a list
317  // of (transition-id, posterior-probability) pairs.
318  // typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
319  using namespace fst;
320  typedef Lattice::Arc Arc;
321  typedef Arc::Weight Weight;
322  typedef Arc::StateId StateId;
323 
324  if (acoustic_like_sum) *acoustic_like_sum = 0.0;
325 
326  // Make sure the lattice is topologically sorted.
327  if (lat.Properties(fst::kTopSorted, true) == 0)
328  KALDI_ERR << "Input lattice must be topologically sorted.";
329  KALDI_ASSERT(lat.Start() == 0);
330 
331  int32 num_states = lat.NumStates();
332  vector<int32> state_times;
333  int32 max_time = LatticeStateTimes(lat, &state_times);
334  std::vector<double> alpha(num_states, kLogZeroDouble);
335  std::vector<double> &beta(alpha); // we re-use the same memory for
336  // this, but it's semantically distinct so we name it differently.
337  double tot_forward_prob = kLogZeroDouble;
338 
339  post->clear();
340  post->resize(max_time);
341 
342  alpha[0] = 0.0;
343  // Propagate alphas forward.
344  for (StateId s = 0; s < num_states; s++) {
345  double this_alpha = alpha[s];
346  for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
347  const Arc &arc = aiter.Value();
348  double arc_like = -ConvertToCost(arc.weight);
349  alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like);
350  }
351  Weight f = lat.Final(s);
352  if (f != Weight::Zero()) {
353  double final_like = this_alpha - (f.Value1() + f.Value2());
354  tot_forward_prob = LogAdd(tot_forward_prob, final_like);
355  KALDI_ASSERT(state_times[s] == max_time &&
356  "Lattice is inconsistent (final-prob not at max_time)");
357  }
358  }
359  for (StateId s = num_states-1; s >= 0; s--) {
360  Weight f = lat.Final(s);
361  double this_beta = -(f.Value1() + f.Value2());
362  for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
363  const Arc &arc = aiter.Value();
364  double arc_like = -ConvertToCost(arc.weight),
365  arc_beta = beta[arc.nextstate] + arc_like;
366  this_beta = LogAdd(this_beta, arc_beta);
367  int32 transition_id = arc.ilabel;
368 
369  // The following "if" is an optimization to avoid un-needed exp().
370  if (transition_id != 0 || acoustic_like_sum != NULL) {
371  double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
372 
373  if (transition_id != 0) // Arc has a transition-id on it [not epsilon]
374  (*post)[state_times[s]].push_back(std::make_pair(transition_id,
375  static_cast<kaldi::BaseFloat>(posterior)));
376  if (acoustic_like_sum != NULL)
377  *acoustic_like_sum -= posterior * arc.weight.Value2();
378  }
379  }
380  if (acoustic_like_sum != NULL && f != Weight::Zero()) {
381  double final_logprob = - ConvertToCost(f),
382  posterior = Exp(alpha[s] + final_logprob - tot_forward_prob);
383  *acoustic_like_sum -= posterior * f.Value2();
384  }
385  beta[s] = this_beta;
386  }
387  double tot_backward_prob = beta[0];
388  if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
389  KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob
390  << ", while total backward probability = " << tot_backward_prob;
391  }
392  // Now combine any posteriors with the same transition-id.
393  for (int32 t = 0; t < max_time; t++)
394  MergePairVectorSumming(&((*post)[t]));
395  return tot_backward_prob;
396 }
397 
398 
399 void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
400  const vector<int32> &silence_phones,
401  vector< std::set<int32> > *active_phones) {
402  KALDI_ASSERT(IsSortedAndUniq(silence_phones));
403  vector<int32> state_times;
404  int32 num_states = lat.NumStates();
405  int32 max_time = LatticeStateTimes(lat, &state_times);
406  active_phones->clear();
407  active_phones->resize(max_time);
408  for (int32 state = 0; state < num_states; state++) {
409  int32 cur_time = state_times[state];
410  for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
411  aiter.Next()) {
412  const LatticeArc &arc = aiter.Value();
413  if (arc.ilabel != 0) { // Non-epsilon arc
414  int32 phone = trans.TransitionIdToPhone(arc.ilabel);
415  if (!std::binary_search(silence_phones.begin(),
416  silence_phones.end(), phone))
417  (*active_phones)[cur_time].insert(phone);
418  }
419  } // end looping over arcs
420  } // end looping over states
421 }
422 
424  Lattice *lat) {
425  typedef LatticeArc Arc;
426  int32 num_states = lat->NumStates();
427  for (int32 state = 0; state < num_states; state++) {
428  for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
429  aiter.Next()) {
430  Arc arc(aiter.Value());
431  arc.olabel = 0; // remove any word.
432  if ((arc.ilabel != 0) // has a transition-id on input..
433  && (trans.TransitionIdToHmmState(arc.ilabel) == 0)
434  && (!trans.IsSelfLoop(arc.ilabel))) {
435  // && trans.IsFinal(arc.ilabel)) // there is one of these per phone...
436  arc.olabel = trans.TransitionIdToPhone(arc.ilabel);
437  }
438  aiter.SetValue(arc);
439  } // end looping over arcs
440  } // end looping over states
441 }
442 
443 
444 static inline double LogAddOrMax(bool viterbi, double a, double b) {
445  if (viterbi)
446  return std::max(a, b);
447  else
448  return LogAdd(a, b);
449 }
450 
451 template<typename LatticeType>
452 double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
453  bool viterbi,
454  vector<double> *alpha,
455  vector<double> *beta) {
456  typedef typename LatticeType::Arc Arc;
457  typedef typename Arc::Weight Weight;
458  typedef typename Arc::StateId StateId;
459 
460  StateId num_states = lat.NumStates();
461  KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted);
462  KALDI_ASSERT(lat.Start() == 0);
463  alpha->clear();
464  beta->clear();
465  alpha->resize(num_states, kLogZeroDouble);
466  beta->resize(num_states, kLogZeroDouble);
467 
468  double tot_forward_prob = kLogZeroDouble;
469  (*alpha)[0] = 0.0;
470  // Propagate alphas forward.
471  for (StateId s = 0; s < num_states; s++) {
472  double this_alpha = (*alpha)[s];
473  for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
474  aiter.Next()) {
475  const Arc &arc = aiter.Value();
476  double arc_like = -ConvertToCost(arc.weight);
477  (*alpha)[arc.nextstate] = LogAddOrMax(viterbi, (*alpha)[arc.nextstate],
478  this_alpha + arc_like);
479  }
480  Weight f = lat.Final(s);
481  if (f != Weight::Zero()) {
482  double final_like = this_alpha - ConvertToCost(f);
483  tot_forward_prob = LogAddOrMax(viterbi, tot_forward_prob, final_like);
484  }
485  }
486  for (StateId s = num_states-1; s >= 0; s--) { // it's guaranteed signed.
487  double this_beta = -ConvertToCost(lat.Final(s));
488  for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
489  aiter.Next()) {
490  const Arc &arc = aiter.Value();
491  double arc_like = -ConvertToCost(arc.weight),
492  arc_beta = (*beta)[arc.nextstate] + arc_like;
493  this_beta = LogAddOrMax(viterbi, this_beta, arc_beta);
494  }
495  (*beta)[s] = this_beta;
496  }
497  double tot_backward_prob = (*beta)[lat.Start()];
498  if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
499  KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob
500  << ", while total backward probability = " << tot_backward_prob;
501  }
502  // Split the difference when returning... they should be the same.
503  return 0.5 * (tot_backward_prob + tot_forward_prob);
504 }
505 
506 // instantiate the template for Lattice and CompactLattice
507 template
508 double ComputeLatticeAlphasAndBetas(const Lattice &lat,
509  bool viterbi,
510  vector<double> *alpha,
511  vector<double> *beta);
512 
513 template
515  bool viterbi,
516  vector<double> *alpha,
517  vector<double> *beta);
518 
519 
520 
523  BaseFloat logprob; // logprob <= 0 is the best Viterbi logprob of this arc,
524  // minus the overall best-cost of the lattice.
525  CompactLatticeArc::StateId state; // state in the lattice.
526  size_t arc; // arc index within the state.
527  bool operator < (const LatticeArcRecord &other) const {
528  return logprob < other.logprob;
529  }
530 };
531 
532 void CompactLatticeLimitDepth(int32 max_depth_per_frame,
533  CompactLattice *clat) {
534  typedef CompactLatticeArc Arc;
535  typedef Arc::Weight Weight;
536  typedef Arc::StateId StateId;
537 
538  if (clat->Start() == fst::kNoStateId) {
539  KALDI_WARN << "Limiting depth of empty lattice.";
540  return;
541  }
542  if (clat->Properties(fst::kTopSorted, true) == 0) {
543  if (!TopSort(clat))
544  KALDI_ERR << "Topological sorting of lattice failed.";
545  }
546 
547  vector<int32> state_times;
548  int32 T = CompactLatticeStateTimes(*clat, &state_times);
549 
550  // The alpha and beta quantities here are "viterbi" alphas and beta.
551  std::vector<double> alpha;
552  std::vector<double> beta;
553  bool viterbi = true;
554  double best_prob = ComputeLatticeAlphasAndBetas(*clat, viterbi,
555  &alpha, &beta);
556 
557  std::vector<std::vector<LatticeArcRecord> > arc_records(T);
558 
559  StateId num_states = clat->NumStates();
560  for (StateId s = 0; s < num_states; s++) {
561  for (fst::ArcIterator<CompactLattice> aiter(*clat, s); !aiter.Done();
562  aiter.Next()) {
563  const Arc &arc = aiter.Value();
564  LatticeArcRecord arc_record;
565  arc_record.state = s;
566  arc_record.arc = aiter.Position();
567  arc_record.logprob =
568  (alpha[s] + beta[arc.nextstate] - ConvertToCost(arc.weight))
569  - best_prob;
570  KALDI_ASSERT(arc_record.logprob < 0.1); // Should be zero or negative.
571  int32 num_frames = arc.weight.String().size(), start_t = state_times[s];
572  for (int32 t = start_t; t < start_t + num_frames; t++) {
573  KALDI_ASSERT(t < T);
574  arc_records[t].push_back(arc_record);
575  }
576  }
577  }
578  StateId dead_state = clat->AddState(); // A non-coaccesible state which we use
579  // to remove arcs (make them end
580  // there).
581  size_t max_depth = max_depth_per_frame;
582  for (int32 t = 0; t < T; t++) {
583  size_t size = arc_records[t].size();
584  if (size > max_depth) {
585  // we sort from worst to best, so we keep the later-numbered ones,
586  // and delete the lower-numbered ones.
587  size_t cutoff = size - max_depth;
588  std::nth_element(arc_records[t].begin(),
589  arc_records[t].begin() + cutoff,
590  arc_records[t].end());
591  for (size_t index = 0; index < cutoff; index++) {
592  LatticeArcRecord record(arc_records[t][index]);
593  fst::MutableArcIterator<CompactLattice> aiter(clat, record.state);
594  aiter.Seek(record.arc);
595  Arc arc = aiter.Value();
596  if (arc.nextstate != dead_state) { // not already killed.
597  arc.nextstate = dead_state;
598  aiter.SetValue(arc);
599  }
600  }
601  }
602  }
603  Connect(clat);
605 }
606 
607 
609  if (clat->Properties(fst::kTopSorted, true) == 0) {
610  if (fst::TopSort(clat) == false) {
611  KALDI_ERR << "Topological sorting failed";
612  }
613  }
614 }
615 
617  if (lat->Properties(fst::kTopSorted, true) == 0) {
618  if (fst::TopSort(lat) == false) {
619  KALDI_ERR << "Topological sorting failed";
620  }
621  }
622 }
623 
624 
629  int32 *num_frames) {
631  if (clat.Properties(fst::kTopSorted, true) == 0) {
632  KALDI_ERR << "Lattice input to CompactLatticeDepth was not topologically "
633  << "sorted.";
634  }
635  if (clat.Start() == fst::kNoStateId) {
636  *num_frames = 0;
637  return 1.0;
638  }
639  size_t num_arc_frames = 0;
640  int32 t;
641  {
642  vector<int32> state_times;
643  t = CompactLatticeStateTimes(clat, &state_times);
644  }
645  if (num_frames != NULL)
646  *num_frames = t;
647  for (StateId s = 0; s < clat.NumStates(); s++) {
648  for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
649  aiter.Next()) {
650  const CompactLatticeArc &arc = aiter.Value();
651  num_arc_frames += arc.weight.String().size();
652  }
653  num_arc_frames += clat.Final(s).String().size();
654  }
655  return num_arc_frames / static_cast<BaseFloat>(t);
656 }
657 
658 
660  std::vector<int32> *depth_per_frame) {
662  if (clat.Properties(fst::kTopSorted, true) == 0) {
663  KALDI_ERR << "Lattice input to CompactLatticeDepthPerFrame was not "
664  << "topologically sorted.";
665  }
666  if (clat.Start() == fst::kNoStateId) {
667  depth_per_frame->clear();
668  return;
669  }
670  vector<int32> state_times;
671  int32 T = CompactLatticeStateTimes(clat, &state_times);
672 
673  depth_per_frame->clear();
674  if (T <= 0) {
675  return;
676  } else {
677  depth_per_frame->resize(T, 0);
678  for (StateId s = 0; s < clat.NumStates(); s++) {
679  int32 start_time = state_times[s];
680  for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
681  aiter.Next()) {
682  const CompactLatticeArc &arc = aiter.Value();
683  int32 len = arc.weight.String().size();
684  for (int32 t = start_time; t < start_time + len; t++) {
685  KALDI_ASSERT(t < T);
686  (*depth_per_frame)[t]++;
687  }
688  }
689  int32 final_len = clat.Final(s).String().size();
690  for (int32 t = start_time; t < start_time + final_len; t++) {
691  KALDI_ASSERT(t < T);
692  (*depth_per_frame)[t]++;
693  }
694  }
695  }
696 }
697 
698 
699 
701  CompactLattice *clat) {
702  typedef CompactLatticeArc Arc;
703  typedef Arc::Weight Weight;
704  int32 num_states = clat->NumStates();
705  for (int32 state = 0; state < num_states; state++) {
706  for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
707  !aiter.Done();
708  aiter.Next()) {
709  Arc arc(aiter.Value());
710  std::vector<int32> phone_seq;
711  const std::vector<int32> &tid_seq = arc.weight.String();
712  for (std::vector<int32>::const_iterator iter = tid_seq.begin();
713  iter != tid_seq.end(); ++iter) {
714  if (trans.IsFinal(*iter))// note: there is one of these per phone...
715  phone_seq.push_back(trans.TransitionIdToPhone(*iter));
716  }
717  arc.weight.SetString(phone_seq);
718  aiter.SetValue(arc);
719  } // end looping over arcs
720  Weight f = clat->Final(state);
721  if (f != Weight::Zero()) {
722  std::vector<int32> phone_seq;
723  const std::vector<int32> &tid_seq = f.String();
724  for (std::vector<int32>::const_iterator iter = tid_seq.begin();
725  iter != tid_seq.end(); ++iter) {
726  if (trans.IsFinal(*iter))// note: there is one of these per phone...
727  phone_seq.push_back(trans.TransitionIdToPhone(*iter));
728  }
729  f.SetString(phone_seq);
730  clat->SetFinal(state, f);
731  }
732  } // end looping over states
733 }
734 
735 bool LatticeBoost(const TransitionModel &trans,
736  const std::vector<int32> &alignment,
737  const std::vector<int32> &silence_phones,
738  BaseFloat b,
739  BaseFloat max_silence_error,
740  Lattice *lat) {
742 
743  // get all stored properties (test==false means don't test if not known).
744  uint64 props = lat->Properties(fst::kFstProperties,
745  false);
746 
747  KALDI_ASSERT(IsSortedAndUniq(silence_phones));
748  KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0);
749  vector<int32> state_times;
750  int32 num_states = lat->NumStates();
751  int32 num_frames = LatticeStateTimes(*lat, &state_times);
752  KALDI_ASSERT(num_frames == static_cast<int32>(alignment.size()));
753  for (int32 state = 0; state < num_states; state++) {
754  int32 cur_time = state_times[state];
755  for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
756  aiter.Next()) {
757  LatticeArc arc = aiter.Value();
758  if (arc.ilabel != 0) { // Non-epsilon arc
759  if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) {
760  KALDI_WARN << "Lattice has out-of-range transition-ids: "
761  << "lattice/model mismatch?";
762  return false;
763  }
764  int32 phone = trans.TransitionIdToPhone(arc.ilabel),
765  ref_phone = trans.TransitionIdToPhone(alignment[cur_time]);
766  BaseFloat frame_error;
767  if (phone == ref_phone) {
768  frame_error = 0.0;
769  } else { // an error...
770  if (std::binary_search(silence_phones.begin(), silence_phones.end(), phone))
771  frame_error = max_silence_error;
772  else
773  frame_error = 1.0;
774  }
775  BaseFloat delta_cost = -b * frame_error; // negative cost if
776  // frame is wrong, to boost likelihood of arcs with errors on them.
777  // Add this cost to the graph part.
778  arc.weight.SetValue1(arc.weight.Value1() + delta_cost);
779  aiter.SetValue(arc);
780  }
781  }
782  }
783  // All we changed is the weights, so any properties that were
784  // known before, are still known, except for whether or not the
785  // lattice was weighted.
786  lat->SetProperties(props,
787  ~(fst::kWeighted|fst::kUnweighted));
788 
789  return true;
790 }
791 
792 
793 
795  const TransitionModel &trans,
796  const std::vector<int32> &silence_phones,
797  const Lattice &lat,
798  const std::vector<int32> &num_ali,
799  std::string criterion,
800  bool one_silence_class,
801  Posterior *post) {
802  using namespace fst;
803  typedef Lattice::Arc Arc;
804  typedef Arc::Weight Weight;
805  typedef Arc::StateId StateId;
806 
807  KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr");
808  bool is_mpfe = (criterion == "mpfe");
809 
810  if (lat.Properties(fst::kTopSorted, true) == 0)
811  KALDI_ERR << "Input lattice must be topologically sorted.";
812  KALDI_ASSERT(lat.Start() == 0);
813 
814  int32 num_states = lat.NumStates();
815  vector<int32> state_times;
816  int32 max_time = LatticeStateTimes(lat, &state_times);
817  KALDI_ASSERT(max_time == static_cast<int32>(num_ali.size()));
818  std::vector<double> alpha(num_states, kLogZeroDouble),
819  alpha_smbr(num_states, 0), //forward variable for sMBR
820  beta(num_states, kLogZeroDouble),
821  beta_smbr(num_states, 0); //backward variable for sMBR
822 
823  double tot_forward_prob = kLogZeroDouble;
824  double tot_forward_score = 0;
825 
826  post->clear();
827  post->resize(max_time);
828 
829  alpha[0] = 0.0;
830  // First Pass Forward,
831  for (StateId s = 0; s < num_states; s++) {
832  double this_alpha = alpha[s];
833  for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
834  const Arc &arc = aiter.Value();
835  double arc_like = -ConvertToCost(arc.weight);
836  alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like);
837  }
838  Weight f = lat.Final(s);
839  if (f != Weight::Zero()) {
840  double final_like = this_alpha - (f.Value1() + f.Value2());
841  tot_forward_prob = LogAdd(tot_forward_prob, final_like);
842  KALDI_ASSERT(state_times[s] == max_time &&
843  "Lattice is inconsistent (final-prob not at max_time)");
844  }
845  }
846  // First Pass Backward,
847  for (StateId s = num_states-1; s >= 0; s--) {
848  Weight f = lat.Final(s);
849  double this_beta = -(f.Value1() + f.Value2());
850  for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
851  const Arc &arc = aiter.Value();
852  double arc_like = -ConvertToCost(arc.weight),
853  arc_beta = beta[arc.nextstate] + arc_like;
854  this_beta = LogAdd(this_beta, arc_beta);
855  }
856  beta[s] = this_beta;
857  }
858  // First Pass Forward-Backward Check
859  double tot_backward_prob = beta[0];
860  // may loose the condition somehow here 1e-6 (was 1e-8)
861  if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-6)) {
862  KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob
863  << ", while total backward probability = " << tot_backward_prob;
864  }
865 
866  alpha_smbr[0] = 0.0;
867  // Second Pass Forward, calculate forward for MPFE/SMBR
868  for (StateId s = 0; s < num_states; s++) {
869  double this_alpha = alpha[s];
870  for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
871  const Arc &arc = aiter.Value();
872  double arc_like = -ConvertToCost(arc.weight);
873  double frame_acc = 0.0;
874  if (arc.ilabel != 0) {
875  int32 cur_time = state_times[s];
876  int32 phone = trans.TransitionIdToPhone(arc.ilabel),
877  ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
878  bool phone_is_sil = std::binary_search(silence_phones.begin(),
879  silence_phones.end(),
880  phone),
881  ref_phone_is_sil = std::binary_search(silence_phones.begin(),
882  silence_phones.end(),
883  ref_phone),
884  both_sil = phone_is_sil && ref_phone_is_sil;
885  if (!is_mpfe) { // smbr.
886  int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
887  ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
888  if (!one_silence_class) // old behavior
889  frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
890  else
891  frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
892  } else {
893  if (!one_silence_class) // old behavior
894  frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
895  else
896  frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
897  }
898  }
899  double arc_scale = Exp(alpha[s] + arc_like - alpha[arc.nextstate]);
900  alpha_smbr[arc.nextstate] += arc_scale * (alpha_smbr[s] + frame_acc);
901  }
902  Weight f = lat.Final(s);
903  if (f != Weight::Zero()) {
904  double final_like = this_alpha - (f.Value1() + f.Value2());
905  double arc_scale = Exp(final_like - tot_forward_prob);
906  tot_forward_score += arc_scale * alpha_smbr[s];
907  KALDI_ASSERT(state_times[s] == max_time &&
908  "Lattice is inconsistent (final-prob not at max_time)");
909  }
910  }
911  // Second Pass Backward, collect Mpe style posteriors
912  for (StateId s = num_states-1; s >= 0; s--) {
913  for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
914  const Arc &arc = aiter.Value();
915  double arc_like = -ConvertToCost(arc.weight),
916  arc_beta = beta[arc.nextstate] + arc_like;
917  double frame_acc = 0.0;
918  int32 transition_id = arc.ilabel;
919  if (arc.ilabel != 0) {
920  int32 cur_time = state_times[s];
921  int32 phone = trans.TransitionIdToPhone(arc.ilabel),
922  ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
923  bool phone_is_sil = std::binary_search(silence_phones.begin(),
924  silence_phones.end(), phone),
925  ref_phone_is_sil = std::binary_search(silence_phones.begin(),
926  silence_phones.end(),
927  ref_phone),
928  both_sil = phone_is_sil && ref_phone_is_sil;
929  if (!is_mpfe) { // smbr.
930  int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
931  ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
932  if (!one_silence_class) // old behavior
933  frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
934  else
935  frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
936  } else {
937  if (!one_silence_class) // old behavior
938  frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
939  else
940  frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
941  }
942  }
943  double arc_scale = Exp(beta[arc.nextstate] + arc_like - beta[s]);
944  // check arc_scale NAN,
945  // this is to prevent partial paths in Lattices
946  // i.e., paths don't survive to the final state
947  if (KALDI_ISNAN(arc_scale)) arc_scale = 0;
948  beta_smbr[s] += arc_scale * (beta_smbr[arc.nextstate] + frame_acc);
949 
950  if (transition_id != 0) { // Arc has a transition-id on it [not epsilon]
951  double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
952  double acc_diff = alpha_smbr[s] + frame_acc + beta_smbr[arc.nextstate]
953  - tot_forward_score;
954  double posterior_smbr = posterior * acc_diff;
955  (*post)[state_times[s]].push_back(std::make_pair(transition_id,
956  static_cast<BaseFloat>(posterior_smbr)));
957  }
958  }
959  }
960 
961  //Second Pass Forward Backward check
962  double tot_backward_score = beta_smbr[0]; // Initial state id == 0
963  // may loose the condition somehow here 1e-5/1e-4
964  if (!ApproxEqual(tot_forward_score, tot_backward_score, 1e-4)) {
965  KALDI_ERR << "Total forward score over lattice = " << tot_forward_score
966  << ", while total backward score = " << tot_backward_score;
967  }
968 
969  // Output the computed posteriors
970  for (int32 t = 0; t < max_time; t++)
971  MergePairVectorSumming(&((*post)[t]));
972  return tot_forward_score;
973 }
974 
976  std::vector<int32> *words,
977  std::vector<int32> *begin_times,
978  std::vector<int32> *lengths) {
979  words->clear();
980  begin_times->clear();
981  lengths->clear();
982  typedef CompactLattice::Arc Arc;
983  typedef Arc::Label Label;
986  using namespace fst;
987  StateId state = clat.Start();
988  int32 cur_time = 0;
989  if (state == kNoStateId) {
990  KALDI_WARN << "Empty lattice.";
991  return false;
992  }
993  while (1) {
994  Weight final = clat.Final(state);
995  size_t num_arcs = clat.NumArcs(state);
996  if (final != Weight::Zero()) {
997  if (num_arcs != 0) {
998  KALDI_WARN << "Lattice is not linear.";
999  return false;
1000  }
1001  if (! final.String().empty()) {
1002  KALDI_WARN << "Lattice has alignments on final-weight: probably "
1003  "was not word-aligned (alignments will be approximate)";
1004  }
1005  return true;
1006  } else {
1007  if (num_arcs != 1) {
1008  KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
1009  return false;
1010  }
1011  fst::ArcIterator<CompactLattice> aiter(clat, state);
1012  const Arc &arc = aiter.Value();
1013  Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
1014  // Also note: word_id may be zero; we output it anyway.
1015  int32 length = arc.weight.String().size();
1016  words->push_back(word_id);
1017  begin_times->push_back(cur_time);
1018  lengths->push_back(length);
1019  cur_time += length;
1020  state = arc.nextstate;
1021  }
1022  }
1023 }
1024 
1025 
1027  const TransitionModel &tmodel,
1028  const CompactLattice &clat,
1029  std::vector<int32> *words,
1030  std::vector<int32> *begin_times,
1031  std::vector<int32> *lengths,
1032  std::vector<std::vector<int32> > *prons,
1033  std::vector<std::vector<int32> > *phone_lengths) {
1034  words->clear();
1035  begin_times->clear();
1036  lengths->clear();
1037  prons->clear();
1038  phone_lengths->clear();
1039  typedef CompactLattice::Arc Arc;
1040  typedef Arc::Label Label;
1043  using namespace fst;
1044  StateId state = clat.Start();
1045  int32 cur_time = 0;
1046  if (state == kNoStateId) {
1047  KALDI_WARN << "Empty lattice.";
1048  return false;
1049  }
1050  while (1) {
1051  Weight final = clat.Final(state);
1052  size_t num_arcs = clat.NumArcs(state);
1053  if (final != Weight::Zero()) {
1054  if (num_arcs != 0) {
1055  KALDI_WARN << "Lattice is not linear.";
1056  return false;
1057  }
1058  if (! final.String().empty()) {
1059  KALDI_WARN << "Lattice has alignments on final-weight: probably "
1060  "was not word-aligned (alignments will be approximate)";
1061  }
1062  return true;
1063  } else {
1064  if (num_arcs != 1) {
1065  KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
1066  return false;
1067  }
1068  fst::ArcIterator<CompactLattice> aiter(clat, state);
1069  const Arc &arc = aiter.Value();
1070  Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
1071  // Also note: word_id may be zero; we output it anyway.
1072  int32 length = arc.weight.String().size();
1073  words->push_back(word_id);
1074  begin_times->push_back(cur_time);
1075  lengths->push_back(length);
1076  const std::vector<int32> &arc_alignment = arc.weight.String();
1077  std::vector<std::vector<int32> > split_alignment;
1078  SplitToPhones(tmodel, arc_alignment, &split_alignment);
1079  std::vector<int32> phones(split_alignment.size());
1080  std::vector<int32> plengths(split_alignment.size());
1081  for (size_t i = 0; i < split_alignment.size(); i++) {
1082  KALDI_ASSERT(!split_alignment[i].empty());
1083  phones[i] = tmodel.TransitionIdToPhone(split_alignment[i][0]);
1084  plengths[i] = split_alignment[i].size();
1085  }
1086  prons->push_back(phones);
1087  phone_lengths->push_back(plengths);
1088 
1089  cur_time += length;
1090  state = arc.nextstate;
1091  }
1092  }
1093 }
1094 
1095 
1096 
1098  CompactLattice *shortest_path) {
1099  using namespace fst;
1100  if (clat.Properties(fst::kTopSorted, true) == 0) {
1101  CompactLattice clat_copy(clat);
1102  if (!TopSort(&clat_copy))
1103  KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
1104  CompactLatticeShortestPath(clat_copy, shortest_path);
1105  return;
1106  }
1107  // Now we can assume it's topologically sorted.
1108  shortest_path->DeleteStates();
1109  if (clat.Start() == kNoStateId) return;
1110  typedef CompactLatticeArc Arc;
1111  typedef Arc::StateId StateId;
1112  typedef CompactLatticeWeight Weight;
1113  vector<std::pair<double, StateId> > best_cost_and_pred(clat.NumStates() + 1);
1114  StateId superfinal = clat.NumStates();
1115  for (StateId s = 0; s <= clat.NumStates(); s++) {
1116  best_cost_and_pred[s].first = std::numeric_limits<double>::infinity();
1117  best_cost_and_pred[s].second = fst::kNoStateId;
1118  }
1119  best_cost_and_pred[clat.Start()].first = 0;
1120  for (StateId s = 0; s < clat.NumStates(); s++) {
1121  double my_cost = best_cost_and_pred[s].first;
1122  for (ArcIterator<CompactLattice> aiter(clat, s);
1123  !aiter.Done();
1124  aiter.Next()) {
1125  const Arc &arc = aiter.Value();
1126  double arc_cost = ConvertToCost(arc.weight),
1127  next_cost = my_cost + arc_cost;
1128  if (next_cost < best_cost_and_pred[arc.nextstate].first) {
1129  best_cost_and_pred[arc.nextstate].first = next_cost;
1130  best_cost_and_pred[arc.nextstate].second = s;
1131  }
1132  }
1133  double final_cost = ConvertToCost(clat.Final(s)),
1134  tot_final = my_cost + final_cost;
1135  if (tot_final < best_cost_and_pred[superfinal].first) {
1136  best_cost_and_pred[superfinal].first = tot_final;
1137  best_cost_and_pred[superfinal].second = s;
1138  }
1139  }
1140  std::vector<StateId> states; // states on best path.
1141  StateId cur_state = superfinal, start_state = clat.Start();
1142  while (cur_state != start_state) {
1143  StateId prev_state = best_cost_and_pred[cur_state].second;
1144  if (prev_state == kNoStateId) {
1145  KALDI_WARN << "Failure in best-path algorithm for lattice (infinite costs?)";
1146  return; // return empty best-path.
1147  }
1148  states.push_back(prev_state);
1149  KALDI_ASSERT(cur_state != prev_state && "Lattice with cycles");
1150  cur_state = prev_state;
1151  }
1152  std::reverse(states.begin(), states.end());
1153  for (size_t i = 0; i < states.size(); i++)
1154  shortest_path->AddState();
1155  for (StateId s = 0; static_cast<size_t>(s) < states.size(); s++) {
1156  if (s == 0) shortest_path->SetStart(s);
1157  if (static_cast<size_t>(s + 1) < states.size()) { // transition to next state.
1158  bool have_arc = false;
1159  Arc cur_arc;
1160  for (ArcIterator<CompactLattice> aiter(clat, states[s]);
1161  !aiter.Done();
1162  aiter.Next()) {
1163  const Arc &arc = aiter.Value();
1164  if (arc.nextstate == states[s+1]) {
1165  if (!have_arc ||
1166  ConvertToCost(arc.weight) < ConvertToCost(cur_arc.weight)) {
1167  cur_arc = arc;
1168  have_arc = true;
1169  }
1170  }
1171  }
1172  KALDI_ASSERT(have_arc && "Code error.");
1173  shortest_path->AddArc(s, Arc(cur_arc.ilabel, cur_arc.olabel,
1174  cur_arc.weight, s+1));
1175  } else { // final-prob.
1176  shortest_path->SetFinal(s, clat.Final(states[s]));
1177  }
1178  }
1179 }
1180 
1182  CompactLattice *clat) {
1183  typedef CompactLatticeArc Arc;
1184  int32 num_states = clat->NumStates();
1185 
1186  //scan the lattice
1187  for (int32 state = 0; state < num_states; state++) {
1188  for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
1189  !aiter.Done(); aiter.Next()) {
1190 
1191  Arc arc(aiter.Value());
1192 
1193  if (arc.ilabel != 0) { // if there is a word on this arc
1194  LatticeWeight weight = arc.weight.Weight();
1195  // add word insertion penalty to lattice
1196  weight.SetValue1( weight.Value1() + word_ins_penalty);
1197  arc.weight.SetWeight(weight);
1198  aiter.SetValue(arc);
1199  }
1200  } // end looping over arcs
1201  } // end looping over states
1202 }
1203 
1205  ClatRescoreTuple(int32 state, int32 arc, int32 tid):
1206  state_id(state), arc_id(arc), tid(tid) { }
1210 };
1211 
1217  const TransitionModel *tmodel,
1218  BaseFloat speedup_factor,
1219  DecodableInterface *decodable,
1220  CompactLattice *clat) {
1221  KALDI_ASSERT(speedup_factor >= 1.0);
1222  if (clat->NumStates() == 0) {
1223  KALDI_WARN << "Rescoring empty lattice";
1224  return false;
1225  }
1226  if (!clat->Properties(fst::kTopSorted, true)) {
1227  if (fst::TopSort(clat) == false) {
1228  KALDI_WARN << "Cycles detected in lattice.";
1229  return false;
1230  }
1231  }
1232  std::vector<int32> state_times;
1233  int32 utt_len = kaldi::CompactLatticeStateTimes(*clat, &state_times);
1234 
1235  std::vector<std::vector<ClatRescoreTuple> > time_to_state(utt_len);
1236 
1237  int32 num_states = clat->NumStates();
1238  KALDI_ASSERT(num_states == state_times.size());
1239  for (size_t state = 0; state < num_states; state++) {
1240  KALDI_ASSERT(state_times[state] >= 0);
1241  int32 t = state_times[state];
1242  int32 arc_id = 0;
1243  for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
1244  !aiter.Done(); aiter.Next(), arc_id++) {
1245  CompactLatticeArc arc = aiter.Value();
1246  std::vector<int32> arc_string = arc.weight.String();
1247 
1248  for (size_t offset = 0; offset < arc_string.size(); offset++) {
1249  if (t < utt_len) { // end state may be past this..
1250  int32 tid = arc_string[offset];
1251  time_to_state[t+offset].push_back(ClatRescoreTuple(state, arc_id, tid));
1252  } else {
1253  if (t != utt_len) {
1254  KALDI_WARN << "There appears to be lattice/feature mismatch, "
1255  << "aborting.";
1256  return false;
1257  }
1258  }
1259  }
1260  }
1261  if (clat->Final(state) != CompactLatticeWeight::Zero()) {
1262  arc_id = -1;
1263  std::vector<int32> arc_string = clat->Final(state).String();
1264  for (size_t offset = 0; offset < arc_string.size(); offset++) {
1265  KALDI_ASSERT(t + offset < utt_len); // already checked in
1266  // CompactLatticeStateTimes, so would be code error.
1267  time_to_state[t+offset].push_back(
1268  ClatRescoreTuple(state, arc_id, arc_string[offset]));
1269  }
1270  }
1271  }
1272 
1273  for (int32 t = 0; t < utt_len; t++) {
1274  if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
1275  KALDI_WARN << "Features are too short for lattice: utt-len is "
1276  << utt_len << ", " << t << " is last frame";
1277  return false;
1278  }
1279  // frame_scale is the scale we put on the computed acoustic probs for this
1280  // frame. It will always be 1.0 if tmodel == NULL (i.e. if we are not doing
1281  // the "speedup" code). For frames with multiple pdf-ids it will be one.
1282  // For frames with only one pdf-id, it will equal speedup_factor (>=1.0)
1283  // with probability 1.0 / speedup_factor, and zero otherwise. If it is zero,
1284  // we can avoid computing the probabilities.
1285  BaseFloat frame_scale = 1.0;
1286  KALDI_ASSERT(!time_to_state[t].empty());
1287  if (tmodel != NULL) {
1288  int32 pdf_id = tmodel->TransitionIdToPdf(time_to_state[t][0].tid);
1289  bool frame_has_multiple_pdfs = false;
1290  for (size_t i = 1; i < time_to_state[t].size(); i++) {
1291  if (tmodel->TransitionIdToPdf(time_to_state[t][i].tid) != pdf_id) {
1292  frame_has_multiple_pdfs = true;
1293  break;
1294  }
1295  }
1296  if (frame_has_multiple_pdfs) {
1297  frame_scale = 1.0;
1298  } else {
1299  if (WithProb(1.0 / speedup_factor)) {
1300  frame_scale = speedup_factor;
1301  } else {
1302  frame_scale = 0.0;
1303  }
1304  }
1305  if (frame_scale == 0.0)
1306  continue; // the code below would be pointless.
1307  }
1308 
1309  for (size_t i = 0; i < time_to_state[t].size(); i++) {
1310  int32 state = time_to_state[t][i].state_id;
1311  int32 arc_id = time_to_state[t][i].arc_id;
1312  int32 tid = time_to_state[t][i].tid;
1313 
1314  if (arc_id == -1) { // Final state
1315  // Access the trans_id
1316  CompactLatticeWeight curr_clat_weight = clat->Final(state);
1317 
1318  // Calculate likelihood
1319  BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
1320  // update weight
1321  CompactLatticeWeight new_clat_weight = curr_clat_weight;
1322  LatticeWeight new_lat_weight = new_clat_weight.Weight();
1323  new_lat_weight.SetValue2(-log_like + curr_clat_weight.Weight().Value2());
1324  new_clat_weight.SetWeight(new_lat_weight);
1325  clat->SetFinal(state, new_clat_weight);
1326  } else {
1327  fst::MutableArcIterator<CompactLattice> aiter(clat, state);
1328 
1329  aiter.Seek(arc_id);
1330  CompactLatticeArc arc = aiter.Value();
1331 
1332  // Calculate likelihood
1333  BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
1334  // update weight
1335  LatticeWeight new_weight = arc.weight.Weight();
1336  new_weight.SetValue2(-log_like + arc.weight.Weight().Value2());
1337  arc.weight.SetWeight(new_weight);
1338  aiter.SetValue(arc);
1339  }
1340  }
1341  }
1342  return true;
1343 }
1344 
1345 
1347  const TransitionModel &tmodel,
1348  BaseFloat speedup_factor,
1349  DecodableInterface *decodable,
1350  CompactLattice *clat) {
1351  return RescoreCompactLatticeInternal(&tmodel, speedup_factor, decodable, clat);
1352 }
1353 
1355  CompactLattice *clat) {
1356  return RescoreCompactLatticeInternal(NULL, 1.0, decodable, clat);
1357 }
1358 
1359 
1361  Lattice *lat) {
1362  if (lat->NumStates() == 0) {
1363  KALDI_WARN << "Rescoring empty lattice";
1364  return false;
1365  }
1366  if (!lat->Properties(fst::kTopSorted, true)) {
1367  if (fst::TopSort(lat) == false) {
1368  KALDI_WARN << "Cycles detected in lattice.";
1369  return false;
1370  }
1371  }
1372  std::vector<int32> state_times;
1373  int32 utt_len = kaldi::LatticeStateTimes(*lat, &state_times);
1374 
1375  std::vector<std::vector<int32> > time_to_state(utt_len );
1376 
1377  int32 num_states = lat->NumStates();
1378  KALDI_ASSERT(num_states == state_times.size());
1379  for (size_t state = 0; state < num_states; state++) {
1380  int32 t = state_times[state];
1381  // Don't check t >= 0 because non-accessible states could have t = -1.
1382  KALDI_ASSERT(t <= utt_len);
1383  if (t >= 0 && t < utt_len)
1384  time_to_state[t].push_back(state);
1385  }
1386 
1387  for (int32 t = 0; t < utt_len; t++) {
1388  if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
1389  KALDI_WARN << "Features are too short for lattice: utt-len is "
1390  << utt_len << ", " << t << " is last frame";
1391  return false;
1392  }
1393  for (size_t i = 0; i < time_to_state[t].size(); i++) {
1394  int32 state = time_to_state[t][i];
1395  for (fst::MutableArcIterator<Lattice> aiter(lat, state);
1396  !aiter.Done(); aiter.Next()) {
1397  LatticeArc arc = aiter.Value();
1398  if (arc.ilabel != 0) {
1399  int32 trans_id = arc.ilabel; // Note: it doesn't necessarily
1400  // have to be a transition-id, just whatever the Decodable
1401  // object is expecting, but it's normally a transition-id.
1402 
1403  BaseFloat log_like = decodable->LogLikelihood(t, trans_id);
1404  arc.weight.SetValue2(-log_like + arc.weight.Value2());
1405  aiter.SetValue(arc);
1406  }
1407  }
1408  }
1409  }
1410  return true;
1411 }
1412 
1413 
1415  const TransitionModel &tmodel,
1416  const Lattice &lat,
1417  const std::vector<int32> &num_ali,
1418  bool drop_frames,
1419  bool convert_to_pdf_ids,
1420  bool cancel,
1421  Posterior *post) {
1422  // First compute the MMI posteriors.
1423 
1424  Posterior den_post;
1426  &den_post,
1427  NULL);
1428 
1429  Posterior num_post;
1430  AlignmentToPosterior(num_ali, &num_post);
1431 
1432  // Now negate the MMI posteriors and add the numerator
1433  // posteriors.
1434  ScalePosterior(-1.0, &den_post);
1435 
1436  if (convert_to_pdf_ids) {
1437  Posterior num_tmp;
1438  ConvertPosteriorToPdfs(tmodel, num_post, &num_tmp);
1439  num_tmp.swap(num_post);
1440  Posterior den_tmp;
1441  ConvertPosteriorToPdfs(tmodel, den_post, &den_tmp);
1442  den_tmp.swap(den_post);
1443  }
1444 
1445  MergePosteriors(num_post, den_post,
1446  cancel, drop_frames, post);
1447 
1448  return ans;
1449 }
1450 
1451 
1453  typedef Lattice::Arc Arc;
1454  typedef Arc::Label Label;
1455  typedef Arc::StateId StateId;
1456 
1457  if (lat.Properties(fst::kTopSorted, true) == 0) {
1458  Lattice lat_copy(lat);
1459  if (!TopSort(&lat_copy))
1460  KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
1461  return LongestSentenceLength(lat_copy);
1462  }
1463  std::vector<int32> max_length(lat.NumStates(), 0);
1464  int32 lattice_max_length = 0;
1465  for (StateId s = 0; s < lat.NumStates(); s++) {
1466  int32 this_max_length = max_length[s];
1467  for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
1468  const Arc &arc = aiter.Value();
1469  bool arc_has_word = (arc.olabel != 0);
1470  StateId nextstate = arc.nextstate;
1471  KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
1472  if (arc_has_word) {
1473  // A lattice should ideally not have cycles anyway; a cycle with a word
1474  // on is something very bad.
1475  KALDI_ASSERT(nextstate > s && "Lattice has cycles with words on.");
1476  max_length[nextstate] = std::max(max_length[nextstate],
1477  this_max_length + 1);
1478  } else {
1479  max_length[nextstate] = std::max(max_length[nextstate],
1480  this_max_length);
1481  }
1482  }
1483  if (lat.Final(s) != LatticeWeight::Zero())
1484  lattice_max_length = std::max(lattice_max_length, max_length[s]);
1485  }
1486  return lattice_max_length;
1487 }
1488 
1490  typedef CompactLattice::Arc Arc;
1491  typedef Arc::Label Label;
1492  typedef Arc::StateId StateId;
1493 
1494  if (clat.Properties(fst::kTopSorted, true) == 0) {
1495  CompactLattice clat_copy(clat);
1496  if (!TopSort(&clat_copy))
1497  KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
1498  return LongestSentenceLength(clat_copy);
1499  }
1500  std::vector<int32> max_length(clat.NumStates(), 0);
1501  int32 lattice_max_length = 0;
1502  for (StateId s = 0; s < clat.NumStates(); s++) {
1503  int32 this_max_length = max_length[s];
1504  for (fst::ArcIterator<CompactLattice> aiter(clat, s);
1505  !aiter.Done(); aiter.Next()) {
1506  const Arc &arc = aiter.Value();
1507  bool arc_has_word = (arc.ilabel != 0); // note: olabel == ilabel.
1508  // also note: for normal CompactLattice, e.g. as produced by
1509  // determinization, all arcs will have nonzero labels, but the user might
1510  // decide to remplace some of the labels with zero for some reason, and we
1511  // want to support this.
1512  StateId nextstate = arc.nextstate;
1513  KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
1514  KALDI_ASSERT(nextstate > s && "CompactLattice has cycles");
1515  if (arc_has_word)
1516  max_length[nextstate] = std::max(max_length[nextstate],
1517  this_max_length + 1);
1518  else
1519  max_length[nextstate] = std::max(max_length[nextstate],
1520  this_max_length);
1521  }
1522  if (clat.Final(s) != CompactLatticeWeight::Zero())
1523  lattice_max_length = std::max(lattice_max_length, max_length[s]);
1524  }
1525  return lattice_max_length;
1526 }
1527 
1529  const CompactLattice& clat,
1531  CompactLattice* composed_clat) {
1532  // StdFst::Arc and CompactLatticeArc has the same StateId type.
1533  typedef fst::StdArc::StateId StateId;
1534  typedef fst::StdArc::Weight Weight1;
1535  typedef CompactLatticeArc::Weight Weight2;
1536  typedef std::pair<StateId, StateId> StatePair;
1537  typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
1538  typedef MapType::iterator IterType;
1539 
1540  // Empties the output FST.
1541  KALDI_ASSERT(composed_clat != NULL);
1542  composed_clat->DeleteStates();
1543 
1544  MapType state_map;
1545  std::queue<StatePair> state_queue;
1546 
1547  // Sets start state in <composed_clat>.
1548  StateId start_state = composed_clat->AddState();
1549  StatePair start_pair(clat.Start(), det_fst->Start());
1550  composed_clat->SetStart(start_state);
1551  state_queue.push(start_pair);
1552  std::pair<IterType, bool> result =
1553  state_map.insert(std::make_pair(start_pair, start_state));
1554  KALDI_ASSERT(result.second == true);
1555 
1556  // Starts composition here.
1557  while (!state_queue.empty()) {
1558  // Gets the first state in the queue.
1559  StatePair s = state_queue.front();
1560  StateId s1 = s.first;
1561  StateId s2 = s.second;
1562  state_queue.pop();
1563 
1564 
1565  Weight2 clat_final = clat.Final(s1);
1566  if (clat_final.Weight().Value1() !=
1567  std::numeric_limits<BaseFloat>::infinity()) {
1568  // Test for whether the final-prob of state s1 was zero.
1569  Weight1 det_fst_final = det_fst->Final(s2);
1570  if (det_fst_final.Value() !=
1571  std::numeric_limits<BaseFloat>::infinity()) {
1572  // Test for whether the final-prob of state s2 was zero. If neither
1573  // source-state final prob was zero, then we should create final state
1574  // in fst_composed. We compute the product manually since this is more
1575  // efficient.
1576  Weight2 final_weight(LatticeWeight(clat_final.Weight().Value1() +
1577  det_fst_final.Value(),
1578  clat_final.Weight().Value2()),
1579  clat_final.String());
1580  // we can assume final_weight is not Zero(), since neither of
1581  // the sources was zero.
1582  KALDI_ASSERT(state_map.find(s) != state_map.end());
1583  composed_clat->SetFinal(state_map[s], final_weight);
1584  }
1585  }
1586 
1587  // Loops over pair of edges at s1 and s2.
1588  for (fst::ArcIterator<CompactLattice> aiter(clat, s1);
1589  !aiter.Done(); aiter.Next()) {
1590  const CompactLatticeArc& arc1 = aiter.Value();
1591  fst::StdArc arc2;
1592  StateId next_state1 = arc1.nextstate, next_state2;
1593  bool matched = false;
1594 
1595  if (arc1.olabel == 0) {
1596  // If the symbol on <arc1> is <epsilon>, we transit to the next state
1597  // for <clat>, but keep <det_fst> at the current state.
1598  matched = true;
1599  next_state2 = s2;
1600  } else {
1601  // Otherwise try to find the matched arc in <det_fst>.
1602  matched = det_fst->GetArc(s2, arc1.olabel, &arc2);
1603  if (matched) {
1604  next_state2 = arc2.nextstate;
1605  }
1606  }
1607 
1608  // If matched arc is found in <det_fst>, then we have to add new arcs to
1609  // <composed_clat>.
1610  if (matched) {
1611  StatePair next_state_pair(next_state1, next_state2);
1612  IterType siter = state_map.find(next_state_pair);
1613  StateId next_state;
1614 
1615  // Adds composed state to <state_map>.
1616  if (siter == state_map.end()) {
1617  // If the composed state has not been created yet, create it.
1618  next_state = composed_clat->AddState();
1619  std::pair<const StatePair, StateId> next_state_map(next_state_pair,
1620  next_state);
1621  std::pair<IterType, bool> result = state_map.insert(next_state_map);
1622  KALDI_ASSERT(result.second);
1623  state_queue.push(next_state_pair);
1624  } else {
1625  // If the composed state is already in <state_map>, we can directly
1626  // use that.
1627  next_state = siter->second;
1628  }
1629 
1630  // Adds arc to <composed_clat>.
1631  if (arc1.olabel == 0) {
1632  composed_clat->AddArc(state_map[s],
1633  CompactLatticeArc(arc1.ilabel, 0,
1634  arc1.weight, next_state));
1635  } else {
1636  Weight2 composed_weight(
1637  LatticeWeight(arc1.weight.Weight().Value1() +
1638  arc2.weight.Value(),
1639  arc1.weight.Weight().Value2()),
1640  arc1.weight.String());
1641  composed_clat->AddArc(state_map[s],
1642  CompactLatticeArc(arc1.ilabel, arc2.olabel,
1643  composed_weight, next_state));
1644  }
1645  }
1646  }
1647  }
1648  fst::Connect(composed_clat);
1649 }
1650 
1651 
1653  const Lattice &lat,
1654  unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
1655  PairHasher<int32> > *acoustic_scores) {
1656  // typedef the arc, weight types
1657  typedef Lattice::Arc Arc;
1658  typedef Arc::Weight LatticeWeight;
1659  typedef Arc::StateId StateId;
1660 
1661  acoustic_scores->clear();
1662 
1663  std::vector<int32> state_times;
1664  LatticeStateTimes(lat, &state_times); // Assumes the input is top sorted
1665 
1666  KALDI_ASSERT(lat.Start() == 0);
1667 
1668  for (StateId s = 0; s < lat.NumStates(); s++) {
1669  int32 t = state_times[s];
1670  for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done();
1671  aiter.Next()) {
1672  const Arc &arc = aiter.Value();
1673  const LatticeWeight &weight = arc.weight;
1674 
1675  int32 tid = arc.ilabel;
1676 
1677  if (tid != 0) {
1678  unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
1679  PairHasher<int32> >::iterator it = acoustic_scores->find(std::make_pair(t, tid));
1680  if (it == acoustic_scores->end()) {
1681  acoustic_scores->insert(std::make_pair(std::make_pair(t, tid),
1682  std::make_pair(weight.Value2(), 1)));
1683  } else {
1684  if (it->second.second == 2
1685  && it->second.first / it->second.second != weight.Value2()) {
1686  KALDI_VLOG(2) << "Transitions on the same frame have different "
1687  << "acoustic costs for tid " << tid << "; "
1688  << it->second.first / it->second.second
1689  << " vs " << weight.Value2();
1690  }
1691  it->second.first += weight.Value2();
1692  it->second.second++;
1693  }
1694  } else {
1695  // Arcs with epsilon input label (tid) must have 0 acoustic cost
1696  KALDI_ASSERT(weight.Value2() == 0);
1697  }
1698  }
1699 
1700  LatticeWeight f = lat.Final(s);
1701  if (f != LatticeWeight::Zero()) {
1702  // Final acoustic cost must be 0 as we are reading from
1703  // non-determinized, non-compact lattice
1704  KALDI_ASSERT(f.Value2() == 0.0);
1705  }
1706  }
1707 }
1708 
1710  const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
1711  PairHasher<int32> > &acoustic_scores,
1712  Lattice *lat) {
1713  // typedef the arc, weight types
1714  typedef Lattice::Arc Arc;
1715  typedef Arc::Weight LatticeWeight;
1716  typedef Arc::StateId StateId;
1717 
1719 
1720  std::vector<int32> state_times;
1721  LatticeStateTimes(*lat, &state_times);
1722 
1723  KALDI_ASSERT(lat->Start() == 0);
1724 
1725  for (StateId s = 0; s < lat->NumStates(); s++) {
1726  int32 t = state_times[s];
1727  for (fst::MutableArcIterator<Lattice> aiter(lat, s);
1728  !aiter.Done(); aiter.Next()) {
1729  Arc arc(aiter.Value());
1730 
1731  int32 tid = arc.ilabel;
1732  if (tid != 0) {
1733  unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
1734  PairHasher<int32> >::const_iterator it = acoustic_scores.find(std::make_pair(t, tid));
1735  if (it == acoustic_scores.end()) {
1736  KALDI_ERR << "Could not find tid " << tid << " at time " << t
1737  << " in the acoustic scores map.";
1738  } else {
1739  arc.weight.SetValue2(it->second.first / it->second.second);
1740  }
1741  } else {
1742  // For epsilon arcs, set acoustic cost to 0.0
1743  arc.weight.SetValue2(0.0);
1744  }
1745  aiter.SetValue(arc);
1746  }
1747 
1748  LatticeWeight f = lat->Final(s);
1749  if (f != LatticeWeight::Zero()) {
1750  // Set final acoustic cost to 0.0
1751  f.SetValue2(0.0);
1752  lat->SetFinal(s, f);
1753  }
1754  }
1755 }
1756 
1757 } // namespace kaldi
int32 words[kMaxOrder]
fst::StdArc::StateId StateId
fst::StdArc::Label Label
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
double Exp(double x)
Definition: kaldi-math.h:83
virtual bool GetArc(StateId s, Label ilabel, Arc *oarc)=0
Note: ilabel must not be epsilon.
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
CompactLatticeArc::StateId state
virtual Weight Final(StateId s)=0
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
double ComputeLatticeAlphasAndBetas(const LatticeType &lat, bool viterbi, vector< double > *alpha, vector< double > *beta)
Lattice::StateId StateId
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
fst::StdArc StdArc
bool WithProb(BaseFloat prob, struct RandomState *state)
Definition: kaldi-math.cc:72
void ReplaceAcousticScoresFromMap(const unordered_map< std::pair< int32, int32 >, std::pair< BaseFloat, int32 >, PairHasher< int32 > > &acoustic_scores, Lattice *lat)
This function restores acoustic scores computed using the function ComputeAcousticScoresMap into the ...
void TopSortLatticeIfNeeded(Lattice *lat)
Topologically sort the lattice if not already topologically sorted.
virtual bool IsLastFrame(int32 frame) const =0
Returns true if this is the last frame.
ClatRescoreTuple(int32 state, int32 arc, int32 tid)
kaldi::int32 int32
void GetPerFrameAcousticCosts(const Lattice &nbest, Vector< BaseFloat > *per_frame_loglikes)
This function extracts the per-frame log likelihoods from a linear lattice (which we refer to as an &#39;...
bool RescoreCompactLatticeSpeedup(const TransitionModel &tmodel, BaseFloat speedup_factor, DecodableInterface *decodable, CompactLattice *clat)
This function is like RescoreCompactLattice, but it is modified to avoid computing probabilities on m...
virtual StateId Start()=0
int32 TransitionIdToPdf(int32 trans_id) const
bool RescoreCompactLattice(DecodableInterface *decodable, CompactLattice *clat)
This function *adds* the negated scores obtained from the Decodable object, to the acoustic scores on...
void ConvertCompactLatticeToPhones(const TransitionModel &trans, CompactLattice *clat)
Given a lattice, and a transition model to map pdf-ids to phones, replace the sequences of transition...
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans, const vector< int32 > &silence_phones, vector< std::set< int32 > > *active_phones)
Given a lattice, and a transition model to map pdf-ids to phones, outputs for each frame the set of p...
bool SplitToPhones(const TransitionModel &trans_model, const std::vector< int32 > &alignment, std::vector< std::vector< int32 > > *split_alignment)
SplitToPhones splits up the TransitionIds in "alignment" into their individual phones (one vector per...
Definition: hmm-utils.cc:723
void CompactLatticeShortestPath(const CompactLattice &clat, CompactLattice *shortest_path)
A form of the shortest-path/best-path algorithm that&#39;s specially coded for CompactLattice.
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
BaseFloat LatticeForwardBackwardMmi(const TransitionModel &tmodel, const Lattice &lat, const std::vector< int32 > &num_ali, bool drop_frames, bool convert_to_pdf_ids, bool cancel, Posterior *post)
This function can be used to compute posteriors for MMI, with a positive contribution for the numerat...
BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post, double *acoustic_like_sum)
This function does the forward-backward over lattices and computes the posterior probabilities of the...
int32 NumTransitionIds() const
Returns the total number of transition-ids (note, these are one-based).
BaseFloat CompactLatticeDepth(const CompactLattice &clat, int32 *num_frames)
Returns the depth of the lattice, defined as the average number of arcs crossing any given frame...
bool ComputeCompactLatticeBetas(const CompactLattice &clat, vector< double > *beta)
double ConvertToCost(const LatticeWeightTpl< Float > &w)
int32 TransitionIdToHmmState(int32 trans_id) const
bool IsSelfLoop(int32 trans_id) const
static const LatticeWeightTpl Zero()
void ComposeCompactLatticeDeterministic(const CompactLattice &clat, fst::DeterministicOnDemandFst< fst::StdArc > *det_fst, CompactLattice *composed_clat)
This function Composes a CompactLattice format lattice with a DeterministicOnDemandFst<fst::StdFst> f...
void AlignmentToPosterior(const std::vector< int32 > &ali, Posterior *post)
Convert an alignment to a posterior (with a scale of 1.0 on each entry).
Definition: posterior.cc:290
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
int32 CompactLatticeStateTimes(const CompactLattice &lat, vector< int32 > *times)
As LatticeStateTimes, but in the CompactLattice format.
BaseFloat LatticeForwardBackwardMpeVariants(const TransitionModel &trans, const std::vector< int32 > &silence_phones, const Lattice &lat, const std::vector< int32 > &num_ali, std::string criterion, bool one_silence_class, Posterior *post)
This function implements either the MPFE (minimum phone frame error) or SMBR (state-level minimum bay...
bool RescoreCompactLatticeInternal(const TransitionModel *tmodel, BaseFloat speedup_factor, DecodableInterface *decodable, CompactLattice *clat)
RescoreCompactLatticeInternal is the internal code for both RescoreCompactLattice and RescoreCompatLa...
#define KALDI_WARN
Definition: kaldi-error.h:150
void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty, CompactLattice *clat)
This function add the word insertion penalty to graph score of each word in the compact lattice...
fst::StdArc::Label Label
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
fst::StdArc::Weight Weight
bool operator<(const Int32Pair &a, const Int32Pair &b)
Definition: cu-matrixdim.h:83
double LogAdd(double x, double y)
Definition: kaldi-math.h:184
void ScalePosterior(BaseFloat scale, Posterior *post)
Scales the BaseFloat (weight) element in the posterior entries.
Definition: posterior.cc:218
void ComputeAcousticScoresMap(const Lattice &lat, unordered_map< std::pair< int32, int32 >, std::pair< BaseFloat, int32 >, PairHasher< int32 > > *acoustic_scores)
This function computes the mapping from the pair (frame-index, transition-id) to the pair (sum-of-aco...
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ISNAN
Definition: kaldi-math.h:72
bool PruneLattice(BaseFloat beam, LatType *lat)
Arc::Weight Weight
Definition: kws-search.cc:31
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void CompactLatticeLimitDepth(int32 max_depth_per_frame, CompactLattice *clat)
This function limits the depth of the lattice, per frame: that means, it does not allow more than a s...
static const CompactLatticeWeightTpl< WeightType, IntType > Zero()
void MergePairVectorSumming(std::vector< std::pair< I, F > > *vec)
For a vector of pair<I, F> where I is an integer and F a floating-point or integer type...
Definition: stl-utils.h:288
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
bool LatticeBoost(const TransitionModel &trans, const std::vector< int32 > &alignment, const std::vector< int32 > &silence_phones, BaseFloat b, BaseFloat max_silence_error, Lattice *lat)
Boosts LM probabilities by b * [number of frame errors]; equivalently, adds -b*[number of frame error...
fst::ArcTpl< CompactLatticeWeight > CompactLatticeArc
Definition: kaldi-lattice.h:42
void TopSortCompactLatticeIfNeeded(CompactLattice *clat)
Topologically sort the compact lattice if not already topologically sorted.
void CompactLatticeDepthPerFrame(const CompactLattice &clat, std::vector< int32 > *depth_per_frame)
This function returns, for each frame, the number of arcs crossing that frame.
void ConvertLatticeToPhones(const TransitionModel &trans, Lattice *lat)
Given a lattice, and a transition model to map pdf-ids to phones, replace the output symbols (presuma...
This is used in CompactLatticeLimitDepth.
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
Definition: posterior.cc:322
bool ComputeCompactLatticeAlphas(const CompactLattice &clat, vector< double > *alpha)
int32 MergePosteriors(const Posterior &post1, const Posterior &post2, bool merge, bool drop_frames, Posterior *post)
Merge two sets of posteriors, which must have the same length.
Definition: posterior.cc:258
virtual BaseFloat LogLikelihood(int32 frame, int32 index)=0
Returns the log likelihood, which will be negated in the decoder.
bool IsSortedAndUniq(const std::vector< T > &vec)
Returns true if the vector is sorted and contains each element only once.
Definition: stl-utils.h:63
const double kLogZeroDouble
Definition: kaldi-math.h:129
static double LogAddOrMax(bool viterbi, double a, double b)
bool IsFinal(int32 trans_id) const
int32 TransitionIdToPhone(int32 trans_id) const
bool RescoreLattice(DecodableInterface *decodable, Lattice *lat)
This function *adds* the negated scores obtained from the Decodable object, to the acoustic scores on...
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Definition: kaldi-math.h:265
int32 LongestSentenceLength(const Lattice &lat)
This function returns the number of words in the longest sentence in a CompactLattice (i...
bool CompactLatticeToWordAlignment(const CompactLattice &clat, std::vector< int32 > *words, std::vector< int32 > *begin_times, std::vector< int32 > *lengths)
This function takes a CompactLattice that should only contain a single linear sequence (e...
bool CompactLatticeToWordProns(const TransitionModel &tmodel, const CompactLattice &clat, std::vector< int32 > *words, std::vector< int32 > *begin_times, std::vector< int32 > *lengths, std::vector< std::vector< int32 > > *prons, std::vector< std::vector< int32 > > *phone_lengths)
This function takes a CompactLattice that should only contain a single linear sequence (e...
A hashing function-object for pairs of ints.
Definition: stl-utils.h:235