indirect-diff-diag-gmm.cc
Go to the documentation of this file.
1 // gmm/indirect-diff-diag-gmm.cc
2 
3 // Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
21 
22 namespace kaldi {
23 
24 
26  double ml_count, double ml_x_stats, double ml_x2_stats,
27  double disc_count, double disc_x_stats, double disc_x2_stats,
28  double model_mean, double model_var, BaseFloat min_variance,
29  double *ml_x_stats_deriv, double *ml_x2_stats_deriv) {
30 
31  double model_inv_var = 1.0/model_var,
32  model_inv_var_sq = model_inv_var*model_inv_var,
33  model_mean_sq = model_mean*model_mean;
34 
35  // First get derivative of discriminative objective function w.r.t. the
36  // model mean and variance.
37  // Below: eqs. 11 and 13 in 2005 ICASSP paper on fMPE. Note: the factor of
38  // kappa (in the fMPE case) is assumed to have been accounted for by
39  // scaling the num and den accs at the command-line level. We substituted
40  // eq. 12 into 13 and rearranged to get the second expression.
41  double diff_wrt_model_mean = (1.0/model_var) * (disc_x_stats - model_mean*disc_count),
42  diff_wrt_model_var =
43  0.5 * ((disc_x2_stats - 2*model_mean*disc_x_stats + disc_count*model_mean_sq)
44  * model_inv_var_sq
45  - disc_count*model_inv_var);
46 
47  double stats_mean = ml_x_stats / ml_count,
48  stats_var = ml_x2_stats / ml_count - (ml_x_stats / ml_count)*(ml_x_stats / ml_count);
49 
50  // We assume the "rescaling" update will be as follows. Apologies if this is
51  // a bit confusing. The idea is that if the mean and var from (stats versus
52  // model) differ we assume that the model will be updated with
53  // DoRescalingUpdate(), which takes two sets of ML accs (old and new). The old ML
54  // accs given to the update will be the current ml accumulators we have here in
55  // this function, and the new ML accs will be affected by change in fMPE transform.
56  // The update in DoRescalingUpdate() will preserve any current difference between
57  // the ml stats and the model [represented as a shift in mean and factor in variance].
58  // Concretely: the update in DoRescalingUpdate() will do:
59  //
60  // new_model_mean := old_model_mean + new_stats_mean - old_stats_mean (eq. 1)
61  // new_model_var := max(min_variance, old_model_var * new_stats_var / old_stats_var). (eq. 2)
62  //
63  // We're differentiating back through this process to new_stats_mean.
64  // If the model and the stats were actually the same (e.g. we had been doing ML updates),
65  // then all this is equivalent to what was in the original fMPE paper. It's just
66  // extended to make sense outside of that scenario where you're doing ML.
67 
68  double diff_wrt_stats_mean = diff_wrt_model_mean; // This comes from eq. 1 above.
69  double diff_wrt_stats_var;
70  if (model_var <= min_variance*1.01) {
71  diff_wrt_stats_var = 0.0; // model would be "pinned" at minimum variance.
72  KALDI_VLOG(2) << "Variance derivative is zero (min variance)";
73  } else {
74  diff_wrt_stats_var = diff_wrt_model_var * model_var / stats_var; // note:
75  // the factor "model_var / stats_var" comes from "old_model_var / old_stats_var" in eq. 2.
76  // Also note: the {old_,new_} versions of variables are numerically the same here, at the
77  // point where we're differentiating.
78  }
79 
80  // The next equations don't appear in the paper but represent the backpropagation
81  // of the derivative through the equations:
82  // stats_mean := ml_x_stats / ml_count
83  // stats_var := ml_x2_stats / ml_count - (ml_x_stats/ml_count)^2
84  // [we use stats_mean = ml_x_stats/ml_count, here].
85  *ml_x_stats_deriv = diff_wrt_stats_mean / ml_count - 2 * diff_wrt_stats_var * stats_mean / ml_count;
86  *ml_x2_stats_deriv = diff_wrt_stats_var / ml_count;
87 }
88 
89 
90 
91 
92 // The function for just one GMM. We don't export it as it's not currently
93 // needed outside this file.
94 void GetStatsDerivative(const DiagGmm &gmm,
95  const AccumDiagGmm &num_acc,
96  const AccumDiagGmm &den_acc,
97  const AccumDiagGmm &ml_acc,
98  BaseFloat min_variance,
99  BaseFloat min_gaussian_occupancy,
100  AccumDiagGmm *out_accs) {
101  out_accs->Resize(gmm, kGmmAll);
102  int32 num_gauss = gmm.NumGauss(), dim = gmm.Dim();
103  KALDI_ASSERT(num_gauss == num_acc.NumGauss() && dim == num_acc.Dim());
104  KALDI_ASSERT(num_gauss == den_acc.NumGauss()); // don't check den dim--
105  // in the "compressed" form of stats (where num acc stores diff),
106  // it could be zero.
107  KALDI_ASSERT(num_gauss == ml_acc.NumGauss() && dim == ml_acc.Dim());
108 
109  KALDI_ASSERT((ml_acc.Flags() & (kGmmMeans|kGmmVariances)) ==
111  KALDI_ASSERT((num_acc.Flags() & (kGmmMeans|kGmmVariances)) ==
113  DiagGmmNormal gmm_normal(gmm);
114 
115  // if have_den_stats == false, we assume the num and den have been
116  // "compressed" by putting the difference in mean and var stats in num.
117  bool have_den_stats = ((den_acc.Flags() & (kGmmMeans|kGmmVariances)) != 0);
118 
119  for (int32 gauss = 0; gauss < num_gauss; gauss++) {
120  Vector<double> x_stats_deriv(dim), x2_stats_deriv(dim);
121  double num_count = num_acc.occupancy()(gauss),
122  den_count = den_acc.occupancy()(gauss),
123  ml_count = ml_acc.occupancy()(gauss);
124 
125  if (ml_count <= min_gaussian_occupancy) {
126  // This Gaussian won't be updated since has small count
127  KALDI_WARN << "Skipping Gaussian because very small ML count: (num,den,ml) = "
128  << num_count << ", " << den_count << ", " << ml_count;
129  } else {
130  double disc_count = num_count - den_count;
131  for (int32 d = 0; d < dim; d++) {
132  double disc_x_acc = num_acc.mean_accumulator()(gauss, d)
133  - (have_den_stats ? den_acc.mean_accumulator()(gauss, d) : 0.0),
134  disc_x2_acc = num_acc.variance_accumulator()(gauss, d)
135  - (have_den_stats ? den_acc.variance_accumulator()(gauss, d) : 0.0),
136  ml_x_acc = ml_acc.mean_accumulator()(gauss, d),
137  ml_x2_acc = ml_acc.variance_accumulator()(gauss, d),
138  model_mean = gmm_normal.means_(gauss, d),
139  model_var = gmm_normal.vars_(gauss, d);
140 
141  double x_acc_deriv = 0.0, x2_acc_deriv = 0.0;
142  GetSingleStatsDerivative(ml_count, ml_x_acc, ml_x2_acc,
143  disc_count, disc_x_acc, disc_x2_acc,
144  model_mean, model_var, min_variance,
145  &x_acc_deriv, &x2_acc_deriv);
146 
147  x_stats_deriv(d) = x_acc_deriv;
148  x2_stats_deriv(d) = x2_acc_deriv;
149  }
150  // set the stats to these quantities (we're adding, but the stats
151  // are currently zero).
152  out_accs->AddStatsForComponent(gauss, 0.0, x_stats_deriv, x2_stats_deriv);
153  }
154  }
155 }
156 
158  const AccumAmDiagGmm &num_accs, // for MMI, would equal ml accs.
159  const AccumAmDiagGmm &den_accs,
160  const AccumAmDiagGmm &ml_accs,
161  BaseFloat min_variance,
162  BaseFloat min_gaussian_occupancy,
163  AccumAmDiagGmm *out_accs) {
164  out_accs->Init(gmm, kGmmAll);
165  int32 num_pdfs = gmm.NumPdfs();
166  KALDI_ASSERT(num_accs.NumAccs() == num_pdfs);
167  KALDI_ASSERT(den_accs.NumAccs() == num_pdfs);
168  KALDI_ASSERT(ml_accs.NumAccs() == num_pdfs);
169  for (int32 pdf = 0; pdf < num_pdfs; pdf++)
170  GetStatsDerivative(gmm.GetPdf(pdf), num_accs.GetAcc(pdf), den_accs.GetAcc(pdf),
171  ml_accs.GetAcc(pdf), min_variance, min_gaussian_occupancy,
172  &(out_accs->GetAcc(pdf)));
173 
174 }
175 
176 
177 void DoRescalingUpdate(const AccumDiagGmm &old_ml_acc,
178  const AccumDiagGmm &new_ml_acc,
179  BaseFloat min_variance,
180  BaseFloat min_gaussian_occupancy,
181  DiagGmm *gmm,
182  double *tot_count,
183  double *tot_divergence) {
184  int32 num_gauss = gmm->NumGauss(), dim = gmm->Dim();
185  KALDI_ASSERT(old_ml_acc.NumGauss() == num_gauss &&
186  old_ml_acc.Dim() == dim);
187  KALDI_ASSERT(new_ml_acc.NumGauss() == num_gauss &&
188  new_ml_acc.Dim() == dim);
189  KALDI_ASSERT((old_ml_acc.Flags() & (kGmmMeans|kGmmVariances)) ==
191  KALDI_ASSERT((new_ml_acc.Flags() & (kGmmMeans|kGmmVariances)) ==
193 
194  DiagGmmNormal gmm_normal(*gmm);
195  for (int32 gauss = 0; gauss < num_gauss; gauss++) {
196  double old_ml_count = old_ml_acc.occupancy()(gauss),
197  new_ml_count = new_ml_acc.occupancy()(gauss);
198  if (old_ml_count <= min_gaussian_occupancy ||
199  new_ml_count <= min_gaussian_occupancy) {
200  KALDI_WARN << "Gaussian being skipped because it has small count: (old,new) = "
201  << old_ml_count << ", " << new_ml_count;
202  continue;
203  }
204  *tot_count += new_ml_count;
205  for (int32 d = 0; d < dim; d++) {
206  double old_model_mean = gmm_normal.means_(gauss, d),
207  old_model_var = gmm_normal.vars_(gauss, d),
208  old_ml_mean = old_ml_acc.mean_accumulator()(gauss, d) / old_ml_count,
209  old_ml_var = old_ml_acc.variance_accumulator()(gauss, d) / old_ml_count
210  - old_ml_mean*old_ml_mean,
211  new_ml_mean = new_ml_acc.mean_accumulator()(gauss, d) / new_ml_count,
212  new_ml_var = new_ml_acc.variance_accumulator()(gauss, d) / new_ml_count
213  - new_ml_mean*new_ml_mean,
214  new_model_mean = old_model_mean + new_ml_mean - old_ml_mean,
215  new_model_var = std::max(static_cast<double>(min_variance),
216  old_model_var * new_ml_var / old_ml_var);
217  double divergence =
218  0.5 *(((new_model_mean-old_model_mean)*(new_model_mean-old_model_mean) +
219  new_model_var - old_model_var)/old_model_var +
220  Log(old_model_var / new_model_var));
221  if (divergence < 0.0)
222  KALDI_WARN << "Negative divergence " << divergence;
223  *tot_divergence += divergence * new_ml_count;
224  gmm_normal.means_(gauss, d) = new_model_mean;
225  gmm_normal.vars_(gauss, d) = new_model_var;
226  }
227  }
228  gmm_normal.CopyToDiagGmm(gmm);
229 }
230 
231 
232 void DoRescalingUpdate(const AccumAmDiagGmm &old_ml_accs,
233  const AccumAmDiagGmm &new_ml_accs,
234  BaseFloat min_variance,
235  BaseFloat min_gaussian_occupancy,
236  AmDiagGmm *am_gmm) {
237  int32 num_pdfs = am_gmm->NumPdfs();
238  KALDI_ASSERT(old_ml_accs.NumAccs() == num_pdfs);
239  KALDI_ASSERT(new_ml_accs.NumAccs() == num_pdfs);
240  double tot_count = 0.0, tot_divergence = 0.0;
241  for (int32 pdf = 0; pdf < num_pdfs; pdf++)
242  DoRescalingUpdate(old_ml_accs.GetAcc(pdf), new_ml_accs.GetAcc(pdf),
243  min_variance, min_gaussian_occupancy, &am_gmm->GetPdf(pdf),
244  &tot_count, &tot_divergence);
245  KALDI_LOG << "K-L divergence from old to new model is "
246  << (tot_divergence/tot_count) << " over "
247  << tot_count << " frames.";
248  am_gmm->ComputeGconsts();
249 }
250 
251 
252 
253 } // End of namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 Dim() const
Returns the dimensionality of the Gaussian mean vectors.
Definition: diag-gmm.h:74
int32 ComputeGconsts()
Sets the gconsts for all the PDFs.
Definition: am-diag-gmm.cc:90
Definition for Gaussian Mixture Model with diagonal covariances in normal mode: where the parameters ...
const VectorBase< double > & occupancy() const
Definition: mle-diag-gmm.h:183
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
double Log(double x)
Definition: kaldi-math.h:100
const MatrixBase< double > & variance_accumulator() const
Definition: mle-diag-gmm.h:185
const MatrixBase< double > & mean_accumulator() const
Definition: mle-diag-gmm.h:184
void GetStatsDerivative(const DiagGmm &gmm, const AccumDiagGmm &num_acc, const AccumDiagGmm &den_acc, const AccumDiagGmm &ml_acc, BaseFloat min_variance, BaseFloat min_gaussian_occupancy, AccumDiagGmm *out_accs)
void DoRescalingUpdate(const AccumDiagGmm &old_ml_acc, const AccumDiagGmm &new_ml_acc, BaseFloat min_variance, BaseFloat min_gaussian_occupancy, DiagGmm *gmm, double *tot_count, double *tot_divergence)
GmmFlagsType Flags() const
Definition: mle-diag-gmm.h:182
void AddStatsForComponent(int32 comp_id, double occ, const VectorBase< double > &x_stats, const VectorBase< double > &x2_stats)
Increment the stats for this component by the specified amount (not all parts may be taken...
#define KALDI_WARN
Definition: kaldi-error.h:150
Matrix< double > vars_
diagonal variance
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
int32 Dim() const
Returns the dimensionality of the feature vectors.
Definition: mle-diag-gmm.h:126
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
Matrix< double > means_
Means.
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
Definition: am-diag-gmm.h:119
const AccumDiagGmm & GetAcc(int32 index) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
void Resize(int32 num_gauss, int32 dim, GmmFlagsType flags)
Allocates memory for accumulators.
#define KALDI_LOG
Definition: kaldi-error.h:153
void GetSingleStatsDerivative(double ml_count, double ml_x_stats, double ml_x2_stats, double disc_count, double disc_x_stats, double disc_x2_stats, double model_mean, double model_var, BaseFloat min_variance, double *ml_x_stats_deriv, double *ml_x2_stats_deriv)
void Init(const AmDiagGmm &model, GmmFlagsType flags)
Initializes accumulators for each GMM based on the number of components and dimension.
int32 NumGauss() const
Returns the number of mixture components.
Definition: mle-diag-gmm.h:124
void CopyToDiagGmm(DiagGmm *diaggmm, GmmFlagsType flags=kGmmAll) const
Copies to DiagGmm the requested parameters.