MinimumBayesRisk Class Reference

This class does the word-level Minimum Bayes Risk computation, and gives you either the 1-best MBR output together with the expected Bayes Risk, or a sausage-like structure. More...

#include <sausages.h>

Collaboration diagram for MinimumBayesRisk:

Classes

struct  Arc
 
struct  GammaCompare
 

Public Member Functions

 MinimumBayesRisk (const CompactLattice &clat, MinimumBayesRiskOptions opts=MinimumBayesRiskOptions())
 Initialize with compact lattice– any acoustic scaling etc., is assumed to have been done already. More...
 
 MinimumBayesRisk (const CompactLattice &clat, const std::vector< int32 > &words, MinimumBayesRiskOptions opts=MinimumBayesRiskOptions())
 
 MinimumBayesRisk (const CompactLattice &clat, const std::vector< int32 > &words, const std::vector< std::pair< BaseFloat, BaseFloat > > &times, MinimumBayesRiskOptions opts=MinimumBayesRiskOptions())
 
const std::vector< int32 > & GetOneBest () const
 
const std::vector< std::vector< std::pair< BaseFloat, BaseFloat > > > GetTimes () const
 
const std::vector< std::pair< BaseFloat, BaseFloat > > GetSausageTimes () const
 
const std::vector< std::pair< BaseFloat, BaseFloat > > & GetOneBestTimes () const
 
const std::vector< BaseFloat > & GetOneBestConfidences () const
 Outputs the confidences for the one-best transcript. More...
 
BaseFloat GetBayesRisk () const
 Returns the expected WER over this sentence (assuming model correctness). More...
 
const std::vector< std::vector< std::pair< int32, BaseFloat > > > & GetSausageStats () const
 

Private Member Functions

void PrepareLatticeAndInitStats (CompactLattice *clat)
 
void MbrDecode ()
 Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper. More...
 
double l (int32 a, int32 b, bool penalize=false)
 Without the 'penalize' argument this gives us the basic edit-distance function l(a,b), as in the paper. More...
 
int32 r (int32 q)
 returns r_q, in one-based indexing, as in the paper. More...
 
double EditDistance (int32 N, int32 Q, Vector< double > &alpha, Matrix< double > &alpha_dash, Vector< double > &alpha_dash_arc)
 Figure 4 of the paper; called from AccStats (Fig. 5) More...
 
void AccStats ()
 Figure 5 of the paper. Outputs to gamma_ and L_. More...
 

Static Private Member Functions

static void RemoveEps (std::vector< int32 > *vec)
 Removes epsilons (symbol 0) from a vector. More...
 
static void NormalizeEps (std::vector< int32 > *vec)
 
static BaseFloat delta ()
 
static void AddToMap (int32 i, double d, std::map< int32, double > *gamma)
 Function used to increment map. More...
 

Private Attributes

MinimumBayesRiskOptions opts_
 
std::vector< Arcarcs_
 Arcs in the topologically sorted acceptor form of the word-level lattice, with one final-state. More...
 
std::vector< std::vector< int32 > > pre_
 For each node in the lattice, a list of arcs entering that node. More...
 
std::vector< int32state_times_
 
std::vector< int32R_
 
double L_
 
std::vector< std::vector< std::pair< int32, BaseFloat > > > gamma_
 
std::vector< std::vector< std::pair< BaseFloat, BaseFloat > > > times_
 
std::vector< std::pair< BaseFloat, BaseFloat > > sausage_times_
 
std::vector< std::pair< BaseFloat, BaseFloat > > one_best_times_
 
std::vector< BaseFloatone_best_confidences_
 

Detailed Description

This class does the word-level Minimum Bayes Risk computation, and gives you either the 1-best MBR output together with the expected Bayes Risk, or a sausage-like structure.

Definition at line 77 of file sausages.h.

Constructor & Destructor Documentation

◆ MinimumBayesRisk() [1/3]

Initialize with compact lattice– any acoustic scaling etc., is assumed to have been done already.

This does the whole computation. You get the output with GetOneBest(), GetBayesRisk(), and GetSausageStats().

Definition at line 369 of file sausages.cc.

References fst::ConvertLattice(), fst::GetLinearSymbolSequence(), KALDI_ASSERT, MinimumBayesRisk::L_, MinimumBayesRisk::MbrDecode(), MinimumBayesRisk::PrepareLatticeAndInitStats(), MinimumBayesRisk::R_, fst::RemoveAlignmentsFromCompactLattice(), and words.

