sausages.cc
Go to the documentation of this file.
1 // lat/sausages.cc
2 
3 // Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
4 // 2015 Guoguo Chen
5 // 2019 Dogan Can
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include "lat/sausages.h"
23 #include "lat/lattice-functions.h"
24 
25 namespace kaldi {
26 
27 // this is Figure 6 in the paper.
29 
30  for (size_t counter = 0; ; counter++) {
31  NormalizeEps(&R_);
32  AccStats(); // writes to gamma_
33  double delta_Q = 0.0; // change in objective function.
34 
35  one_best_times_.clear();
36  one_best_confidences_.clear();
37 
38  // Caution: q in the line below is (q-1) in the algorithm
39  // in the paper; both R_ and gamma_ are indexed by q-1.
40  for (size_t q = 0; q < R_.size(); q++) {
41  if (opts_.decode_mbr) { // This loop updates R_ [indexed same as gamma_].
42  // gamma_[i] is sorted in reverse order so most likely one is first.
43  const std::vector<std::pair<int32, BaseFloat> > &this_gamma = gamma_[q];
44  double old_gamma = 0, new_gamma = this_gamma[0].second;
45  int32 rq = R_[q], rhat = this_gamma[0].first; // rq: old word, rhat: new.
46  for (size_t j = 0; j < this_gamma.size(); j++)
47  if (this_gamma[j].first == rq) old_gamma = this_gamma[j].second;
48  delta_Q += (old_gamma - new_gamma); // will be 0 or negative; a bound on
49  // change in error.
50  if (rq != rhat)
51  KALDI_VLOG(2) << "Changing word " << rq << " to " << rhat;
52  R_[q] = rhat;
53  }
54  // build the outputs (time, confidences),
55  if (R_[q] != 0 || opts_.print_silence) {
56  // see which 'item' from the sausage-bin should we select,
57  // (not necessarily the 1st one when MBR decoding disabled)
58  int32 s = 0;
59  for (int32 j=0; j<gamma_[q].size(); j++) {
60  if (gamma_[q][j].first == R_[q]) {
61  s = j;
62  break;
63  }
64  }
65  one_best_times_.push_back(times_[q][s]);
66  // post-process the times,
67  size_t i = one_best_times_.size();
68  if (i > 1 && one_best_times_[i-2].second > one_best_times_[i-1].first) {
69  // It's quite possible for this to happen, but it seems like it would
70  // have a bad effect on the downstream processing, so we fix it here.
71  // We resolve overlaps by redistributing the available time interval.
72  BaseFloat prev_right = i > 2 ? one_best_times_[i-3].second : 0.0;
73  BaseFloat left = std::max(prev_right,
74  std::min(one_best_times_[i-2].first,
75  one_best_times_[i-1].first));
76  BaseFloat right = std::max(one_best_times_[i-2].second,
77  one_best_times_[i-1].second);
78  BaseFloat first_dur =
79  one_best_times_[i-2].second - one_best_times_[i-2].first;
80  BaseFloat second_dur =
81  one_best_times_[i-1].second - one_best_times_[i-1].first;
82  BaseFloat mid = first_dur > 0 ? left + (right - left) * first_dur /
83  (first_dur + second_dur) : left;
84  one_best_times_[i-2].first = left;
85  one_best_times_[i-2].second = one_best_times_[i-1].first = mid;
86  one_best_times_[i-1].second = right;
87  }
88  BaseFloat confidence = 0.0;
89  for (int32 j = 0; j < gamma_[q].size(); j++) {
90  if (gamma_[q][j].first == R_[q]) {
91  confidence = gamma_[q][j].second;
92  break;
93  }
94  }
95  one_best_confidences_.push_back(confidence);
96  }
97  }
98  KALDI_VLOG(2) << "Iter = " << counter << ", delta-Q = " << delta_Q;
99  if (delta_Q == 0) break;
100  if (counter > 100) {
101  KALDI_WARN << "Iterating too many times in MbrDecode; stopping.";
102  break;
103  }
104  }
106 }
107 
108 struct Int32IsZero {
109  bool operator() (int32 i) { return (i == 0); }
110 };
111 // static
112 void MinimumBayesRisk::RemoveEps(std::vector<int32> *vec) {
113  Int32IsZero pred;
114  vec->erase(std::remove_if (vec->begin(), vec->end(), pred),
115  vec->end());
116 }
117 
118 // static
119 void MinimumBayesRisk::NormalizeEps(std::vector<int32> *vec) {
120  RemoveEps(vec);
121  vec->resize(1 + vec->size() * 2);
122  int32 s = vec->size();
123  for (int32 i = s/2 - 1; i >= 0; i--) {
124  (*vec)[i*2 + 1] = (*vec)[i];
125  (*vec)[i*2 + 2] = 0;
126  }
127  (*vec)[0] = 0;
128 }
129 
131  Vector<double> &alpha,
132  Matrix<double> &alpha_dash,
133  Vector<double> &alpha_dash_arc) {
134  alpha(1) = 0.0; // = log(1). Line 5.
135  alpha_dash(1, 0) = 0.0; // Line 5.
136  for (int32 q = 1; q <= Q; q++)
137  alpha_dash(1, q) = alpha_dash(1, q-1) + l(0, r(q)); // Line 7.
138  for (int32 n = 2; n <= N; n++) {
139  double alpha_n = kLogZeroDouble;
140  for (size_t i = 0; i < pre_[n].size(); i++) {
141  const Arc &arc = arcs_[pre_[n][i]];
142  alpha_n = LogAdd(alpha_n, alpha(arc.start_node) + arc.loglike);
143  }
144  alpha(n) = alpha_n; // Line 10.
145  // Line 11 omitted: matrix was initialized to zero.
146  for (size_t i = 0; i < pre_[n].size(); i++) {
147  const Arc &arc = arcs_[pre_[n][i]];
148  int32 s_a = arc.start_node, w_a = arc.word;
149  BaseFloat p_a = arc.loglike;
150  for (int32 q = 0; q <= Q; q++) {
151  if (q == 0) {
152  alpha_dash_arc(q) = // line 15.
153  alpha_dash(s_a, q) + l(w_a, 0, true);
154  } else { // a1,a2,a3 are the 3 parts of min expression of line 17.
155  int32 r_q = r(q);
156  double a1 = alpha_dash(s_a, q-1) + l(w_a, r_q),
157  a2 = alpha_dash(s_a, q) + l(w_a, 0, true),
158  a3 = alpha_dash_arc(q-1) + l(0, r_q);
159  alpha_dash_arc(q) = std::min(a1, std::min(a2, a3));
160  }
161  // line 19:
162  alpha_dash(n, q) += Exp(alpha(s_a) + p_a - alpha(n)) * alpha_dash_arc(q);
163  }
164  }
165  }
166  return alpha_dash(N, Q); // line 23.
167 }
168 
169 // Figure 5 in the paper.
171  using std::map;
172 
173  int32 N = static_cast<int32>(pre_.size()) - 1,
174  Q = static_cast<int32>(R_.size());
175 
176  Vector<double> alpha(N+1); // index (1...N)
177  Matrix<double> alpha_dash(N+1, Q+1); // index (1...N, 0...Q)
178  Vector<double> alpha_dash_arc(Q+1); // index 0...Q
179  Matrix<double> beta_dash(N+1, Q+1); // index (1...N, 0...Q)
180  Vector<double> beta_dash_arc(Q+1); // index 0...Q
181  std::vector<char> b_arc(Q+1); // integer in {1,2,3}; index 1...Q
182  std::vector<map<int32, double> > gamma(Q+1); // temp. form of gamma.
183  // index 1...Q [word] -> occ.
184 
185  // The tau maps below are the sums over arcs with the same word label
186  // of the tau_b and tau_e timing quantities mentioned in Appendix C of
187  // the paper... we are using these to get averaged times for both the
188  // the sausage bins and the 1-best output.
189  std::vector<map<int32, double> > tau_b(Q+1), tau_e(Q+1);
190 
191  double Ltmp = EditDistance(N, Q, alpha, alpha_dash, alpha_dash_arc);
192  if (L_ != 0 && Ltmp > L_) { // L_ != 0 is to rule out 1st iter.
193  KALDI_WARN << "Edit distance increased: " << Ltmp << " > "
194  << L_;
195  }
196  L_ = Ltmp;
197  KALDI_VLOG(2) << "L = " << L_;
198  // omit line 10: zero when initialized.
199  beta_dash(N, Q) = 1.0; // Line 11.
200  for (int32 n = N; n >= 2; n--) {
201  for (size_t i = 0; i < pre_[n].size(); i++) {
202  const Arc &arc = arcs_[pre_[n][i]];
203  int32 s_a = arc.start_node, w_a = arc.word;
204  BaseFloat p_a = arc.loglike;
205  alpha_dash_arc(0) = alpha_dash(s_a, 0) + l(w_a, 0, true); // line 14.
206  for (int32 q = 1; q <= Q; q++) { // this loop == lines 15-18.
207  int32 r_q = r(q);
208  double a1 = alpha_dash(s_a, q-1) + l(w_a, r_q),
209  a2 = alpha_dash(s_a, q) + l(w_a, 0, true),
210  a3 = alpha_dash_arc(q-1) + l(0, r_q);
211  if (a1 <= a2) {
212  if (a1 <= a3) { b_arc[q] = 1; alpha_dash_arc(q) = a1; }
213  else { b_arc[q] = 3; alpha_dash_arc(q) = a3; }
214  } else {
215  if (a2 <= a3) { b_arc[q] = 2; alpha_dash_arc(q) = a2; }
216  else { b_arc[q] = 3; alpha_dash_arc(q) = a3; }
217  }
218  }
219  beta_dash_arc.SetZero(); // line 19.
220  for (int32 q = Q; q >= 1; q--) {
221  // line 21:
222  beta_dash_arc(q) += Exp(alpha(s_a) + p_a - alpha(n)) * beta_dash(n, q);
223  switch (static_cast<int>(b_arc[q])) { // lines 22 and 23:
224  case 1:
225  beta_dash(s_a, q-1) += beta_dash_arc(q);
226  // next: gamma(q, w(a)) += beta_dash_arc(q)
227  AddToMap(w_a, beta_dash_arc(q), &(gamma[q]));
228  // next: accumulating times, see decl for tau_b,tau_e
229  AddToMap(w_a, state_times_[s_a] * beta_dash_arc(q), &(tau_b[q]));
230  AddToMap(w_a, state_times_[n] * beta_dash_arc(q), &(tau_e[q]));
231  break;
232  case 2:
233  beta_dash(s_a, q) += beta_dash_arc(q);
234  break;
235  case 3:
236  beta_dash_arc(q-1) += beta_dash_arc(q);
237  // next: gamma(q, epsilon) += beta_dash_arc(q)
238  AddToMap(0, beta_dash_arc(q), &(gamma[q]));
239  // next: accumulating times, see decl for tau_b,tau_e
240  // WARNING: there was an error in Appendix C. If we followed
241  // the instructions there the next line would say state_times_[sa], but
242  // it would be wrong. I will try to publish an erratum.
243  AddToMap(0, state_times_[n] * beta_dash_arc(q), &(tau_b[q]));
244  AddToMap(0, state_times_[n] * beta_dash_arc(q), &(tau_e[q]));
245  break;
246  default:
247  KALDI_ERR << "Invalid b_arc value"; // error in code.
248  }
249  }
250  beta_dash_arc(0) += Exp(alpha(s_a) + p_a - alpha(n)) * beta_dash(n, 0);
251  beta_dash(s_a, 0) += beta_dash_arc(0); // line 26.
252  }
253  }
254  beta_dash_arc.SetZero(); // line 29.
255  for (int32 q = Q; q >= 1; q--) {
256  beta_dash_arc(q) += beta_dash(1, q);
257  beta_dash_arc(q-1) += beta_dash_arc(q);
258  AddToMap(0, beta_dash_arc(q), &(gamma[q]));
259  // the statements below are actually redundant because
260  // state_times_[1] is zero.
261  AddToMap(0, state_times_[1] * beta_dash_arc(q), &(tau_b[q]));
262  AddToMap(0, state_times_[1] * beta_dash_arc(q), &(tau_e[q]));
263  }
264  for (int32 q = 1; q <= Q; q++) { // a check (line 35)
265  double sum = 0.0;
266  for (map<int32, double>::iterator iter = gamma[q].begin();
267  iter != gamma[q].end(); ++iter) sum += iter->second;
268  if (fabs(sum - 1.0) > 0.1)
269  KALDI_WARN << "sum of gamma[" << q << ",s] is " << sum;
270  }
271  // The next part is where we take gamma, and convert
272  // to the class member gamma_, which is using a different
273  // data structure and indexed from zero, not one.
274  gamma_.clear();
275  gamma_.resize(Q);
276  for (int32 q = 1; q <= Q; q++) {
277  for (map<int32, double>::iterator iter = gamma[q].begin();
278  iter != gamma[q].end(); ++iter)
279  gamma_[q-1].push_back(
280  std::make_pair(iter->first, static_cast<BaseFloat>(iter->second)));
281  // sort gamma_[q-1] from largest to smallest posterior.
282  GammaCompare comp;
283  std::sort(gamma_[q-1].begin(), gamma_[q-1].end(), comp);
284  }
285  // We do the same conversion for the state times tau_b and tau_e:
286  // they get turned into the times_ data member, which has zero-based
287  // indexing.
288  times_.clear();
289  times_.resize(Q);
290  sausage_times_.clear();
291  sausage_times_.resize(Q);
292  for (int32 q = 1; q <= Q; q++) {
293  double t_b = 0.0, t_e = 0.0;
294  for (std::vector<std::pair<int32, BaseFloat>>::iterator iter = gamma_[q-1].begin();
295  iter != gamma_[q-1].end(); ++iter) {
296  double w_b = tau_b[q][iter->first], w_e = tau_e[q][iter->first];
297  if (w_b > w_e)
298  KALDI_WARN << "Times out of order"; // this is quite bad.
299  times_[q-1].push_back(
300  std::make_pair(static_cast<BaseFloat>(w_b / iter->second),
301  static_cast<BaseFloat>(w_e / iter->second)));
302  t_b += w_b;
303  t_e += w_e;
304  }
305  sausage_times_[q-1].first = t_b;
306  sausage_times_[q-1].second = t_e;
307  if (sausage_times_[q-1].first > sausage_times_[q-1].second)
308  KALDI_WARN << "Times out of order"; // this is quite bad.
309  if (q > 1 && sausage_times_[q-2].second > sausage_times_[q-1].first) {
310  // We previously had a warning here, but now we'll just set both
311  // those values to their average. It's quite possible for this
312  // condition to happen, but it seems like it would have a bad effect
313  // on the downstream processing, so we fix it.
314  sausage_times_[q-2].second = sausage_times_[q-1].first =
315  0.5 * (sausage_times_[q-2].second + sausage_times_[q-1].first);
316  }
317  }
318 }
319 
321  KALDI_ASSERT(clat != NULL);
322 
323  CreateSuperFinal(clat); // Add super-final state to clat... this is
324  // one of the requirements of the MBR algorithm, as mentioned in the
325  // paper (i.e. just one final state).
326 
327  // Topologically sort the lattice, if not already sorted.
328  kaldi::uint64 props = clat->Properties(fst::kFstProperties, false);
329  if (!(props & fst::kTopSorted)) {
330  if (fst::TopSort(clat) == false)
331  KALDI_ERR << "Cycles detected in lattice.";
332  }
333  CompactLatticeStateTimes(*clat, &state_times_); // work out times of
334  // the states in clat
335  state_times_.push_back(0); // we'll convert to 1-based numbering.
336  for (size_t i = state_times_.size()-1; i > 0; i--)
338 
339  // Now we convert the information in "clat" into a special internal
340  // format (pre_, post_ and arcs_) which allows us to access the
341  // arcs preceding any given state.
342  // Note: in our internal format the states will be numbered from 1,
343  // which involves adding 1 to the OpenFst states.
344  int32 N = clat->NumStates();
345  pre_.resize(N+1);
346 
347  // Careful: "Arc" is a class-member struct, not an OpenFst type of arc as one
348  // would normally assume.
349  for (int32 n = 1; n <= N; n++) {
350  for (fst::ArcIterator<CompactLattice> aiter(*clat, n-1);
351  !aiter.Done();
352  aiter.Next()) {
353  const CompactLatticeArc &carc = aiter.Value();
354  Arc arc; // in our local format.
355  arc.word = carc.ilabel; // == carc.olabel
356  arc.start_node = n;
357  arc.end_node = carc.nextstate + 1; // convert to 1-based.
358  arc.loglike = - (carc.weight.Weight().Value1() +
359  carc.weight.Weight().Value2());
360  // loglike: sum graph/LM and acoustic cost, and negate to
361  // convert to loglikes. We assume acoustic scaling is already done.
362 
363  pre_[arc.end_node].push_back(arcs_.size()); // record index of this arc.
364  arcs_.push_back(arc);
365  }
366  }
367 }
368 
370  MinimumBayesRiskOptions opts) : opts_(opts) {
371  CompactLattice clat(clat_in); // copy.
372 
374 
375  // We don't need to look at clat.Start() or clat.Final(state):
376  // we know clat.Start() == 0 since it's topologically sorted,
377  // and clat.Final(state) is Zero() except for One() at the last-
378  // numbered state, thanks to CreateSuperFinal and the topological
379  // sorting.
380 
381  { // Now set R_ to one best in the FST.
382  RemoveAlignmentsFromCompactLattice(&clat); // will be more efficient
383  // in best-path if we do this.
384  Lattice lat;
385  ConvertLattice(clat, &lat); // convert from CompactLattice to Lattice.
386  fst::VectorFst<fst::StdArc> fst;
387  ConvertLattice(lat, &fst); // convert from lattice to normal FST.
388  fst::VectorFst<fst::StdArc> fst_shortest_path;
389  fst::ShortestPath(fst, &fst_shortest_path); // take shortest path of FST.
390  std::vector<int32> alignment, words;
391  fst::TropicalWeight weight;
392  GetLinearSymbolSequence(fst_shortest_path, &alignment, &words, &weight);
393  KALDI_ASSERT(alignment.empty()); // we removed the alignment.
394  R_ = words;
395  L_ = 0.0; // Set current edit-distance to 0 [just so we know
396  // when we're on the 1st iter.]
397  }
398 
399  MbrDecode();
400 
401 }
402 
404  const std::vector<int32> &words,
405  MinimumBayesRiskOptions opts) : opts_(opts) {
406  CompactLattice clat(clat_in); // copy.
407 
409 
410  R_ = words;
411  L_ = 0.0;
412 
413  MbrDecode();
414 }
415 
417  const std::vector<int32> &words,
418  const std::vector<std::pair<BaseFloat,BaseFloat> > &times,
419  MinimumBayesRiskOptions opts) : opts_(opts) {
420  CompactLattice clat(clat_in); // copy.
421 
423 
424  R_ = words;
425  sausage_times_ = times;
426  L_ = 0.0;
427 
428  MbrDecode();
429 }
430 
431 
432 } // namespace kaldi
int32 words[kMaxOrder]
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
double Exp(double x)
Definition: kaldi-math.h:83
static void AddToMap(int32 i, double d, std::map< int32, double > *gamma)
Function used to increment map.
Definition: sausages.h:192
void RemoveAlignmentsFromCompactLattice(MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *fst)
Removes state-level alignments (the strings that are part of the weights).
static void RemoveEps(std::vector< int32 > *vec)
Removes epsilons (symbol 0) from a vector.
Definition: sausages.cc:112
double l(int32 a, int32 b, bool penalize=false)
Without the &#39;penalize&#39; argument this gives us the basic edit-distance function l(a,b), as in the paper.
Definition: sausages.h:157
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
static void NormalizeEps(std::vector< int32 > *vec)
Definition: sausages.cc:119
kaldi::int32 int32
The implementation of the Minimum Bayes Risk decoding method described in "Minimum Bayes Risk decodin...
Definition: sausages.h:56
std::vector< std::vector< std::pair< BaseFloat, BaseFloat > > > times_
Definition: sausages.h:235
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
std::vector< int32 > state_times_
Definition: sausages.h:219
MinimumBayesRiskOptions opts_
Definition: sausages.h:207
bool print_silence
Boolean configuration parameter: if true, the 1-best path will &#39;keep&#39; the <eps> bins,.
Definition: sausages.h:62
void PrepareLatticeAndInitStats(CompactLattice *clat)
Definition: sausages.cc:320
std::vector< BaseFloat > one_best_confidences_
Definition: sausages.h:250
std::vector< std::vector< std::pair< int32, BaseFloat > > > gamma_
Definition: sausages.h:229
int32 r(int32 q)
returns r_q, in one-based indexing, as in the paper.
Definition: sausages.h:163
struct rnnlm::@11::@12 n
Arc::StateId CreateSuperFinal(MutableFst< Arc > *fst)
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
MinimumBayesRisk(const CompactLattice &clat, MinimumBayesRiskOptions opts=MinimumBayesRiskOptions())
Initialize with compact lattice– any acoustic scaling etc., is assumed to have been done already...
Definition: sausages.cc:369
int32 CompactLatticeStateTimes(const CompactLattice &lat, vector< int32 > *times)
As LatticeStateTimes, but in the CompactLattice format.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< int32 > R_
Definition: sausages.h:222
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
std::vector< std::pair< BaseFloat, BaseFloat > > sausage_times_
Definition: sausages.h:240
double LogAdd(double x, double y)
Definition: kaldi-math.h:184
double EditDistance(int32 N, int32 Q, Vector< double > &alpha, Matrix< double > &alpha_dash, Vector< double > &alpha_dash_arc)
Figure 4 of the paper; called from AccStats (Fig. 5)
Definition: sausages.cc:130
void AccStats()
Figure 5 of the paper. Outputs to gamma_ and L_.
Definition: sausages.cc:170
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
fst::ArcTpl< CompactLatticeWeight > CompactLatticeArc
Definition: kaldi-lattice.h:42
std::vector< std::vector< int32 > > pre_
For each node in the lattice, a list of arcs entering that node.
Definition: sausages.h:217
std::vector< std::pair< BaseFloat, BaseFloat > > one_best_times_
Definition: sausages.h:245
void SetZero()
Set vector to all zeros.
const double kLogZeroDouble
Definition: kaldi-math.h:129
std::vector< Arc > arcs_
Arcs in the topologically sorted acceptor form of the word-level lattice, with one final-state...
Definition: sausages.h:213
void MbrDecode()
Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper.
Definition: sausages.cc:28
bool decode_mbr
Boolean configuration parameter: if true, we actually update the hypothesis to do MBR decoding (if fa...
Definition: sausages.h:60