30 for (
size_t counter = 0; ; counter++) {
40 for (
size_t q = 0; q <
R_.size(); q++) {
43 const std::vector<std::pair<int32, BaseFloat> > &this_gamma =
gamma_[q];
44 double old_gamma = 0, new_gamma = this_gamma[0].second;
45 int32 rq =
R_[q], rhat = this_gamma[0].first;
46 for (
size_t j = 0;
j < this_gamma.size();
j++)
47 if (this_gamma[
j].first == rq) old_gamma = this_gamma[
j].second;
48 delta_Q += (old_gamma - new_gamma);
51 KALDI_VLOG(2) <<
"Changing word " << rq <<
" to " << rhat;
82 BaseFloat mid = first_dur > 0 ? left + (right - left) * first_dur /
83 (first_dur + second_dur) : left;
91 confidence =
gamma_[q][
j].second;
98 KALDI_VLOG(2) <<
"Iter = " << counter <<
", delta-Q = " << delta_Q;
99 if (delta_Q == 0)
break;
101 KALDI_WARN <<
"Iterating too many times in MbrDecode; stopping.";
109 bool operator() (
int32 i) {
return (i == 0); }
114 vec->erase(std::remove_if (vec->begin(), vec->end(), pred),
121 vec->resize(1 + vec->size() * 2);
122 int32 s = vec->size();
123 for (
int32 i = s/2 - 1;
i >= 0;
i--) {
124 (*vec)[
i*2 + 1] = (*vec)[
i];
135 alpha_dash(1, 0) = 0.0;
136 for (
int32 q = 1; q <= Q; q++)
137 alpha_dash(1, q) = alpha_dash(1, q-1) +
l(0,
r(q));
140 for (
size_t i = 0;
i <
pre_[
n].size();
i++) {
146 for (
size_t i = 0;
i <
pre_[
n].size();
i++) {
150 for (
int32 q = 0; q <= Q; q++) {
153 alpha_dash(s_a, q) +
l(w_a, 0,
true);
156 double a1 = alpha_dash(s_a, q-1) +
l(w_a, r_q),
157 a2 = alpha_dash(s_a, q) +
l(w_a, 0,
true),
158 a3 = alpha_dash_arc(q-1) +
l(0, r_q);
159 alpha_dash_arc(q) = std::min(a1, std::min(a2, a3));
162 alpha_dash(
n, q) +=
Exp(alpha(s_a) + p_a - alpha(
n)) * alpha_dash_arc(q);
166 return alpha_dash(N, Q);
174 Q = static_cast<int32>(
R_.size());
181 std::vector<char> b_arc(Q+1);
182 std::vector<map<int32, double> > gamma(Q+1);
189 std::vector<map<int32, double> > tau_b(Q+1), tau_e(Q+1);
191 double Ltmp =
EditDistance(N, Q, alpha, alpha_dash, alpha_dash_arc);
192 if (
L_ != 0 && Ltmp >
L_) {
193 KALDI_WARN <<
"Edit distance increased: " << Ltmp <<
" > " 199 beta_dash(N, Q) = 1.0;
201 for (
size_t i = 0;
i <
pre_[
n].size();
i++) {
205 alpha_dash_arc(0) = alpha_dash(s_a, 0) +
l(w_a, 0,
true);
206 for (
int32 q = 1; q <= Q; q++) {
208 double a1 = alpha_dash(s_a, q-1) +
l(w_a, r_q),
209 a2 = alpha_dash(s_a, q) +
l(w_a, 0,
true),
210 a3 = alpha_dash_arc(q-1) +
l(0, r_q);
212 if (a1 <= a3) { b_arc[q] = 1; alpha_dash_arc(q) = a1; }
213 else { b_arc[q] = 3; alpha_dash_arc(q) = a3; }
215 if (a2 <= a3) { b_arc[q] = 2; alpha_dash_arc(q) = a2; }
216 else { b_arc[q] = 3; alpha_dash_arc(q) = a3; }
220 for (
int32 q = Q; q >= 1; q--) {
222 beta_dash_arc(q) +=
Exp(alpha(s_a) + p_a - alpha(
n)) * beta_dash(
n, q);
223 switch (static_cast<int>(b_arc[q])) {
225 beta_dash(s_a, q-1) += beta_dash_arc(q);
227 AddToMap(w_a, beta_dash_arc(q), &(gamma[q]));
233 beta_dash(s_a, q) += beta_dash_arc(q);
236 beta_dash_arc(q-1) += beta_dash_arc(q);
238 AddToMap(0, beta_dash_arc(q), &(gamma[q]));
250 beta_dash_arc(0) +=
Exp(alpha(s_a) + p_a - alpha(
n)) * beta_dash(
n, 0);
251 beta_dash(s_a, 0) += beta_dash_arc(0);
255 for (
int32 q = Q; q >= 1; q--) {
256 beta_dash_arc(q) += beta_dash(1, q);
257 beta_dash_arc(q-1) += beta_dash_arc(q);
258 AddToMap(0, beta_dash_arc(q), &(gamma[q]));
264 for (
int32 q = 1; q <= Q; q++) {
266 for (map<int32, double>::iterator iter = gamma[q].begin();
267 iter != gamma[q].end(); ++iter) sum += iter->second;
268 if (fabs(sum - 1.0) > 0.1)
269 KALDI_WARN <<
"sum of gamma[" << q <<
",s] is " << sum;
276 for (
int32 q = 1; q <= Q; q++) {
277 for (map<int32, double>::iterator iter = gamma[q].begin();
278 iter != gamma[q].end(); ++iter)
280 std::make_pair(iter->first, static_cast<BaseFloat>(iter->second)));
283 std::sort(
gamma_[q-1].begin(),
gamma_[q-1].end(), comp);
292 for (
int32 q = 1; q <= Q; q++) {
293 double t_b = 0.0, t_e = 0.0;
294 for (std::vector<std::pair<int32, BaseFloat>>::iterator iter =
gamma_[q-1].begin();
295 iter !=
gamma_[q-1].end(); ++iter) {
296 double w_b = tau_b[q][iter->first], w_e = tau_e[q][iter->first];
300 std::make_pair(static_cast<BaseFloat>(w_b / iter->second),
301 static_cast<BaseFloat>(w_e / iter->second)));
328 kaldi::uint64 props = clat->Properties(fst::kFstProperties,
false);
329 if (!(props & fst::kTopSorted)) {
330 if (fst::TopSort(clat) ==
false)
331 KALDI_ERR <<
"Cycles detected in lattice.";
344 int32 N = clat->NumStates();
350 for (fst::ArcIterator<CompactLattice> aiter(*clat,
n-1);
355 arc.
word = carc.ilabel;
358 arc.
loglike = - (carc.weight.Weight().Value1() +
359 carc.weight.Weight().Value2());
364 arcs_.push_back(arc);
386 fst::VectorFst<fst::StdArc>
fst;
388 fst::VectorFst<fst::StdArc> fst_shortest_path;
389 fst::ShortestPath(fst, &fst_shortest_path);
390 std::vector<int32> alignment,
words;
391 fst::TropicalWeight weight;
404 const std::vector<int32> &
words,
417 const std::vector<int32> &
words,
418 const std::vector<std::pair<BaseFloat,BaseFloat> > ×,
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
static void AddToMap(int32 i, double d, std::map< int32, double > *gamma)
Function used to increment map.
void RemoveAlignmentsFromCompactLattice(MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *fst)
Removes state-level alignments (the strings that are part of the weights).
static void RemoveEps(std::vector< int32 > *vec)
Removes epsilons (symbol 0) from a vector.
double l(int32 a, int32 b, bool penalize=false)
Without the 'penalize' argument this gives us the basic edit-distance function l(a,b), as in the paper.
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
static void NormalizeEps(std::vector< int32 > *vec)
The implementation of the Minimum Bayes Risk decoding method described in "Minimum Bayes Risk decodin...
std::vector< std::vector< std::pair< BaseFloat, BaseFloat > > > times_
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
std::vector< int32 > state_times_
MinimumBayesRiskOptions opts_
bool print_silence
Boolean configuration parameter: if true, the 1-best path will 'keep' the <eps> bins,.
void PrepareLatticeAndInitStats(CompactLattice *clat)
std::vector< BaseFloat > one_best_confidences_
std::vector< std::vector< std::pair< int32, BaseFloat > > > gamma_
int32 r(int32 q)
returns r_q, in one-based indexing, as in the paper.
Arc::StateId CreateSuperFinal(MutableFst< Arc > *fst)
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
fst::VectorFst< LatticeArc > Lattice
MinimumBayesRisk(const CompactLattice &clat, MinimumBayesRiskOptions opts=MinimumBayesRiskOptions())
Initialize with compact lattice– any acoustic scaling etc., is assumed to have been done already...
int32 CompactLatticeStateTimes(const CompactLattice &lat, vector< int32 > *times)
As LatticeStateTimes, but in the CompactLattice format.
fst::VectorFst< CompactLatticeArc > CompactLattice
std::vector< std::pair< BaseFloat, BaseFloat > > sausage_times_
double LogAdd(double x, double y)
double EditDistance(int32 N, int32 Q, Vector< double > &alpha, Matrix< double > &alpha_dash, Vector< double > &alpha_dash_arc)
Figure 4 of the paper; called from AccStats (Fig. 5)
void AccStats()
Figure 5 of the paper. Outputs to gamma_ and L_.
#define KALDI_ASSERT(cond)
fst::ArcTpl< CompactLatticeWeight > CompactLatticeArc
std::vector< std::vector< int32 > > pre_
For each node in the lattice, a list of arcs entering that node.
std::vector< std::pair< BaseFloat, BaseFloat > > one_best_times_
void SetZero()
Set vector to all zeros.
const double kLogZeroDouble
std::vector< Arc > arcs_
Arcs in the topologically sorted acceptor form of the word-level lattice, with one final-state...
void MbrDecode()
Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper.
bool decode_mbr
Boolean configuration parameter: if true, we actually update the hypothesis to do MBR decoding (if fa...