370  : opts_(opts) {
371  CompactLattice clat(clat_in); // copy.
372 
374 
375  // We don't need to look at clat.Start() or clat.Final(state):
376  // we know clat.Start() == 0 since it's topologically sorted,
377  // and clat.Final(state) is Zero() except for One() at the last-
378  // numbered state, thanks to CreateSuperFinal and the topological
379  // sorting.
380 
381  { // Now set R_ to one best in the FST.
382  RemoveAlignmentsFromCompactLattice(&clat); // will be more efficient
383  // in best-path if we do this.
384  Lattice lat;
385  ConvertLattice(clat, &lat); // convert from CompactLattice to Lattice.
386  fst::VectorFst<fst::StdArc> fst;
387  ConvertLattice(lat, &fst); // convert from lattice to normal FST.
388  fst::VectorFst<fst::StdArc> fst_shortest_path;
389  fst::ShortestPath(fst, &fst_shortest_path); // take shortest path of FST.
390  std::vector<int32> alignment, words;
391  fst::TropicalWeight weight;
392  GetLinearSymbolSequence(fst_shortest_path, &alignment, &words, &weight);
393  KALDI_ASSERT(alignment.empty()); // we removed the alignment.
394  R_ = words;
395  L_ = 0.0; // Set current edit-distance to 0 [just so we know
396  // when we're on the 1st iter.]
397  }
398 
399  MbrDecode();
400 
401 }
int32 words[kMaxOrder]
void RemoveAlignmentsFromCompactLattice(MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *fst)
Removes state-level alignments (the strings that are part of the weights).
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
MinimumBayesRiskOptions opts_
Definition: sausages.h:207
void PrepareLatticeAndInitStats(CompactLattice *clat)
Definition: sausages.cc:320
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
std::vector< int32 > R_
Definition: sausages.h:222
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void MbrDecode()
Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper.
Definition: sausages.cc:28

◆ MinimumBayesRisk() [2/3]

MinimumBayesRisk ( const CompactLattice clat,
const std::vector< int32 > &  words,
MinimumBayesRiskOptions  opts = MinimumBayesRiskOptions() 
)

Definition at line 403 of file sausages.cc.

References MinimumBayesRisk::L_, MinimumBayesRisk::MbrDecode(), MinimumBayesRisk::PrepareLatticeAndInitStats(), MinimumBayesRisk::R_, and words.

405  : opts_(opts) {
406  CompactLattice clat(clat_in); // copy.
407 
409 
410  R_ = words;
411  L_ = 0.0;
412 
413  MbrDecode();
414 }
int32 words[kMaxOrder]
MinimumBayesRiskOptions opts_
Definition: sausages.h:207
void PrepareLatticeAndInitStats(CompactLattice *clat)
Definition: sausages.cc:320
std::vector< int32 > R_
Definition: sausages.h:222
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
void MbrDecode()
Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper.
Definition: sausages.cc:28

◆ MinimumBayesRisk() [3/3]

MinimumBayesRisk ( const CompactLattice clat,
const std::vector< int32 > &  words,
const std::vector< std::pair< BaseFloat, BaseFloat > > &  times,
MinimumBayesRiskOptions  opts = MinimumBayesRiskOptions() 
)

Definition at line 416 of file sausages.cc.

References MinimumBayesRisk::L_, MinimumBayesRisk::MbrDecode(), MinimumBayesRisk::PrepareLatticeAndInitStats(), MinimumBayesRisk::R_, MinimumBayesRisk::sausage_times_, and words.

419  : opts_(opts) {
420  CompactLattice clat(clat_in); // copy.
421 
423 
424  R_ = words;
425  sausage_times_ = times;
426  L_ = 0.0;
427 
428  MbrDecode();
429 }
int32 words[kMaxOrder]
MinimumBayesRiskOptions opts_
Definition: sausages.h:207
void PrepareLatticeAndInitStats(CompactLattice *clat)
Definition: sausages.cc:320
std::vector< int32 > R_
Definition: sausages.h:222
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
std::vector< std::pair< BaseFloat, BaseFloat > > sausage_times_
Definition: sausages.h:240
void MbrDecode()
Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper.
Definition: sausages.cc:28

Member Function Documentation

◆ AccStats()

void AccStats ( )
private

Figure 5 of the paper. Outputs to gamma_ and L_.

Definition at line 170 of file sausages.cc.

References MinimumBayesRisk::AddToMap(), MinimumBayesRisk::arcs_, MinimumBayesRisk::EditDistance(), kaldi::Exp(), MinimumBayesRisk::gamma_, rnnlm::i, KALDI_ERR, KALDI_VLOG, KALDI_WARN, MinimumBayesRisk::l(), MinimumBayesRisk::L_, MinimumBayesRisk::Arc::loglike, rnnlm::n, MinimumBayesRisk::pre_, MinimumBayesRisk::r(), MinimumBayesRisk::R_, MinimumBayesRisk::sausage_times_, VectorBase< Real >::SetZero(), MinimumBayesRisk::Arc::start_node, MinimumBayesRisk::state_times_, MinimumBayesRisk::times_, and MinimumBayesRisk::Arc::word.

Referenced by MinimumBayesRisk::MbrDecode().

