mle-diag-gmm-test.cc
Go to the documentation of this file.
1 // gmm/mle-diag-gmm-test.cc
2 
3 // Copyright 2009-2011 Georg Stemmer; Jan Silovsky; Saarland University;
4 // Microsoft Corporation; Yanmin Qian
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "gmm/diag-gmm.h"
22 #include "gmm/diag-gmm-normal.h"
23 #include "gmm/mle-diag-gmm.h"
24 #include "util/kaldi-io.h"
25 
26 using namespace kaldi;
27 
28 void TestComponentAcc(const DiagGmm &gmm, const Matrix<BaseFloat> &feats) {
29  MleDiagGmmOptions config;
30  AccumDiagGmm est_atonce; // updates all components
31  AccumDiagGmm est_compwise; // updates single components
32 
33  // Initialize estimators
34  est_atonce.Resize(gmm.NumGauss(), gmm.Dim(), kGmmAll);
35  est_atonce.SetZero(kGmmAll);
36  est_compwise.Resize(gmm.NumGauss(),
37  gmm.Dim(), kGmmAll);
38  est_compwise.SetZero(kGmmAll);
39 
40  // accumulate estimators
41  for (int32 i = 0; i < feats.NumRows(); i++) {
42  est_atonce.AccumulateFromDiag(gmm, feats.Row(i), 1.0F);
43  Vector<BaseFloat> post(gmm.NumGauss());
44  gmm.ComponentPosteriors(feats.Row(i), &post);
45  for (int32 m = 0; m < gmm.NumGauss(); m++) {
46  est_compwise.AccumulateForComponent(feats.Row(i), m, post(m));
47  }
48  }
49 
50  DiagGmm gmm_atonce; // model with all components accumulated together
51  DiagGmm gmm_compwise; // model with each component accumulated separately
52  gmm_atonce.Resize(gmm.NumGauss(), gmm.Dim());
53  gmm_compwise.Resize(gmm.NumGauss(), gmm.Dim());
54 
55  MleDiagGmmUpdate(config, est_atonce, kGmmAll, &gmm_atonce, NULL, NULL);
56  MleDiagGmmUpdate(config, est_compwise, kGmmAll, &gmm_compwise, NULL, NULL);
57 
58  // the two ways of updating should result in the same model
59  double loglike0 = 0.0;
60  double loglike1 = 0.0;
61  double loglike2 = 0.0;
62  for (int32 i = 0; i < feats.NumRows(); i++) {
63  loglike0 += static_cast<double>(gmm.LogLikelihood(feats.Row(i)));
64  loglike1 += static_cast<double>(gmm_atonce.LogLikelihood(feats.Row(i)));
65  loglike2 += static_cast<double>(gmm_compwise.LogLikelihood(feats.Row(i)));
66  }
67 
68  std::cout << "Per-frame log-likelihood before update = "
69  << (loglike0/feats.NumRows()) << '\n';
70  std::cout << "Per-frame log-likelihood (accumulating at once) = "
71  << (loglike1/feats.NumRows()) << '\n';
72  std::cout << "Per-frame log-likelihood (accumulating component-wise) = "
73  << (loglike2/feats.NumRows()) << '\n';
74 
75  AssertEqual(loglike1, loglike2, 1.0e-6);
76 
77  if (est_atonce.NumGauss() != gmm.NumGauss()) {
78  KALDI_WARN << "Unable to pass test_update_flags() test because of "
79  "component removal during Update() call (this is normal)";
80  return;
81  } else {
82  KALDI_ASSERT(loglike1 >= loglike0 - (std::abs(loglike1)+std::abs(loglike0))*1.0e-06);
83  KALDI_ASSERT(loglike2 >= loglike0 - (std::abs(loglike2)+std::abs(loglike0))*1.0e-06);
84  }
85 }
86 
88  const Matrix<BaseFloat> &feats,
89  GmmFlagsType flags) {
90  MleDiagGmmOptions config;
91  AccumDiagGmm est_gmm_allp; // updates all params
92  // let's trust that all-params update works
93  AccumDiagGmm est_gmm_somep; // updates params indicated by flags
94 
95  // warm-up estimators
96  est_gmm_allp.Resize(gmm.NumGauss(),
97  gmm.Dim(), kGmmAll);
98  est_gmm_allp.SetZero(kGmmAll);
99  est_gmm_somep.Resize(gmm.NumGauss(),
100  gmm.Dim(), flags);
101  est_gmm_somep.SetZero(flags);
102 
103  // accumulate estimators
104  for (int32 i = 0; i < feats.NumRows(); i++) {
105  est_gmm_allp.AccumulateFromDiag(gmm, feats.Row(i), 1.0F);
106  est_gmm_somep.AccumulateFromDiag(gmm, feats.Row(i), 1.0F);
107  }
108 
109  DiagGmm gmm_all_update; // model with all params updated
110  DiagGmm gmm_some_update; // model with some params updated
111  gmm_all_update.CopyFromDiagGmm(gmm); // init with orig. model
112  gmm_some_update.CopyFromDiagGmm(gmm); // init with orig. model
113 
114  MleDiagGmmUpdate(config, est_gmm_allp, kGmmAll, &gmm_all_update, NULL, NULL);
115  MleDiagGmmUpdate(config, est_gmm_somep, flags, &gmm_some_update, NULL, NULL);
116 
117  if (est_gmm_allp.NumGauss() != gmm.NumGauss()) {
118  KALDI_WARN << "Unable to pass test_update_flags() test because of "
119  "component removal during Update() call (this is normal)";
120  return;
121  }
122 
123  // now back-off the gmm_all_update params that were not updated
124  // in gmm_some_update to orig.
125  if (~flags & kGmmWeights)
126  gmm_all_update.SetWeights(gmm.weights());
127  if (~flags & kGmmMeans) {
128  Matrix<BaseFloat> means(gmm.NumGauss(), gmm.Dim());
129  gmm.GetMeans(&means);
130  gmm_all_update.SetMeans(means);
131  }
132  if (~flags & kGmmVariances) {
133  Matrix<BaseFloat> vars(gmm.NumGauss(), gmm.Dim());
134  gmm.GetVars(&vars);
135  vars.InvertElements();
136  gmm_all_update.SetInvVars(vars);
137  }
138  gmm_all_update.ComputeGconsts();
139 
140  // now both models gmm_all_update, gmm_all_update have the same params updated
141  // compute loglike for models for check
142  double loglike0 = 0.0;
143  double loglike1 = 0.0;
144  double loglike2 = 0.0;
145  for (int32 i = 0; i < feats.NumRows(); i++) {
146  loglike0 += static_cast<double>(
147  gmm.LogLikelihood(feats.Row(i)));
148  loglike1 += static_cast<double>(
149  gmm_all_update.LogLikelihood(feats.Row(i)));
150  loglike2 += static_cast<double>(
151  gmm_some_update.LogLikelihood(feats.Row(i)));
152  }
153  if ((flags & kGmmVariances) && !(flags & kGmmMeans))
154  return; // Don't run the test as the variance update gives a different
155  // answer if you don't update the mean.
156 
157  AssertEqual(loglike1, loglike2, 1.0e-6);
158 }
159 
160 void
161 test_io(const DiagGmm &gmm, const AccumDiagGmm &est_gmm, bool binary,
162  const Matrix<BaseFloat> &feats) {
163  std::cout << "Testing I/O, binary = " << binary << '\n';
164 
165  est_gmm.Write(Output("tmp_stats", binary).Stream(), binary);
166 
167  bool binary_in;
168  AccumDiagGmm est_gmm2;
169  est_gmm2.Resize(est_gmm.NumGauss(),
170  est_gmm.Dim(), kGmmAll);
171  Input ki("tmp_stats", &binary_in);
172  est_gmm2.Read(ki.Stream(), binary_in, false); // not adding
173 
174  Input ki2("tmp_stats", &binary_in);
175  est_gmm2.Read(ki2.Stream(), binary_in, true); // adding
176 
177  est_gmm2.Scale(0.5, kGmmAll);
178  // 0.5 -> make it same as what it would have been if we read just once.
179  // [may affect it due to removal of components with small counts].
180 
181  MleDiagGmmOptions config;
182  DiagGmm gmm1;
183  DiagGmm gmm2;
184  gmm1.CopyFromDiagGmm(gmm);
185  gmm2.CopyFromDiagGmm(gmm);
186  MleDiagGmmUpdate(config, est_gmm, est_gmm.Flags(), &gmm1, NULL, NULL);
187  MleDiagGmmUpdate(config, est_gmm2, est_gmm2.Flags(), &gmm2, NULL, NULL);
188 
189  BaseFloat loglike1 = 0.0;
190  BaseFloat loglike2 = 0.0;
191  for (int32 i = 0; i < feats.NumRows(); i++) {
192  loglike1 += gmm1.LogLikelihood(feats.Row(i));
193  loglike2 += gmm2.LogLikelihood(feats.Row(i));
194  }
195 
196  AssertEqual(loglike1, loglike2, 1.0e-6);
197 
198  unlink("tmp_stats");
199 }
200 
201 void
203  size_t dim = 15; // dimension of the gmm
204  size_t nMix = 9; // number of mixtures in the data
205  size_t maxiterations = 20; // number of iterations for estimation
206 
207  // maximum number of densities in the GMM
208  // larger than the number of mixtures in the data
209  // so that we can test the removal of unseen components
210  int32 maxcomponents = 10;
211 
212  // generate random feature vectors
213  Matrix<BaseFloat> means_f(nMix, dim), vars_f(nMix, dim);
214  // first, generate random mean and variance vectors
215  for (size_t m = 0; m < nMix; m++) {
216  for (size_t d= 0; d < dim; d++) {
217  means_f(m, d) = kaldi::RandGauss()*100.0F;
218  vars_f(m, d) = Exp(kaldi::RandGauss())*1000.0F+ 1.0F;
219  }
220 // std::cout << "Gauss " << m << ": Mean = " << means_f.Row(m) << '\n'
221 // << "Vars = " << vars_f.Row(m) << '\n';
222  }
223  // second, generate 1000 feature vectors for each of the mixture components
224  size_t counter = 0, multiple = 200;
225  Matrix<BaseFloat> feats(nMix*multiple, dim);
226  for (size_t m = 0; m < nMix; m++) {
227  for (size_t i = 0; i < multiple; i++) {
228  for (size_t d = 0; d < dim; d++) {
229  feats(counter, d) = means_f(m, d) + kaldi::RandGauss() *
230  std::sqrt(vars_f(m, d));
231  }
232  counter++;
233  }
234  }
235  // Compute the global mean and variance
236  Vector<BaseFloat> mean_acc(dim);
237  Vector<BaseFloat> var_acc(dim);
238  Vector<BaseFloat> featvec(dim);
239  for (size_t i = 0; i < counter; i++) {
240  featvec.CopyRowFromMat(feats, i);
241  mean_acc.AddVec(1.0, featvec);
242  featvec.ApplyPow(2.0);
243  var_acc.AddVec(1.0, featvec);
244  }
245  mean_acc.Scale(1.0F/counter);
246  var_acc.Scale(1.0F/counter);
247  var_acc.AddVec2(-1.0, mean_acc);
248 // std::cout << "Mean acc = " << mean_acc << '\n' << "Var acc = "
249 // << var_acc << '\n';
250 
251  // write the feature vectors to a file
252  // std::ofstream of("tmpfeats");
253  // of.precision(10);
254  // of << feats;
255  // of.close();
256 
257  // now generate randomly initial values for the GMM
258  Vector<BaseFloat> weights(1);
259  Matrix<BaseFloat> means(1, dim), vars(1, dim), invvars(1, dim);
260  for (size_t d= 0; d < dim; d++) {
261  means(0, d) = kaldi::RandGauss()*100.0F;
262  vars(0, d) = Exp(kaldi::RandGauss()) *10.0F + 1e-5F;
263  }
264  weights(0) = 1.0F;
265  invvars.CopyFromMat(vars);
266  invvars.InvertElements();
267 
268  // new GMM
269  DiagGmm *gmm = new DiagGmm();
270  gmm->Resize(1, dim);
271  gmm->SetWeights(weights);
272  gmm->SetInvVarsAndMeans(invvars, means);
273  gmm->ComputeGconsts();
274 
275  {
276  KALDI_LOG << "Testing natural<>normal conversion";
277  DiagGmmNormal ngmm(*gmm);
278  DiagGmm rgmm;
279  rgmm.Resize(1, dim);
280  ngmm.CopyToDiagGmm(&rgmm);
281 
282  // check contents
283  KALDI_ASSERT(ApproxEqual(weights(0), 1.0F, 1e-6));
284  KALDI_ASSERT(ApproxEqual(gmm->weights()(0), rgmm.weights()(0), 1e-6));
285  for (int32 d = 0; d < dim; d++) {
286  KALDI_ASSERT(ApproxEqual(means.Row(0)(d), ngmm.means_.Row(0)(d), 1e-6));
287  KALDI_ASSERT(ApproxEqual(1./invvars.Row(0)(d), ngmm.vars_.Row(0)(d), 1e-6));
288  KALDI_ASSERT(ApproxEqual(gmm->means_invvars().Row(0)(d), rgmm.means_invvars().Row(0)(d), 1e-6));
289  KALDI_ASSERT(ApproxEqual(gmm->inv_vars().Row(0)(d), rgmm.inv_vars().Row(0)(d), 1e-6));
290  }
291  KALDI_LOG << "OK";
292  }
293 
294  AccumDiagGmm est_gmm;
295 // var_acc.Scale(0.1);
296 // est_gmm.config_.p_variance_floor_vector = &var_acc;
297 
298  MleDiagGmmOptions config;
299  config.min_variance = 0.01;
300  GmmFlagsType flags = kGmmAll; // Should later try reducing this.
301 
302  est_gmm.Resize(gmm->NumGauss(), gmm->Dim(), flags);
303 
304  // iterate
305  size_t iteration = 0;
306  float lastloglike = 0.0;
307  int32 lastloglike_nM = 0;
308 
309  while (iteration < maxiterations) {
310  Vector<BaseFloat> featvec(dim);
311  est_gmm.Resize(gmm->NumGauss(), gmm->Dim(), flags);
312  est_gmm.SetZero(flags);
313  double loglike = 0.0;
314  for (size_t i = 0; i < counter; i++) {
315  featvec.CopyRowFromMat(feats, i);
316  loglike += static_cast<double>(est_gmm.AccumulateFromDiag(*gmm,
317  featvec, 1.0F));
318  }
319  std::cout << "Loglikelihood before iteration " << iteration << " : "
320  << std::scientific << loglike << " number of components: "
321  << gmm->NumGauss() << '\n';
322 
323  // every 5th iteration check loglike change and update lastloglike
324  if (iteration % 5 == 0) {
325  // likelihood should be increasing on the long term
326  if ((iteration > 0) && (gmm->NumGauss() >= lastloglike_nM)) {
327  KALDI_ASSERT(loglike - lastloglike >= -1.0);
328  }
329  lastloglike = loglike;
330  lastloglike_nM = gmm->NumGauss();
331  }
332 
333  // binary write
334  est_gmm.Write(Output("tmp_stats", true).Stream(), true);
335 
336  // binary read
337  bool binary_in;
338  Input ki("tmp_stats", &binary_in);
339  est_gmm.Read(ki.Stream(), binary_in, false); // false = not adding.
340 
341  BaseFloat obj, count;
342  MleDiagGmmUpdate(config, est_gmm, flags, gmm, &obj, &count);
343  KALDI_LOG <<"ML objective function change = " << (obj/count)
344  << " per frame, over " << (count) << " frames.";
345 
346  if ((iteration % 3 == 1) && (gmm->NumGauss() * 2 <= maxcomponents)) {
347  gmm->Split(gmm->NumGauss() * 2, 0.001);
348  }
349 
350  if (iteration == 5) { // run following tests with not too overfitted model
351  std::cout << "Testing flags-driven updates" << '\n';
352  test_flags_driven_update(*gmm, feats, kGmmAll);
354  test_flags_driven_update(*gmm, feats, kGmmMeans);
357  std::cout << "Testing component-wise accumulation" << '\n';
358  TestComponentAcc(*gmm, feats);
359  }
360 
361  iteration++;
362  }
363 
364  { // I/O tests
365  GmmFlagsType flags_all = kGmmAll;
366  est_gmm.Resize(gmm->NumGauss(),
367  gmm->Dim(), flags_all);
368  est_gmm.SetZero(flags_all);
369  float loglike = 0.0;
370  for (size_t i = 0; i < counter; i++) {
371  loglike += est_gmm.AccumulateFromDiag(*gmm, feats.Row(i), 1.0F);
372  }
373  test_io(*gmm, est_gmm, false, feats); // ASCII mode
374  test_io(*gmm, est_gmm, true, feats); // Binary mode
375  }
376 
377  { // Test multi-threaded update.
378  GmmFlagsType flags_all = kGmmAll;
379  est_gmm.Resize(gmm->NumGauss(),
380  gmm->Dim(), flags_all);
381  est_gmm.SetZero(flags_all);
382 
383  Vector<BaseFloat> weights(counter);
384  for (size_t i = 0; i < counter; i++)
385  weights(i) = 0.5 + 0.1 * (Rand() % 10);
386 
387 
388  float loglike = 0.0;
389  for (size_t i = 0; i < counter; i++) {
390  loglike += weights(i) *
391  est_gmm.AccumulateFromDiag(*gmm, feats.Row(i), weights(i));
392  }
393  AccumDiagGmm est_gmm2(*gmm, flags_all);
394  int32 num_threads = 2;
395  float loglike2 =
396  est_gmm2.AccumulateFromDiagMultiThreaded(*gmm, feats, weights, num_threads);
397  AssertEqual(loglike, loglike2);
398  est_gmm.AssertEqual(est_gmm2);
399  }
400 
401 
402  delete gmm;
403 
404  unlink("tmp_stats");
405 }
406 
407 int main() {
408  // repeat the test five times
409  for (int i = 0; i < 2; i++)
411  std::cout << "Test OK.\n";
412 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
double Exp(double x)
Definition: kaldi-math.h:83
int32 Dim() const
Returns the dimensionality of the Gaussian mean vectors.
Definition: diag-gmm.h:74
void CopyFromDiagGmm(const DiagGmm &diaggmm)
Copies from given DiagGmm.
Definition: diag-gmm.cc:83
void SetInvVarsAndMeans(const MatrixBase< Real > &invvars, const MatrixBase< Real > &means)
Use SetInvVarsAndMeans if updating both means and (inverse) variances.
Definition: diag-gmm-inl.h:63
BaseFloat AccumulateFromDiagMultiThreaded(const DiagGmm &gmm, const MatrixBase< BaseFloat > &data, const VectorBase< BaseFloat > &frame_weights, int32 num_threads)
This does the same job as AccumulateFromDiag, but using multiple threads.
void Split(int32 target_components, float perturb_factor, std::vector< int32 > *history=NULL)
Split the components and remember the order in which the components were split.
Definition: diag-gmm.cc:154
Definition for Gaussian Mixture Model with diagonal covariances in normal mode: where the parameters ...
const Matrix< BaseFloat > & means_invvars() const
Definition: diag-gmm.h:179
void test_flags_driven_update(const DiagGmm &gmm, const Matrix< BaseFloat > &feats, GmmFlagsType flags)
void UnitTestEstimateDiagGmm()
void MleDiagGmmUpdate(const MleDiagGmmOptions &config, const AccumDiagGmm &diag_gmm_acc, GmmFlagsType flags, DiagGmm *gmm, BaseFloat *obj_change_out, BaseFloat *count_out, int32 *floored_elements_out, int32 *floored_gaussians_out, int32 *removed_gaussians_out)
for computing the maximum-likelihood estimates of the parameters of a Gaussian mixture model...
void TestComponentAcc(const DiagGmm &gmm, const Matrix< BaseFloat > &feats)
void Resize(int32 nMix, int32 dim)
Resizes arrays to this dim. Does not initialize data.
Definition: diag-gmm.cc:66
int32 ComputeGconsts()
Sets the gconsts.
Definition: diag-gmm.cc:114
float RandGauss(struct RandomState *state=NULL)
Definition: kaldi-math.h:155
kaldi::int32 int32
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void CopyFromMat(const MatrixBase< OtherReal > &M, MatrixTransposeType trans=kNoTrans)
Copy given matrix. (no resize is done).
void SetMeans(const MatrixBase< Real > &m)
Use SetMeans to update only the Gaussian means (and not variances)
Definition: diag-gmm-inl.h:43
double min_variance
Minimum allowed variance in any dimension (if no variance floor) It is in double since the variance i...
Definition: mle-diag-gmm.h:50
void Scale(BaseFloat f, GmmFlagsType flags)
void CopyRowFromMat(const MatrixBase< Real > &M, MatrixIndexT row)
Extracts a row of the matrix M.
void GetVars(Matrix< Real > *v) const
Accessor for covariances.
Definition: diag-gmm-inl.h:115
void AddVec2(const Real alpha, const VectorBase< Real > &v)
Add vector : *this = *this + alpha * rv^2 [element-wise squaring].
const size_t count
std::istream & Stream()
Definition: kaldi-io.cc:826
BaseFloat ComponentPosteriors(const VectorBase< BaseFloat > &data, Vector< BaseFloat > *posteriors) const
Computes the posterior probabilities of all Gaussian components given a data point.
Definition: diag-gmm.cc:601
float BaseFloat
Definition: kaldi-types.h:29
int main()
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
BaseFloat AccumulateFromDiag(const DiagGmm &gmm, const VectorBase< BaseFloat > &data, BaseFloat frame_posterior)
Accumulate for all components given a diagonal-covariance GMM.
GmmFlagsType Flags() const
Definition: mle-diag-gmm.h:182
BaseFloat LogLikelihood(const VectorBase< BaseFloat > &data) const
Returns the log-likelihood of a data point (vector) given the GMM.
Definition: diag-gmm.cc:517
void GetMeans(Matrix< Real > *m) const
Accessor for means.
Definition: diag-gmm-inl.h:123
void Write(std::ostream &out_stream, bool binary) const
Definition: mle-diag-gmm.cc:77
void AccumulateForComponent(const VectorBase< BaseFloat > &data, int32 comp_index, BaseFloat weight)
Accumulate for a single component, given the posterior.
#define KALDI_WARN
Definition: kaldi-error.h:150
const Vector< BaseFloat > & weights() const
Definition: diag-gmm.h:178
Matrix< double > vars_
diagonal variance
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
Configuration variables like variance floor, minimum occupancy, etc.
Definition: mle-diag-gmm.h:38
void Scale(Real alpha)
Multiplies all elements by this constant.
void Read(std::istream &in_stream, bool binary, bool add)
Definition: mle-diag-gmm.cc:33
void SetZero(GmmFlagsType flags)
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
void SetInvVars(const MatrixBase< Real > &v)
Set the (inverse) variances and recompute means_invvars_.
Definition: diag-gmm-inl.h:78
int32 Dim() const
Returns the dimensionality of the feature vectors.
Definition: mle-diag-gmm.h:126
Matrix< double > means_
Means.
void InvertElements()
Inverts all the elements of the matrix.
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void ApplyPow(Real power)
Take all elements of vector to a power.
Definition: kaldi-vector.h:179
void test_io(const DiagGmm &gmm, const AccumDiagGmm &est_gmm, bool binary, const Matrix< BaseFloat > &feats)
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
void SetWeights(const VectorBase< Real > &w)
Mutators for both float or double.
Definition: diag-gmm-inl.h:28
void Resize(int32 num_gauss, int32 dim, GmmFlagsType flags)
Allocates memory for accumulators.
#define KALDI_LOG
Definition: kaldi-error.h:153
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Definition: kaldi-math.h:265
void AssertEqual(const AccumDiagGmm &other)
int32 NumGauss() const
Returns the number of mixture components.
Definition: mle-diag-gmm.h:124
const Matrix< BaseFloat > & inv_vars() const
Definition: diag-gmm.h:180
void CopyToDiagGmm(DiagGmm *diaggmm, GmmFlagsType flags=kGmmAll) const
Copies to DiagGmm the requested parameters.