ebw-diag-gmm.cc
Go to the documentation of this file.
1 // gmm/ebw-diag-gmm.cc
2 
3 // Copyright 2009-2011 Arnab Ghoshal, Petr Motlicek
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <algorithm> // for std::max
21 #include <string>
22 #include <vector>
23 
24 #include "gmm/diag-gmm.h"
25 #include "gmm/ebw-diag-gmm.h"
26 
27 namespace kaldi {
28 
29 // This function is used inside the EBW update routines.
30 // returns true if all variances were positive.
31 static bool EBWUpdateGaussian(
32  BaseFloat D,
33  GmmFlagsType flags,
34  const VectorBase<double> &orig_mean,
35  const VectorBase<double> &orig_var,
36  const VectorBase<double> &x_stats,
37  const VectorBase<double> &x2_stats,
38  double occ,
39  VectorBase<double> *mean,
40  VectorBase<double> *var,
41  double *auxf_impr) {
42  if (! (flags&(kGmmMeans|kGmmVariances))) { // nothing to do.
43  if (auxf_impr) *auxf_impr = 0.0;
44  mean->CopyFromVec(orig_mean);
45  var->CopyFromVec(orig_var);
46  return true;
47  }
48  KALDI_ASSERT(!( (flags&kGmmVariances) && !(flags&kGmmMeans))
49  && "We didn't make the update cover this case sensibly (update vars not means)");
50 
51  mean->SetZero();
52  var->SetZero();
53  mean->AddVec(D, orig_mean);
54  var->AddVec2(D, orig_mean);
55  var->AddVec(D, orig_var);
56  mean->AddVec(1.0, x_stats);
57  var->AddVec(1.0, x2_stats);
58  BaseFloat scale = 1.0 / (occ + D);
59  mean->Scale(scale);
60  var->Scale(scale);
61  var->AddVec2(-1.0, *mean);
62 
63  if (!(flags&kGmmVariances)) var->CopyFromVec(orig_var);
64  if (!(flags&kGmmMeans)) mean->CopyFromVec(orig_mean);
65 
66  // Return false if any NaN's.
67  for (int32 i = 0; i < mean->Dim(); i++) {
68  double m = ((*mean)(i)), v = ((*var)(i));
69  if (m!=m || v!=v || m-m != 0 || v-v != 0) {
70  return false;
71  }
72  }
73 
74  if (var->Min() > 0.0) {
75  if (auxf_impr != NULL) {
76  // work out auxf improvement.
77  BaseFloat old_auxf = 0.0, new_auxf = 0.0;
78  int32 dim = orig_mean.Dim();
79  for (int32 i = 0; i < dim; i++) {
80  BaseFloat mean_diff = (*mean)(i) - orig_mean(i);
81  old_auxf += (occ+D) * -0.5 * (Log(orig_var(i)) +
82  ((*var)(i) + mean_diff*mean_diff)
83  / orig_var(i));
84  new_auxf += (occ+D) * -0.5 * (Log((*var)(i)) + 1.0);
85 
86  }
87  *auxf_impr = new_auxf - old_auxf;
88  }
89  return true;
90  } else return false;
91 }
92 
93 // Update Gaussian parameters only (no weights)
94 void UpdateEbwDiagGmm(const AccumDiagGmm &num_stats, // with I-smoothing, if used.
95  const AccumDiagGmm &den_stats,
96  GmmFlagsType flags,
97  const EbwOptions &opts,
98  DiagGmm *gmm,
99  BaseFloat *auxf_change_out,
100  BaseFloat *count_out,
101  int32 *num_floored_out) {
102  GmmFlagsType acc_flags = num_stats.Flags();
103  if (flags & ~acc_flags)
104  KALDI_ERR << "Incompatible flags: you requested to update flags \""
105  << GmmFlagsToString(flags) << "\" but accumulators have only \""
106  << GmmFlagsToString(acc_flags) << '"';
107 
108  // It could be that the num stats actually contain the difference between
109  // num and den (for mean and var stats), and den stats only have the weights.
110  bool den_has_stats;
111  if (den_stats.Flags() != acc_flags) {
112  den_has_stats = false;
113  if (den_stats.Flags() != kGmmWeights)
114  KALDI_ERR << "Incompatible flags: num stats have flags \""
115  << GmmFlagsToString(acc_flags) << "\" vs. den stats \""
116  << GmmFlagsToString(den_stats.Flags()) << '"';
117  } else {
118  den_has_stats = true;
119  }
120  int32 num_comp = num_stats.NumGauss();
121  int32 dim = num_stats.Dim();
122  KALDI_ASSERT(num_stats.NumGauss() == den_stats.NumGauss());
123  KALDI_ASSERT(num_stats.Dim() == gmm->Dim());
124  KALDI_ASSERT(gmm->NumGauss() == num_comp);
125 
126  if ( !(flags & (kGmmMeans | kGmmVariances)) ) {
127  return; // Nothing to update.
128  }
129 
130  // copy DiagGMM model and transform this to the normal case
131  DiagGmmNormal diaggmmnormal;
132  gmm->ComputeGconsts();
133  diaggmmnormal.CopyFromDiagGmm(*gmm);
134 
135  // go over all components
136  Vector<double> mean(dim), var(dim), mean_stats(dim), var_stats(dim);
137 
138  for (int32 g = 0; g < num_comp; g++) {
139  BaseFloat num_count = num_stats.occupancy()(g),
140  den_count = den_stats.occupancy()(g);
141  if (num_count == 0.0 && den_count == 0.0) {
142  KALDI_VLOG(2) << "Not updating Gaussian " << g << " since counts are zero";
143  continue;
144  }
145  mean_stats.CopyFromVec(num_stats.mean_accumulator().Row(g));
146  if (den_has_stats)
147  mean_stats.AddVec(-1.0, den_stats.mean_accumulator().Row(g));
148  if (flags & kGmmVariances) {
149  var_stats.CopyFromVec(num_stats.variance_accumulator().Row(g));
150  if (den_has_stats)
151  var_stats.AddVec(-1.0, den_stats.variance_accumulator().Row(g));
152  }
153  double D = (opts.tau + opts.E * den_count) / 2;
154  if (D+num_count-den_count <= 0.0) {
155  // ensure +ve-- can be problem if num count == 0 and E=2.
156  D = -1.0001*(num_count-den_count) + 1.0e-10;
157  KALDI_ASSERT(D+num_count-den_count > 0.0);
158  }
159  // We initialize to half the value of D that would be dictated by E (and
160  // tau); this is part of the strategy used to ensure that the value of D we
161  // use is at least twice the value that would ensure positive variances.
162 
163  int32 iter, max_iter = 100;
164  for (iter = 0; iter < max_iter; iter++) { // will normally break from the loop
165  // the first time.
166  if (EBWUpdateGaussian(D, flags,
167  diaggmmnormal.means_.Row(g),
168  diaggmmnormal.vars_.Row(g),
169  mean_stats, var_stats, num_count-den_count,
170  &mean, &var, NULL)) {
171  // Succeeded in getting all +ve vars at this value of D.
172  // So double D and commit changes.
173  D *= 2.0;
174  double auxf_impr = 0.0;
175  bool ans = EBWUpdateGaussian(D, flags,
176  diaggmmnormal.means_.Row(g),
177  diaggmmnormal.vars_.Row(g),
178  mean_stats, var_stats, num_count-den_count,
179  &mean, &var, &auxf_impr);
180  if (!ans) {
181  KALDI_WARN << "Something went wrong in the EBW update. Check that your"
182  "previous update phase looks reasonable, probably your model is "
183  "already ruined. Reverting to the old values";
184  } else {
185  if (auxf_change_out) *auxf_change_out += auxf_impr;
186  if (count_out) *count_out += den_count; // The idea is that for MMI, this will
187  // reflect the actual #frames trained on (the numerator one would be I-smoothed).
188  // In general (e.g. for MPE), we won't know the #frames.
189  diaggmmnormal.means_.CopyRowFromVec(mean, g);
190  diaggmmnormal.vars_.CopyRowFromVec(var, g);
191  }
192  break;
193  } else {
194  // small step
195  D *= 1.1;
196  }
197  }
198  if (iter > 0 && num_floored_out != NULL) (*num_floored_out)++;
199  if (iter == max_iter) KALDI_WARN << "Dropped off end of loop, recomputing D. (unexpected.)";
200  }
201  // copy to natural representation according to flags.
202  diaggmmnormal.CopyToDiagGmm(gmm, flags);
203  gmm->ComputeGconsts();
204 }
205 
206 
207 void UpdateEbwWeightsDiagGmm(const AccumDiagGmm &num_stats, // should have no I-smoothing
208  const AccumDiagGmm &den_stats,
209  const EbwWeightOptions &opts,
210  DiagGmm *gmm,
211  BaseFloat *auxf_change_out,
212  BaseFloat *count_out) {
213 
214  DiagGmmNormal diaggmmnormal;
215  gmm->ComputeGconsts();
216  diaggmmnormal.CopyFromDiagGmm(*gmm);
217 
218  Vector<double> weights(diaggmmnormal.weights_),
219  num_occs(num_stats.occupancy()),
220  den_occs(den_stats.occupancy());
221  if (opts.tau == 0.0 &&
222  num_occs.Sum() + den_occs.Sum() < opts.min_num_count_weight_update) {
223  KALDI_LOG << "Not updating weights for this state because total count is "
224  << num_occs.Sum() + den_occs.Sum() << " < "
226  if (count_out)
227  *count_out += num_occs.Sum();
228  return;
229  }
230  num_occs.AddVec(opts.tau, weights);
231  KALDI_ASSERT(weights.Dim() == num_occs.Dim() && num_occs.Dim() == den_occs.Dim());
232  if (weights.Dim() == 1) return; // Nothing to do: only one mixture.
233  double weight_auxf_at_start = 0.0, weight_auxf_at_end = 0.0;
234 
235  int32 num_comp = weights.Dim();
236  for (int32 g = 0; g < num_comp; g++) { // c.f. eq. 4.32 in Dan Povey's thesis.
237  weight_auxf_at_start +=
238  num_occs(g) * log (weights(g))
239  - den_occs(g) * weights(g) / diaggmmnormal.weights_(g);
240  }
241  for (int32 iter = 0; iter < 50; iter++) {
242  Vector<double> k_jm(num_comp); // c.f. eq. 4.35
243  double max_m = 0.0;
244  for (int32 g = 0; g < num_comp; g++)
245  max_m = std::max(max_m, den_occs(g)/diaggmmnormal.weights_(g));
246  for (int32 g = 0; g < num_comp; g++)
247  k_jm(g) = max_m - den_occs(g)/diaggmmnormal.weights_(g);
248  for (int32 g = 0; g < num_comp; g++) // c.f. eq. 4.34
249  weights(g) = num_occs(g) + k_jm(g)*weights(g);
250  weights.Scale(1.0 / weights.Sum()); // c.f. eq. 4.34 (denominator)
251  }
252  for (int32 g = 0; g < num_comp; g++) { // weight flooring.
253  if (weights(g) < opts.min_gaussian_weight)
254  weights(g) = opts.min_gaussian_weight;
255  }
256  weights.Scale(1.0 / weights.Sum()); // renormalize after flooring..
257  // floor won't be exact now but doesn't really matter.
258 
259  for (int32 g = 0; g < num_comp; g++) { // c.f. eq. 4.32 in Dan Povey's thesis.
260  weight_auxf_at_end +=
261  num_occs(g) * log (weights(g))
262  - den_occs(g) * weights(g) / diaggmmnormal.weights_(g);
263  }
264 
265  if (auxf_change_out)
266  *auxf_change_out += weight_auxf_at_end - weight_auxf_at_start;
267  if (count_out)
268  *count_out += num_occs.Sum(); // only really valid for MMI [not MPE, or MMI
269  // with canceled stats]
270 
271  diaggmmnormal.weights_.CopyFromVec(weights);
272 
273  // copy to natural representation
274  diaggmmnormal.CopyToDiagGmm(gmm, kGmmAll);
275  gmm->ComputeGconsts();
276 }
277 
278 void UpdateEbwAmDiagGmm(const AccumAmDiagGmm &num_stats, // with I-smoothing, if used.
279  const AccumAmDiagGmm &den_stats,
280  GmmFlagsType flags,
281  const EbwOptions &opts,
282  AmDiagGmm *am_gmm,
283  BaseFloat *auxf_change_out,
284  BaseFloat *count_out,
285  int32 *num_floored_out) {
286  KALDI_ASSERT(num_stats.NumAccs() == den_stats.NumAccs()
287  && num_stats.NumAccs() == am_gmm->NumPdfs());
288 
289  if (auxf_change_out) *auxf_change_out = 0.0;
290  if (count_out) *count_out = 0.0;
291  if (num_floored_out) *num_floored_out = 0.0;
292 
293  for (int32 pdf = 0; pdf < num_stats.NumAccs(); pdf++)
294  UpdateEbwDiagGmm(num_stats.GetAcc(pdf), den_stats.GetAcc(pdf), flags,
295  opts, &(am_gmm->GetPdf(pdf)), auxf_change_out,
296  count_out, num_floored_out);
297 }
298 
299 
300 void UpdateEbwWeightsAmDiagGmm(const AccumAmDiagGmm &num_stats, // with I-smoothing, if used.
301  const AccumAmDiagGmm &den_stats,
302  const EbwWeightOptions &opts,
303  AmDiagGmm *am_gmm,
304  BaseFloat *auxf_change_out,
305  BaseFloat *count_out) {
306  KALDI_ASSERT(num_stats.NumAccs() == den_stats.NumAccs()
307  && num_stats.NumAccs() == am_gmm->NumPdfs());
308 
309  if (auxf_change_out) *auxf_change_out = 0.0;
310  if (count_out) *count_out = 0.0;
311 
312  for (int32 pdf = 0; pdf < num_stats.NumAccs(); pdf++)
313  UpdateEbwWeightsDiagGmm(num_stats.GetAcc(pdf), den_stats.GetAcc(pdf),
314  opts, &(am_gmm->GetPdf(pdf)), auxf_change_out,
315  count_out);
316 }
317 
318 void IsmoothStatsDiagGmm(const AccumDiagGmm &src_stats,
319  double tau,
320  AccumDiagGmm *dst_stats) {
321  KALDI_ASSERT(src_stats.NumGauss() == dst_stats->NumGauss());
322  int32 dim = src_stats.Dim(), num_gauss = src_stats.NumGauss();
323  for (int32 g = 0; g < num_gauss; g++) {
324  double occ = src_stats.occupancy()(g);
325  if (occ != 0.0) { // can only do this for nonzero occupancies...
326  Vector<double> x_stats(dim), x2_stats(dim);
327  if (dst_stats->Flags() & kGmmMeans)
328  x_stats.CopyFromVec(src_stats.mean_accumulator().Row(g));
329  if (dst_stats->Flags() & kGmmVariances)
330  x2_stats.CopyFromVec(src_stats.variance_accumulator().Row(g));
331  x_stats.Scale(tau / occ);
332  x2_stats.Scale(tau / occ);
333  dst_stats->AddStatsForComponent(g, tau, x_stats, x2_stats);
334  }
335  }
336 }
337 
339 void DiagGmmToStats(const DiagGmm &gmm,
340  GmmFlagsType flags,
341  double state_occ,
342  AccumDiagGmm *dst_stats) {
343  dst_stats->Resize(gmm, AugmentGmmFlags(flags));
344  int32 num_gauss = gmm.NumGauss(), dim = gmm.Dim();
345  DiagGmmNormal gmmnormal(gmm);
346  Vector<double> x_stats(dim), x2_stats(dim);
347  for (int32 g = 0; g < num_gauss; g++) {
348  double occ = state_occ * gmmnormal.weights_(g);
349  x_stats.SetZero();
350  x_stats.AddVec(occ, gmmnormal.means_.Row(g));
351  x2_stats.SetZero();
352  x2_stats.AddVec2(occ, gmmnormal.means_.Row(g));
353  x2_stats.AddVec(occ, gmmnormal.vars_.Row(g));
354  dst_stats->AddStatsForComponent(g, occ, x_stats, x2_stats);
355  }
356 }
357 
358 void IsmoothStatsAmDiagGmm(const AccumAmDiagGmm &src_stats,
359  double tau,
360  AccumAmDiagGmm *dst_stats) {
361  int num_pdfs = src_stats.NumAccs();
362  KALDI_ASSERT(num_pdfs == dst_stats->NumAccs());
363  for (int32 pdf = 0; pdf < num_pdfs; pdf++)
364  IsmoothStatsDiagGmm(src_stats.GetAcc(pdf), tau, &(dst_stats->GetAcc(pdf)));
365 }
366 
368  double tau,
369  AccumAmDiagGmm *dst_stats) {
370  int num_pdfs = src_model.NumPdfs();
371  KALDI_ASSERT(num_pdfs == dst_stats->NumAccs());
372  for (int32 pdf = 0; pdf < num_pdfs; pdf++) {
373  AccumDiagGmm tmp_stats;
374  double occ = 1.0; // its value doesn't matter.
375  DiagGmmToStats(src_model.GetPdf(pdf), kGmmAll, occ, &tmp_stats);
376  IsmoothStatsDiagGmm(tmp_stats, tau, &(dst_stats->GetAcc(pdf)));
377  }
378 }
379 
380 
381 
382 } // End of namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 Dim() const
Returns the dimensionality of the Gaussian mean vectors.
Definition: diag-gmm.h:74
GmmFlagsType AugmentGmmFlags(GmmFlagsType f)
Returns "augmented" version of flags: e.g.
Definition: model-common.cc:52
Definition for Gaussian Mixture Model with diagonal covariances in normal mode: where the parameters ...
void UpdateEbwWeightsDiagGmm(const AccumDiagGmm &num_stats, const AccumDiagGmm &den_stats, const EbwWeightOptions &opts, DiagGmm *gmm, BaseFloat *auxf_change_out, BaseFloat *count_out)
int32 ComputeGconsts()
Sets the gconsts.
Definition: diag-gmm.cc:114
const VectorBase< double > & occupancy() const
Definition: mle-diag-gmm.h:183
kaldi::int32 int32
void CopyFromDiagGmm(const DiagGmm &diaggmm)
Copies from given DiagGmm.
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
Real Min() const
Returns the minimum value of any element, or +infinity for the empty vector.
void IsmoothStatsDiagGmm(const AccumDiagGmm &src_stats, double tau, AccumDiagGmm *dst_stats)
I-Smooth the stats. src_stats and dst_stats do not have to be different.
void UpdateEbwDiagGmm(const AccumDiagGmm &num_stats, const AccumDiagGmm &den_stats, GmmFlagsType flags, const EbwOptions &opts, DiagGmm *gmm, BaseFloat *auxf_change_out, BaseFloat *count_out, int32 *num_floored_out)
Definition: ebw-diag-gmm.cc:94
void AddVec2(const Real alpha, const VectorBase< Real > &v)
Add vector : *this = *this + alpha * rv^2 [element-wise squaring].
void UpdateEbwWeightsAmDiagGmm(const AccumAmDiagGmm &num_stats, const AccumAmDiagGmm &den_stats, const EbwWeightOptions &opts, AmDiagGmm *am_gmm, BaseFloat *auxf_change_out, BaseFloat *count_out)
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
float BaseFloat
Definition: kaldi-types.h:29
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
double Log(double x)
Definition: kaldi-math.h:100
void IsmoothStatsAmDiagGmmFromModel(const AmDiagGmm &src_model, double tau, AccumAmDiagGmm *dst_stats)
This version of the I-smoothing function takes a model as input.
const MatrixBase< double > & variance_accumulator() const
Definition: mle-diag-gmm.h:185
const MatrixBase< double > & mean_accumulator() const
Definition: mle-diag-gmm.h:184
GmmFlagsType Flags() const
Definition: mle-diag-gmm.h:182
void AddStatsForComponent(int32 comp_id, double occ, const VectorBase< double > &x_stats, const VectorBase< double > &x2_stats)
Increment the stats for this component by the specified amount (not all parts may be taken...
static bool EBWUpdateGaussian(BaseFloat D, GmmFlagsType flags, const VectorBase< double > &orig_mean, const VectorBase< double > &orig_var, const VectorBase< double > &x_stats, const VectorBase< double > &x2_stats, double occ, VectorBase< double > *mean, VectorBase< double > *var, double *auxf_impr)
Definition: ebw-diag-gmm.cc:31
void IsmoothStatsAmDiagGmm(const AccumAmDiagGmm &src_stats, double tau, AccumAmDiagGmm *dst_stats)
Smooth "dst_stats" with "src_stats".
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
BaseFloat min_num_count_weight_update
Definition: ebw-diag-gmm.h:48
Matrix< double > vars_
diagonal variance
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void Scale(Real alpha)
Multiplies all elements by this constant.
Real Sum() const
Returns sum of the elements.
int32 Dim() const
Returns the dimensionality of the feature vectors.
Definition: mle-diag-gmm.h:126
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
BaseFloat min_gaussian_weight
Definition: ebw-diag-gmm.h:49
Matrix< double > means_
Means.
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
Definition: am-diag-gmm.h:119
const AccumDiagGmm & GetAcc(int32 index) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void CopyRowFromVec(const VectorBase< Real > &v, const MatrixIndexT row)
Copy vector into specific row of matrix.
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
std::string GmmFlagsToString(GmmFlagsType flags)
Convert GMM flags to string.
Definition: model-common.cc:43
void UpdateEbwAmDiagGmm(const AccumAmDiagGmm &num_stats, const AccumAmDiagGmm &den_stats, GmmFlagsType flags, const EbwOptions &opts, AmDiagGmm *am_gmm, BaseFloat *auxf_change_out, BaseFloat *count_out, int32 *num_floored_out)
Vector< double > weights_
weights (not log).
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void SetZero()
Set vector to all zeros.
void Resize(int32 num_gauss, int32 dim, GmmFlagsType flags)
Allocates memory for accumulators.
#define KALDI_LOG
Definition: kaldi-error.h:153
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
int32 NumGauss() const
Returns the number of mixture components.
Definition: mle-diag-gmm.h:124
void DiagGmmToStats(const DiagGmm &gmm, GmmFlagsType flags, double state_occ, AccumDiagGmm *dst_stats)
Creates stats from the GMM. Resizes them as needed.
void CopyToDiagGmm(DiagGmm *diaggmm, GmmFlagsType flags=kGmmAll) const
Copies to DiagGmm the requested parameters.