170  {
171  using std::map;
172 
173  int32 N = static_cast<int32>(pre_.size()) - 1,
174  Q = static_cast<int32>(R_.size());
175 
176  Vector<double> alpha(N+1); // index (1...N)
177  Matrix<double> alpha_dash(N+1, Q+1); // index (1...N, 0...Q)
178  Vector<double> alpha_dash_arc(Q+1); // index 0...Q
179  Matrix<double> beta_dash(N+1, Q+1); // index (1...N, 0...Q)
180  Vector<double> beta_dash_arc(Q+1); // index 0...Q
181  std::vector<char> b_arc(Q+1); // integer in {1,2,3}; index 1...Q
182  std::vector<map<int32, double> > gamma(Q+1); // temp. form of gamma.
183  // index 1...Q [word] -> occ.
184 
185  // The tau maps below are the sums over arcs with the same word label
186  // of the tau_b and tau_e timing quantities mentioned in Appendix C of
187  // the paper... we are using these to get averaged times for both the
188  // the sausage bins and the 1-best output.
189  std::vector<map<int32, double> > tau_b(Q+1), tau_e(Q+1);
190 
191  double Ltmp = EditDistance(N, Q, alpha, alpha_dash, alpha_dash_arc);
192  if (L_ != 0 && Ltmp > L_) { // L_ != 0 is to rule out 1st iter.
193  KALDI_WARN << "Edit distance increased: " << Ltmp << " > "
194  << L_;
195  }
196  L_ = Ltmp;
197  KALDI_VLOG(2) << "L = " << L_;
198  // omit line 10: zero when initialized.
199  beta_dash(N, Q) = 1.0; // Line 11.
200  for (int32 n = N; n >= 2; n--) {
201  for (size_t i = 0; i < pre_[n].size(); i++) {
202  const Arc &arc = arcs_[pre_[n][i]];
203  int32 s_a = arc.start_node, w_a = arc.word;
204  BaseFloat p_a = arc.loglike;
205  alpha_dash_arc(0) = alpha_dash(s_a, 0) + l(w_a, 0, true); // line 14.
206  for (int32 q = 1; q <= Q; q++) { // this loop == lines 15-18.
207  int32 r_q = r(q);
208  double a1 = alpha_dash(s_a, q-1) + l(w_a, r_q),
209  a2 = alpha_dash(s_a, q) + l(w_a, 0, true),
210  a3 = alpha_dash_arc(q-1) + l(0, r_q);
211  if (a1 <= a2) {
212  if (a1 <= a3) { b_arc[q] = 1; alpha_dash_arc(q) = a1; }
213  else { b_arc[q] = 3; alpha_dash_arc(q) = a3; }
214  } else {
215  if (a2 <= a3) { b_arc[q] = 2; alpha_dash_arc(q) = a2; }
216  else { b_arc[q] = 3; alpha_dash_arc(q) = a3; }
217  }
218  }
219  beta_dash_arc.SetZero(); // line 19.
220  for (int32 q = Q; q >= 1; q--) {
221  // line 21:
222  beta_dash_arc(q) += Exp(alpha(s_a) + p_a - alpha(n)) * beta_dash(n, q);
223  switch (static_cast<int>(b_arc[q])) { // lines 22 and 23:
224  case 1:
225  beta_dash(s_a, q-1) += beta_dash_arc(q);
226  // next: gamma(q, w(a)) += beta_dash_arc(q)
227  AddToMap(w_a, beta_dash_arc(q), &(gamma[q]));
228  // next: accumulating times, see decl for tau_b,tau_e
229  AddToMap(w_a, state_times_[s_a] * beta_dash_arc(q), &(tau_b[q]));
230  AddToMap(w_a, state_times_[n] * beta_dash_arc(q), &(tau_e[q]));
231  break;
232  case 2:
233  beta_dash(s_a, q) += beta_dash_arc(q);
234  break;
235  case 3:
236  beta_dash_arc(q-1) += beta_dash_arc(q);
237  // next: gamma(q, epsilon) += beta_dash_arc(q)
238  AddToMap(0, beta_dash_arc(q), &(gamma[q]));
239  // next: accumulating times, see decl for tau_b,tau_e
240  // WARNING: there was an error in Appendix C. If we followed
241  // the instructions there the next line would say state_times_[sa], but
242  // it would be wrong. I will try to publish an erratum.
243  AddToMap(0, state_times_[n] * beta_dash_arc(q), &(tau_b[q]));
244  AddToMap(0, state_times_[n] * beta_dash_arc(q), &(tau_e[q]));
245  break;
246  default:
247  KALDI_ERR << "Invalid b_arc value"; // error in code.
248  }
249  }
250  beta_dash_arc(0) += Exp(alpha(s_a) + p_a - alpha(n)) * beta_dash(n, 0);
251  beta_dash(s_a, 0) += beta_dash_arc(0); // line 26.
252  }
253  }
254  beta_dash_arc.SetZero(); // line 29.
255  for (int32 q = Q; q >= 1; q--) {
256  beta_dash_arc(q) += beta_dash(1, q);
257  beta_dash_arc(q-1) += beta_dash_arc(q);
258  AddToMap(0, beta_dash_arc(q), &(gamma[q]));
259  // the statements below are actually redundant because
260  // state_times_[1] is zero.
261  AddToMap(0, state_times_[1] * beta_dash_arc(q), &(tau_b[q]));
262  AddToMap(0, state_times_[1] * beta_dash_arc(q), &(tau_e[q]));
263  }
264  for (int32 q = 1; q <= Q; q++) { // a check (line 35)
265  double sum = 0.0;
266  for (map<int32, double>::iterator iter = gamma[q].begin();
267  iter != gamma[q].end(); ++iter) sum += iter->second;
268  if (fabs(sum - 1.0) > 0.1)
269  KALDI_WARN << "sum of gamma[" << q << ",s] is " << sum;
270  }
271  // The next part is where we take gamma, and convert
272  // to the class member gamma_, which is using a different
273  // data structure and indexed from zero, not one.
274  gamma_.clear();
275  gamma_.resize(Q);
276  for (int32 q = 1; q <= Q; q++) {
277  for (map<int32, double>::iterator iter = gamma[q].begin();
278  iter != gamma[q].end(); ++iter)
279  gamma_[q-1].push_back(
280  std::make_pair(iter->first, static_cast<BaseFloat>(iter->second)));
281  // sort gamma_[q-1] from largest to smallest posterior.
282  GammaCompare comp;
283  std::sort(gamma_[q-1].begin(), gamma_[q-1].end(), comp);
284  }
285  // We do the same conversion for the state times tau_b and tau_e:
286  // they get turned into the times_ data member, which has zero-based
287  // indexing.
288  times_.clear();
289  times_.resize(Q);
290  sausage_times_.clear();
291  sausage_times_.resize(Q);
292  for (int32 q = 1; q <= Q; q++) {
293  double t_b = 0.0, t_e = 0.0;
294  for (std::vector<std::pair<int32, BaseFloat>>::iterator iter = gamma_[q-1].begin();
295  iter != gamma_[q-1].end(); ++iter) {
296  double w_b = tau_b[q][iter->first], w_e = tau_e[q][iter->first];
297  if (w_b > w_e)
298  KALDI_WARN << "Times out of order"; // this is quite bad.
299  times_[q-1].push_back(
300  std::make_pair(static_cast<BaseFloat>(w_b / iter->second),
301  static_cast<BaseFloat>(w_e / iter->second)));
302  t_b += w_b;
303  t_e += w_e;
304  }
305  sausage_times_[q-1].first = t_b;
306  sausage_times_[q-1].second = t_e;
307  if (sausage_times_[q-1].first > sausage_times_[q-1].second)
308  KALDI_WARN << "Times out of order"; // this is quite bad.
309  if (q > 1 && sausage_times_[q-2].second > sausage_times_[q-1].first) {
310  // We previously had a warning here, but now we'll just set both
311  // those values to their average. It's quite possible for this
312  // condition to happen, but it seems like it would have a bad effect
313  // on the downstream processing, so we fix it.
314  sausage_times_[q-2].second = sausage_times_[q-1].first =
315  0.5 * (sausage_times_[q-2].second + sausage_times_[q-1].first);
316  }
317  }
318 }
double Exp(double x)
Definition: kaldi-math.h:83
static void AddToMap(int32 i, double d, std::map< int32, double > *gamma)
Function used to increment map.
Definition: sausages.h:192
double l(int32 a, int32 b, bool penalize=false)
Without the &#39;penalize&#39; argument this gives us the basic edit-distance function l(a,b), as in the paper.
Definition: sausages.h:157
kaldi::int32 int32
std::vector< std::vector< std::pair< BaseFloat, BaseFloat > > > times_
Definition: sausages.h:235
std::vector< int32 > state_times_
Definition: sausages.h:219
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > gamma_
Definition: sausages.h:229
int32 r(int32 q)
returns r_q, in one-based indexing, as in the paper.
Definition: sausages.h:163
struct rnnlm::@11::@12 n
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< int32 > R_
Definition: sausages.h:222
std::vector< std::pair< BaseFloat, BaseFloat > > sausage_times_
Definition: sausages.h:240
double EditDistance(int32 N, int32 Q, Vector< double > &alpha, Matrix< double > &alpha_dash, Vector< double > &alpha_dash_arc)
Figure 4 of the paper; called from AccStats (Fig. 5)
Definition: sausages.cc:130
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
std::vector< std::vector< int32 > > pre_
For each node in the lattice, a list of arcs entering that node.
Definition: sausages.h:217
std::vector< Arc > arcs_
Arcs in the topologically sorted acceptor form of the word-level lattice, with one final-state...
Definition: sausages.h:213

◆ AddToMap()

static void AddToMap ( int32  i,
double  d,
std::map< int32, double > *  gamma 
)
inlinestaticprivate

Function used to increment map.

Definition at line 192 of file sausages.h.

References rnnlm::d.

Referenced by MinimumBayesRisk::AccStats().

192  {
193  if (d == 0) return;
194  std::pair<const int32, double> pr(i, d);
195  std::pair<std::map<int32, double>::iterator, bool> ret = gamma->insert(pr);
196  if (!ret.second) // not inserted, so add to contents.
197  ret.first->second += d;
198  }

◆ delta()

static BaseFloat delta ( )
inlinestaticprivate

Definition at line 188 of file sausages.h.

188 { return 1.0e-05; }

◆ EditDistance()

double EditDistance ( int32  N,
int32  Q,
Vector< double > &  alpha,
Matrix< double > &  alpha_dash,
Vector< double > &  alpha_dash_arc 
)
private

Figure 4 of the paper; called from AccStats (Fig. 5)

Definition at line 130 of file sausages.cc.

References MinimumBayesRisk::arcs_, kaldi::Exp(), rnnlm::i, kaldi::kLogZeroDouble, MinimumBayesRisk::l(), kaldi::LogAdd(), MinimumBayesRisk::Arc::loglike, rnnlm::n, MinimumBayesRisk::pre_, MinimumBayesRisk::r(), MinimumBayesRisk::Arc::start_node, and MinimumBayesRisk::Arc::word.

Referenced by MinimumBayesRisk::AccStats().

133  {
134  alpha(1) = 0.0; // = log(1). Line 5.
135  alpha_dash(1, 0) = 0.0; // Line 5.
136  for (int32 q = 1; q <= Q; q++)
137  alpha_dash(1, q) = alpha_dash(1, q-1) + l(0, r(q)); // Line 7.
138  for (int32 n = 2; n <= N; n++) {
139  double alpha_n = kLogZeroDouble;
140  for (size_t i = 0; i < pre_[n].size(); i++) {
141  const Arc &arc = arcs_[pre_[n][i]];
142  alpha_n = LogAdd(alpha_n, alpha(arc.start_node) + arc.loglike);
143  }
144  alpha(n) = alpha_n; // Line 10.
145  // Line 11 omitted: matrix was initialized to zero.
146  for (size_t i = 0; i < pre_[n].size(); i++) {
147  const Arc &arc = arcs_[pre_[n][i]];
148  int32 s_a = arc.start_node, w_a = arc.word;
149  BaseFloat p_a = arc.loglike;
150  for (int32 q = 0; q <= Q; q++) {
151  if (q == 0) {
152  alpha_dash_arc(q) = // line 15.
153  alpha_dash(s_a, q) + l(w_a, 0, true);
154  } else { // a1,a2,a3 are the 3 parts of min expression of line 17.
155  int32 r_q = r(q);
156  double a1 = alpha_dash(s_a, q-1) + l(w_a, r_q),
157  a2 = alpha_dash(s_a, q) + l(w_a, 0, true),
158  a3 = alpha_dash_arc(q-1) + l(0, r_q);
159  alpha_dash_arc(q) = std::min(a1, std::min(a2, a3));
160  }
161  // line 19:
162  alpha_dash(n, q) += Exp(alpha(s_a) + p_a - alpha(n)) * alpha_dash_arc(q);
163  }
164  }
165  }
166  return alpha_dash(N, Q); // line 23.
167 }
double Exp(double x)
Definition: kaldi-math.h:83
double l(int32 a, int32 b, bool penalize=false)
Without the &#39;penalize&#39; argument this gives us the basic edit-distance function l(a,b), as in the paper.
Definition: sausages.h:157
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
int32 r(int32 q)
returns r_q, in one-based indexing, as in the paper.
Definition: sausages.h:163
struct rnnlm::@11::@12 n
double LogAdd(double x, double y)
Definition: kaldi-math.h:184
std::vector< std::vector< int32 > > pre_
For each node in the lattice, a list of arcs entering that node.
Definition: sausages.h:217
const double kLogZeroDouble
Definition: kaldi-math.h:129
std::vector< Arc > arcs_
Arcs in the topologically sorted acceptor form of the word-level lattice, with one final-state...
Definition: sausages.h:213

◆ GetBayesRisk()

BaseFloat GetBayesRisk ( ) const
inline

Returns the expected WER over this sentence (assuming model correctness).

Definition at line 137 of file sausages.h.

Referenced by main().

137 { return L_; }

◆ GetOneBest()

const std::vector<int32>& GetOneBest ( ) const
inline

Definition at line 104 of file sausages.h.

Referenced by main().

104  { // gets one-best (with no epsilons)
105  return R_;
106  }
std::vector< int32 > R_
Definition: sausages.h:222

◆ GetOneBestConfidences()

const std::vector<BaseFloat>& GetOneBestConfidences ( ) const
inline

Outputs the confidences for the one-best transcript.

Definition at line 132 of file sausages.h.

Referenced by main().

132  {
133  return one_best_confidences_;
134  }
std::vector< BaseFloat > one_best_confidences_
Definition: sausages.h:250

◆ GetOneBestTimes()

const std::vector<std::pair<BaseFloat, BaseFloat> >& GetOneBestTimes ( ) const
inline

Definition at line 122 of file sausages.h.

Referenced by main().

122  {
123  return one_best_times_; // returns average (start,end) times for each word
124  // corresponding to an entry in the one-best output. This is typically the
125  // appropriate subset of the times in GetTimes() but can be slightly
126  // different if the times for the one-best words overlap, in which case
127  // the times returned by this method do not overlap unlike the times
128  // returned by GetTimes().
129  }
std::vector< std::pair< BaseFloat, BaseFloat > > one_best_times_
Definition: sausages.h:245

◆ GetSausageStats()

const std::vector<std::vector<std::pair<int32, BaseFloat> > >& GetSausageStats ( ) const
inline

Definition at line 139 of file sausages.h.

Referenced by main().

139  {
140  return gamma_;
141  }
std::vector< std::vector< std::pair< int32, BaseFloat > > > gamma_
Definition: sausages.h:229

◆ GetSausageTimes()

const std::vector<std::pair<BaseFloat, BaseFloat> > GetSausageTimes ( ) const
inline

Definition at line 114 of file sausages.h.

Referenced by main().

114  {
115  return sausage_times_; // returns average (start,end) times for each bin.
116  // This is typically the weighted average of the times in GetTimes() but can
117  // be slightly different if the times for the bins overlap, in which case
118  // the times returned by this method do not overlap unlike the times
119  // returned by GetTimes().
120  }
std::vector< std::pair< BaseFloat, BaseFloat > > sausage_times_
Definition: sausages.h:240

◆ GetTimes()

const std::vector<std::vector<std::pair<BaseFloat, BaseFloat> > > GetTimes ( ) const
inline

Definition at line 108 of file sausages.h.

108  {
109  return times_; // returns average (start,end) times for each word in each
110  // bin. These are raw averages without any processing, i.e. time intervals
111  // from different bins can overlap.
112  }
std::vector< std::vector< std::pair< BaseFloat, BaseFloat > > > times_
Definition: sausages.h:235

◆ l()

double l ( int32  a,
int32  b,
bool  penalize = false 
)
inlineprivate

Without the 'penalize' argument this gives us the basic edit-distance function l(a,b), as in the paper.

With the 'penalize' argument it can be interpreted as the edit distance plus the 'delta' from the paper, except that we make a kind of conceptual bug-fix and only apply the delta if the edit-distance was not already zero. This bug-fix was necessary in order to force all the stats to show up, that should show up, and applying the bug-fix makes the sausage stats significantly less sparse.

Definition at line 157 of file sausages.h.

Referenced by MinimumBayesRisk::AccStats(), and MinimumBayesRisk::EditDistance().

157  {
158  if (a == b) return 0.0;
159  else return (penalize ? 1.0 + delta() : 1.0);
160  }
static BaseFloat delta()
Definition: sausages.h:188

◆ MbrDecode()

void MbrDecode ( )
private

Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper.

Definition at line 28 of file sausages.cc.

References MinimumBayesRisk::AccStats(), MinimumBayesRiskOptions::decode_mbr, MinimumBayesRisk::gamma_, rnnlm::i, rnnlm::j, KALDI_VLOG, KALDI_WARN, MinimumBayesRisk::NormalizeEps(), MinimumBayesRisk::one_best_confidences_, MinimumBayesRisk::one_best_times_, MinimumBayesRisk::opts_, MinimumBayesRiskOptions::print_silence, MinimumBayesRisk::R_, MinimumBayesRisk::RemoveEps(), and MinimumBayesRisk::times_.

Referenced by MinimumBayesRisk::MinimumBayesRisk().

28  {
29 
30  for (size_t counter = 0; ; counter++) {
31  NormalizeEps(&R_);
32  AccStats(); // writes to gamma_
33  double delta_Q = 0.0; // change in objective function.
34 
35  one_best_times_.clear();
36  one_best_confidences_.clear();
37 
38  // Caution: q in the line below is (q-1) in the algorithm
39  // in the paper; both R_ and gamma_ are indexed by q-1.
40  for (size_t q = 0; q < R_.size(); q++) {
41  if (opts_.decode_mbr) { // This loop updates R_ [indexed same as gamma_].
42  // gamma_[i] is sorted in reverse order so most likely one is first.
43  const std::vector<std::pair<int32, BaseFloat> > &this_gamma = gamma_[q];
44  double old_gamma = 0, new_gamma = this_gamma[0].second;
45  int32 rq = R_[q], rhat = this_gamma[0].first; // rq: old word, rhat: new.
46  for (size_t j = 0; j < this_gamma.size(); j++)
47  if (this_gamma[j].first == rq) old_gamma = this_gamma[j].second;
48  delta_Q += (old_gamma - new_gamma); // will be 0 or negative; a bound on
49  // change in error.
50  if (rq != rhat)
51  KALDI_VLOG(2) << "Changing word " << rq << " to " << rhat;
52  R_[q] = rhat;
53  }
54  // build the outputs (time, confidences),
55  if (R_[q] != 0 || opts_.print_silence) {
56  // see which 'item' from the sausage-bin should we select,
57  // (not necessarily the 1st one when MBR decoding disabled)
58  int32 s = 0;
59  for (int32 j=0; j<gamma_[q].size(); j++) {
60  if (gamma_[q][j].first == R_[q]) {
61  s = j;
62  break;
63  }
64  }
65  one_best_times_.push_back(times_[q][s]);
66  // post-process the times,
67  size_t i = one_best_times_.size();
68  if (i > 1 && one_best_times_[i-2].second > one_best_times_[i-1].first) {
69  // It's quite possible for this to happen, but it seems like it would
70  // have a bad effect on the downstream processing, so we fix it here.
71  // We resolve overlaps by redistributing the available time interval.
72  BaseFloat prev_right = i > 2 ? one_best_times_[i-3].second : 0.0;
73  BaseFloat left = std::max(prev_right,
74  std::min(one_best_times_[i-2].first,
75  one_best_times_[i-1].first));
76  BaseFloat right = std::max(one_best_times_[i-2].second,
77  one_best_times_[i-1].second);
78  BaseFloat first_dur =
79  one_best_times_[i-2].second - one_best_times_[i-2].first;
80  BaseFloat second_dur =
81  one_best_times_[i-1].second - one_best_times_[i-1].first;
82  BaseFloat mid = first_dur > 0 ? left + (right - left) * first_dur /
83  (first_dur + second_dur) : left;
84  one_best_times_[i-2].first = left;
85  one_best_times_[i-2].second = one_best_times_[i-1].first = mid;
86  one_best_times_[i-1].second = right;
87  }
88  BaseFloat confidence = 0.0;
89  for (int32 j = 0; j < gamma_[q].size(); j++) {
90  if (gamma_[q][j].first == R_[q]) {
91  confidence = gamma_[q][j].second;
92  break;
93  }
94  }
95  one_best_confidences_.push_back(confidence);
96  }
97  }
98  KALDI_VLOG(2) << "Iter = " << counter << ", delta-Q = " << delta_Q;
99  if (delta_Q == 0) break;
100  if (counter > 100) {
101  KALDI_WARN << "Iterating too many times in MbrDecode; stopping.";
102  break;
103  }
104  }
106 }
static void RemoveEps(std::vector< int32 > *vec)
Removes epsilons (symbol 0) from a vector.
Definition: sausages.cc:112
static void NormalizeEps(std::vector< int32 > *vec)
Definition: sausages.cc:119
kaldi::int32 int32
std::vector< std::vector< std::pair< BaseFloat, BaseFloat > > > times_
Definition: sausages.h:235
MinimumBayesRiskOptions opts_
Definition: sausages.h:207
bool print_silence
Boolean configuration parameter: if true, the 1-best path will &#39;keep&#39; the <eps> bins,.
Definition: sausages.h:62
std::vector< BaseFloat > one_best_confidences_
Definition: sausages.h:250
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > gamma_
Definition: sausages.h:229
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< int32 > R_
Definition: sausages.h:222
void AccStats()
Figure 5 of the paper. Outputs to gamma_ and L_.
Definition: sausages.cc:170
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
std::vector< std::pair< BaseFloat, BaseFloat > > one_best_times_
Definition: sausages.h:245
bool decode_mbr
Boolean configuration parameter: if true, we actually update the hypothesis to do MBR decoding (if fa...
Definition: sausages.h:60

◆ NormalizeEps()

void NormalizeEps ( std::vector< int32 > *  vec)
staticprivate

Definition at line 119 of file sausages.cc.

References rnnlm::i, and MinimumBayesRisk::RemoveEps().

Referenced by MinimumBayesRisk::MbrDecode().

119  {
120  RemoveEps(vec);
121  vec->resize(1 + vec->size() * 2);
122  int32 s = vec->size();
123  for (int32 i = s/2 - 1; i >= 0; i--) {
124  (*vec)[i*2 + 1] = (*vec)[i];
125  (*vec)[i*2 + 2] = 0;
126  }
127  (*vec)[0] = 0;
128 }
static void RemoveEps(std::vector< int32 > *vec)
Removes epsilons (symbol 0) from a vector.
Definition: sausages.cc:112
kaldi::int32 int32

◆ PrepareLatticeAndInitStats()

void PrepareLatticeAndInitStats ( CompactLattice clat)
private

Definition at line 320 of file sausages.cc.

References MinimumBayesRisk::arcs_, kaldi::CompactLatticeStateTimes(), fst::CreateSuperFinal(), MinimumBayesRisk::Arc::end_node, rnnlm::i, KALDI_ASSERT, KALDI_ERR, MinimumBayesRisk::Arc::loglike, rnnlm::n, MinimumBayesRisk::pre_, MinimumBayesRisk::Arc::start_node, MinimumBayesRisk::state_times_, and MinimumBayesRisk::Arc::word.

Referenced by MinimumBayesRisk::MinimumBayesRisk().

320  {
321  KALDI_ASSERT(clat != NULL);
322 
323  CreateSuperFinal(clat); // Add super-final state to clat... this is
324  // one of the requirements of the MBR algorithm, as mentioned in the
325  // paper (i.e. just one final state).
326 
327  // Topologically sort the lattice, if not already sorted.
328  kaldi::uint64 props = clat->Properties(fst::kFstProperties, false);
329  if (!(props & fst::kTopSorted)) {
330  if (fst::TopSort(clat) == false)
331  KALDI_ERR << "Cycles detected in lattice.";
332  }
333  CompactLatticeStateTimes(*clat, &state_times_); // work out times of
334  // the states in clat
335  state_times_.push_back(0); // we'll convert to 1-based numbering.
336  for (size_t i = state_times_.size()-1; i > 0; i--)
338 
339  // Now we convert the information in "clat" into a special internal
340  // format (pre_, post_ and arcs_) which allows us to access the
341  // arcs preceding any given state.
342  // Note: in our internal format the states will be numbered from 1,
343  // which involves adding 1 to the OpenFst states.
344  int32 N = clat->NumStates();
345  pre_.resize(N+1);
346 
347  // Careful: "Arc" is a class-member struct, not an OpenFst type of arc as one
348  // would normally assume.
349  for (int32 n = 1; n <= N; n++) {
350  for (fst::ArcIterator<CompactLattice> aiter(*clat, n-1);
351  !aiter.Done();
352  aiter.Next()) {
353  const CompactLatticeArc &carc = aiter.Value();
354  Arc arc; // in our local format.
355  arc.word = carc.ilabel; // == carc.olabel
356  arc.start_node = n;
357  arc.end_node = carc.nextstate + 1; // convert to 1-based.
358  arc.loglike = - (carc.weight.Weight().Value1() +
359  carc.weight.Weight().Value2());
360  // loglike: sum graph/LM and acoustic cost, and negate to
361  // convert to loglikes. We assume acoustic scaling is already done.
362 
363  pre_[arc.end_node].push_back(arcs_.size()); // record index of this arc.
364  arcs_.push_back(arc);
365  }
366  }
367 }
kaldi::int32 int32
std::vector< int32 > state_times_
Definition: sausages.h:219
struct rnnlm::@11::@12 n
Arc::StateId CreateSuperFinal(MutableFst< Arc > *fst)
#define KALDI_ERR
Definition: kaldi-error.h:147
int32 CompactLatticeStateTimes(const CompactLattice &lat, vector< int32 > *times)
As LatticeStateTimes, but in the CompactLattice format.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
fst::ArcTpl< CompactLatticeWeight > CompactLatticeArc
Definition: kaldi-lattice.h:42
std::vector< std::vector< int32 > > pre_
For each node in the lattice, a list of arcs entering that node.
Definition: sausages.h:217
std::vector< Arc > arcs_
Arcs in the topologically sorted acceptor form of the word-level lattice, with one final-state...
Definition: sausages.h:213

◆ r()

int32 r ( int32  q)
inlineprivate

returns r_q, in one-based indexing, as in the paper.

Definition at line 163 of file sausages.h.

Referenced by MinimumBayesRisk::AccStats(), and MinimumBayesRisk::EditDistance().

163 { return R_[q-1]; }
std::vector< int32 > R_
Definition: sausages.h:222

◆ RemoveEps()

void RemoveEps ( std::vector< int32 > *  vec)
staticprivate

Removes epsilons (symbol 0) from a vector.

Definition at line 112 of file sausages.cc.

Referenced by MinimumBayesRisk::MbrDecode(), and MinimumBayesRisk::NormalizeEps().

112  {
113  Int32IsZero pred;
114  vec->erase(std::remove_if (vec->begin(), vec->end(), pred),
115  vec->end());
116 }

Member Data Documentation

◆ arcs_

std::vector<Arc> arcs_
private

Arcs in the topologically sorted acceptor form of the word-level lattice, with one final-state.

Contains (word-symbol, log-likelihood on arc == negated cost). Indexed from zero.

Definition at line 213 of file sausages.h.

Referenced by MinimumBayesRisk::AccStats(), MinimumBayesRisk::EditDistance(), and MinimumBayesRisk::PrepareLatticeAndInitStats().

◆ gamma_

std::vector<std::vector<std::pair<int32, BaseFloat> > > gamma_
private

Definition at line 229 of file sausages.h.

Referenced by MinimumBayesRisk::AccStats(), and MinimumBayesRisk::MbrDecode().

◆ L_

double L_
private

Definition at line 226 of file sausages.h.

Referenced by MinimumBayesRisk::AccStats(), and MinimumBayesRisk::MinimumBayesRisk().

◆ one_best_confidences_

std::vector<BaseFloat> one_best_confidences_
private

Definition at line 250 of file sausages.h.

Referenced by MinimumBayesRisk::MbrDecode().

◆ one_best_times_

std::vector<std::pair<BaseFloat, BaseFloat> > one_best_times_
private

Definition at line 245 of file sausages.h.

Referenced by MinimumBayesRisk::MbrDecode().

◆ opts_

MinimumBayesRiskOptions opts_
private

Definition at line 207 of file sausages.h.

Referenced by MinimumBayesRisk::MbrDecode().

◆ pre_

std::vector<std::vector<int32> > pre_
private

For each node in the lattice, a list of arcs entering that node.

Indexed from 1 (first node == 1).

Definition at line 217 of file sausages.h.

Referenced by MinimumBayesRisk::AccStats(), MinimumBayesRisk::EditDistance(), and MinimumBayesRisk::PrepareLatticeAndInitStats().

◆ R_

std::vector<int32> R_
private

◆ sausage_times_

std::vector<std::pair<BaseFloat, BaseFloat> > sausage_times_
private

Definition at line 240 of file sausages.h.

Referenced by MinimumBayesRisk::AccStats(), and MinimumBayesRisk::MinimumBayesRisk().

◆ state_times_

std::vector<int32> state_times_
private

◆ times_

std::vector<std::vector<std::pair<BaseFloat, BaseFloat> > > times_
private

Definition at line 235 of file sausages.h.

Referenced by MinimumBayesRisk::AccStats(), and MinimumBayesRisk::MbrDecode().


The documentation for this class was generated from the following files: