RegtreeMllrDiagGmmAccs Class Reference

Class for computing the maximum-likelihood estimates of the parameters of an acoustic model that uses diagonal Gaussian mixture models as emission densities. More...

#include <regtree-mllr-diag-gmm.h>

Collaboration diagram for RegtreeMllrDiagGmmAccs:

Public Member Functions

 RegtreeMllrDiagGmmAccs ()
 
 ~RegtreeMllrDiagGmmAccs ()
 
void Init (int32 num_bclass, int32 dim)
 
void SetZero ()
 
BaseFloat AccumulateForGmm (const RegressionTree &regtree, const AmDiagGmm &am, const VectorBase< BaseFloat > &data, int32 pdf_index, BaseFloat weight)
 Accumulate stats for a single GMM in the model; returns log likelihood. More...
 
void AccumulateForGaussian (const RegressionTree &regtree, const AmDiagGmm &am, const VectorBase< BaseFloat > &data, int32 pdf_index, int32 gauss_index, BaseFloat weight)
 Accumulate stats for a single Gaussian component in the model. More...
 
void Update (const RegressionTree &regtree, const RegtreeMllrOptions &opts, RegtreeMllrDiagGmm *out_mllr, BaseFloat *auxf_impr, BaseFloat *t) const
 
void Write (std::ostream &out_stream, bool binary) const
 
void Read (std::istream &in_stream, bool binary, bool add)
 
int32 Dim () const
 Accessors. More...
 
int32 NumBaseClasses () const
 
const std::vector< AffineXformStats * > & baseclass_stats () const
 

Private Member Functions

BaseFloat MllrObjFunction (const Matrix< BaseFloat > &xform, int32 bclass_id) const
 Returns the MLLR objective function for a given transform and baseclass. More...
 
 KALDI_DISALLOW_COPY_AND_ASSIGN (RegtreeMllrDiagGmmAccs)
 

Private Attributes

std::vector< AffineXformStats * > baseclass_stats_
 Per-baseclass stats; used for accumulation. More...
 
int32 num_baseclasses_
 Number of baseclasses. More...
 
int32 dim_
 Dimension of feature vectors. More...
 

Detailed Description

Class for computing the maximum-likelihood estimates of the parameters of an acoustic model that uses diagonal Gaussian mixture models as emission densities.

Definition at line 103 of file regtree-mllr-diag-gmm.h.

Constructor & Destructor Documentation

◆ RegtreeMllrDiagGmmAccs()

Definition at line 105 of file regtree-mllr-diag-gmm.h.

105 {}

◆ ~RegtreeMllrDiagGmmAccs()

Definition at line 106 of file regtree-mllr-diag-gmm.h.

References kaldi::DeletePointers().

void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.

Member Function Documentation

◆ AccumulateForGaussian()

void AccumulateForGaussian ( const RegressionTree regtree,
const AmDiagGmm am,
const VectorBase< BaseFloat > &  data,
int32  pdf_index,
int32  gauss_index,
BaseFloat  weight 
)

Accumulate stats for a single Gaussian component in the model.

Definition at line 222 of file regtree-mllr-diag-gmm.cc.

References SpMatrix< Real >::AddVec2(), VectorBase< Real >::CopyFromVec(), rnnlm::d, RegtreeMllrDiagGmm::dim_, RegressionTree::Gauss2BaseclassId(), DiagGmm::GetComponentMean(), AmDiagGmm::GetPdf(), DiagGmm::inv_vars(), VectorBase< Real >::MulElements(), and MatrixBase< Real >::Row().

225  {
226  const DiagGmm &pdf = am.GetPdf(pdf_index);
227  Vector<double> data_d(data);
228  Vector<double> inv_var_x(dim_);
229  Vector<double> extended_mean(dim_+1);
230  double weight_d = static_cast<double>(weight);
231 
232  unsigned bclass = regtree.Gauss2BaseclassId(pdf_index, gauss_index);
233  inv_var_x.CopyFromVec(pdf.inv_vars().Row(gauss_index));
234  inv_var_x.MulElements(data_d);
235 
236  // Using SubVector to stop compiler warning
237  SubVector<double> tmp_mean(extended_mean, 0, dim_);
238  pdf.GetComponentMean(gauss_index, &tmp_mean); // modifies extended_mean
239  extended_mean(dim_) = 1.0;
240  SpMatrix<double> mean_scatter(dim_+1);
241  mean_scatter.AddVec2(1.0, extended_mean);
242 
243  baseclass_stats_[bclass]->beta_ += weight_d;
244  baseclass_stats_[bclass]->K_.AddVecVec(weight_d, inv_var_x, extended_mean);
245  vector< SpMatrix<double> > &G = baseclass_stats_[bclass]->G_;
246  for (int32 d = 0; d < dim_; d++)
247  G[d].AddSp((weight_d * pdf.inv_vars()(gauss_index, d)), mean_scatter);
248 }
kaldi::int32 int32
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.
int32 dim_
Dimension of feature vectors.

◆ AccumulateForGmm()

BaseFloat AccumulateForGmm ( const RegressionTree regtree,
const AmDiagGmm am,
const VectorBase< BaseFloat > &  data,
int32  pdf_index,
BaseFloat  weight 
)

Accumulate stats for a single GMM in the model; returns log likelihood.

This does not work with multiple feature transforms.

Definition at line 185 of file regtree-mllr-diag-gmm.cc.

References SpMatrix< Real >::AddVec2(), DiagGmm::ComponentPosteriors(), VectorBase< Real >::CopyFromVec(), rnnlm::d, RegtreeMllrDiagGmm::dim_, RegressionTree::Gauss2BaseclassId(), DiagGmm::GetComponentMean(), AmDiagGmm::GetPdf(), DiagGmm::inv_vars(), VectorBase< Real >::MulElements(), DiagGmm::NumGauss(), MatrixBase< Real >::Row(), VectorBase< Real >::Scale(), and PackedMatrix< Real >::SetZero().

Referenced by main(), and UnitTestRegtreeMllrDiagGmm().

187  {
188  const DiagGmm &pdf = am.GetPdf(pdf_index);
189  int32 num_comp = static_cast<int32>(pdf.NumGauss());
190  Vector<BaseFloat> posterior(num_comp);
191  BaseFloat loglike = pdf.ComponentPosteriors(data, &posterior);
192  posterior.Scale(weight);
193  Vector<double> posterior_d(posterior);
194 
195  Vector<double> data_d(data);
196  Vector<double> inv_var_x(dim_);
197  Vector<double> extended_mean(dim_+1);
198  SpMatrix<double> mean_scatter(dim_+1);
199 
200  for (int32 m = 0; m < num_comp; m++) {
201  unsigned bclass = regtree.Gauss2BaseclassId(pdf_index, m);
202  inv_var_x.CopyFromVec(pdf.inv_vars().Row(m));
203  inv_var_x.MulElements(data_d);
204 
205  // Using SubVector to stop compiler warning
206  SubVector<double> tmp_mean(extended_mean, 0, dim_);
207  pdf.GetComponentMean(m, &tmp_mean); // modifies extended_mean
208  extended_mean(dim_) = 1.0;
209  mean_scatter.SetZero();
210  mean_scatter.AddVec2(1.0, extended_mean);
211 
212  baseclass_stats_[bclass]->beta_ += posterior_d(m);
213  baseclass_stats_[bclass]->K_.AddVecVec(posterior_d(m), inv_var_x,
214  extended_mean);
215  vector< SpMatrix<double> > &G = baseclass_stats_[bclass]->G_;
216  for (int32 d = 0; d < dim_; d++)
217  G[d].AddSp((posterior_d(m) * pdf.inv_vars()(m, d)), mean_scatter);
218  }
219  return loglike;
220 }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.
int32 dim_
Dimension of feature vectors.

◆ baseclass_stats()

const std::vector<AffineXformStats*>& baseclass_stats ( ) const
inline

Definition at line 135 of file regtree-mllr-diag-gmm.h.

135  {
136  return baseclass_stats_;
137  }
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.

◆ Dim()

int32 Dim ( ) const
inline

Accessors.

Definition at line 133 of file regtree-mllr-diag-gmm.h.

133 { return dim_; }
int32 dim_
Dimension of feature vectors.

◆ Init()

void Init ( int32  num_bclass,
int32  dim 
)

Definition at line 159 of file regtree-mllr-diag-gmm.cc.

References kaldi::DeletePointers(), RegtreeMllrDiagGmm::dim_, and KALDI_ASSERT.

Referenced by main(), and UnitTestRegtreeMllrDiagGmm().

159  {
160  if (num_bclass == 0) { // empty stats
162  baseclass_stats_.clear();
163  num_baseclasses_ = 0;
164  dim_ = 0; // non-zero dimension is meaningless in empty stats
165  } else {
166  KALDI_ASSERT(dim != 0); // if not empty, dim = 0 is meaningless
167  num_baseclasses_ = num_bclass;
168  dim_ = dim;
170  for (vector<AffineXformStats*>::iterator it = baseclass_stats_.begin(),
171  end = baseclass_stats_.end(); it != end; ++it) {
172  *it = new AffineXformStats();
173  (*it)->Init(dim_, dim_);
174  }
175  }
176 }
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.
int32 num_baseclasses_
Number of baseclasses.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 dim_
Dimension of feature vectors.

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( RegtreeMllrDiagGmmAccs  )
private

◆ MllrObjFunction()

BaseFloat MllrObjFunction ( const Matrix< BaseFloat > &  xform,
int32  bclass_id 
) const
private

Returns the MLLR objective function for a given transform and baseclass.

◆ NumBaseClasses()

int32 NumBaseClasses ( ) const
inline

Definition at line 134 of file regtree-mllr-diag-gmm.h.

134 { return num_baseclasses_; }
int32 num_baseclasses_
Number of baseclasses.

◆ Read()

void Read ( std::istream &  in_stream,
bool  binary,
bool  add 
)

Definition at line 264 of file regtree-mllr-diag-gmm.cc.

References RegtreeMllrDiagGmm::dim_, kaldi::ExpectToken(), KALDI_ASSERT, and kaldi::ReadBasicType().

Referenced by TestMllrAccsIO().

264  {
265  ExpectToken(in, binary, "<MLLRACCS>");
266  ExpectToken(in, binary, "<NUMBASECLASSES>");
267  ReadBasicType(in, binary, &num_baseclasses_);
268  ExpectToken(in, binary, "<DIMENSION>");
269  ReadBasicType(in, binary, &dim_);
270  KALDI_ASSERT(num_baseclasses_ > 0 && dim_ > 0);
272  ExpectToken(in, binary, "<STATS>");
273  vector<AffineXformStats*>::iterator itr = baseclass_stats_.begin(),
274  end = baseclass_stats_.end();
275  for ( ; itr != end; ++itr) {
276  *itr = new AffineXformStats();
277  (*itr)->Init(dim_, dim_);
278  (*itr)->Read(in, binary, add);
279  }
280  ExpectToken(in, binary, "</MLLRACCS>");
281 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.
int32 num_baseclasses_
Number of baseclasses.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 dim_
Dimension of feature vectors.

◆ SetZero()

void SetZero ( )

Definition at line 178 of file regtree-mllr-diag-gmm.cc.

Referenced by main().

178  {
179  for (vector<AffineXformStats*>::iterator it = baseclass_stats_.begin(),
180  end = baseclass_stats_.end(); it != end; ++it) {
181  (*it)->SetZero();
182  }
183 }
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.

◆ Update()

void Update ( const RegressionTree regtree,
const RegtreeMllrOptions opts,
RegtreeMllrDiagGmm out_mllr,
BaseFloat auxf_impr,
BaseFloat t 
) const

Definition at line 322 of file regtree-mllr-diag-gmm.cc.

References kaldi::ComputeMllrMatrix(), kaldi::DeletePointers(), RegtreeMllrDiagGmm::dim_, RegressionTree::GatherStats(), RegtreeMllrDiagGmm::Init(), KALDI_ASSERT, KALDI_LOG, KALDI_WARN, RegtreeMllrOptions::min_count, kaldi::MllrAuxFunction(), RegtreeMllrDiagGmm::set_bclass2xforms(), RegtreeMllrDiagGmm::SetParameters(), MatrixBase< Real >::SetUnit(), and RegtreeMllrOptions::use_regtree.

Referenced by main(), TestMllrAccsIO(), and TestXformMean().

326  {
327  BaseFloat tot_auxf_impr = 0, tot_t = 0;
328  Matrix<BaseFloat> xform_mat(dim_, dim_ + 1);
329  if (opts.use_regtree) { // estimate transforms using a regression tree
330  vector<AffineXformStats*> regclass_stats;
331  vector<int32> base2regclass;
332  bool update_xforms = regtree.GatherStats(baseclass_stats_, opts.min_count,
333  &base2regclass, &regclass_stats);
334  out_mllr->set_bclass2xforms(base2regclass);
335  // If update_xforms == true, none should be negative, else all should be -1
336  if (update_xforms) {
337  out_mllr->Init(regclass_stats.size(), dim_);
338  for (int32 rclass_index = 0, num_rclass = regclass_stats.size();
339  rclass_index < num_rclass; ++rclass_index) {
340  KALDI_ASSERT(regclass_stats[rclass_index]->beta_ >= opts.min_count);
341  xform_mat.SetUnit();
342  BaseFloat obj_old = MllrAuxFunction(xform_mat,
343  *(regclass_stats[rclass_index]));
344  ComputeMllrMatrix(regclass_stats[rclass_index]->K_,
345  regclass_stats[rclass_index]->G_, &xform_mat);
346  out_mllr->SetParameters(xform_mat, rclass_index);
347  BaseFloat obj_new = MllrAuxFunction(xform_mat,
348  *(regclass_stats[rclass_index]));
349  KALDI_LOG << "MLLR: regclass " << (rclass_index)
350  << ": Objective function impr per frame is "
351  << ((obj_new - obj_old)/regclass_stats[rclass_index]->beta_)
352  << " over " << regclass_stats[rclass_index]->beta_
353  << " frames.";
354  KALDI_ASSERT(obj_new >= obj_old - (std::abs(obj_new)+std::abs(obj_old))*1.0e-05);
355  tot_t += regclass_stats[rclass_index]->beta_;
356  tot_auxf_impr += obj_new - obj_old;
357  }
358  } else {
359  out_mllr->Init(1, dim_); // Use a unit transform at the root.
360  }
361  DeletePointers(&regclass_stats);
362  // end of estimation using regression tree
363  } else { // estimate 1 transform per baseclass (if enough count)
364  out_mllr->Init(num_baseclasses_, dim_);
365  vector<int32> base2xforms(num_baseclasses_, -1);
366  for (int32 bclass_index = 0; bclass_index < num_baseclasses_;
367  ++bclass_index) {
368  if (baseclass_stats_[bclass_index]->beta_ > opts.min_count) {
369  base2xforms[bclass_index] = bclass_index;
370  xform_mat.SetUnit();
371  BaseFloat obj_old = MllrAuxFunction(xform_mat,
372  *(baseclass_stats_[bclass_index]));
373  ComputeMllrMatrix(baseclass_stats_[bclass_index]->K_,
374  baseclass_stats_[bclass_index]->G_, &xform_mat);
375  out_mllr->SetParameters(xform_mat, bclass_index);
376  BaseFloat obj_new = MllrAuxFunction(xform_mat,
377  *(baseclass_stats_[bclass_index]));
378  KALDI_LOG << "MLLR: base-class " << (bclass_index)
379  << ": Auxiliary function impr per frame is "
380  << ((obj_new-obj_old)/baseclass_stats_[bclass_index]->beta_);
381  KALDI_ASSERT(obj_new >= obj_old - (std::abs(obj_new)+std::abs(obj_old))*1.0e-05);
382  tot_t += baseclass_stats_[bclass_index]->beta_;
383  tot_auxf_impr += obj_new - obj_old;
384  } else {
385  KALDI_WARN << "For baseclass " << (bclass_index) << " count = "
386  << (baseclass_stats_[bclass_index]->beta_) << " < "
387  << opts.min_count << ": not updating MLLR";
388  tot_t += baseclass_stats_[bclass_index]->beta_;
389  }
390  } // end looping over all baseclasses
391  out_mllr->set_bclass2xforms(base2xforms);
392  } // end of estimating one transform per baseclass
393  if (auxf_impr != NULL) *auxf_impr = tot_auxf_impr;
394  if (t != NULL) *t = tot_t;
395 }
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
static void ComputeMllrMatrix(const Matrix< double > &K, const vector< SpMatrix< double > > &G, Matrix< BaseFloat > *out)
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.
int32 num_baseclasses_
Number of baseclasses.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
static BaseFloat MllrAuxFunction(const Matrix< BaseFloat > &xform, const AffineXformStats &stats)
#define KALDI_LOG
Definition: kaldi-error.h:153
int32 dim_
Dimension of feature vectors.

◆ Write()

void Write ( std::ostream &  out_stream,
bool  binary 
) const

Definition at line 250 of file regtree-mllr-diag-gmm.cc.

References RegtreeMllrDiagGmm::dim_, kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by TestMllrAccsIO().

250  {
251  WriteToken(out, binary, "<MLLRACCS>");
252  WriteToken(out, binary, "<NUMBASECLASSES>");
253  WriteBasicType(out, binary, num_baseclasses_);
254  WriteToken(out, binary, "<DIMENSION>");
255  WriteBasicType(out, binary, dim_);
256  WriteToken(out, binary, "<STATS>");
257  vector<AffineXformStats*>::const_iterator itr = baseclass_stats_.begin(),
258  end = baseclass_stats_.end();
259  for ( ; itr != end; ++itr)
260  (*itr)->Write(out, binary);
261  WriteToken(out, binary, "</MLLRACCS>");
262 }
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
int32 num_baseclasses_
Number of baseclasses.
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
int32 dim_
Dimension of feature vectors.

Member Data Documentation

◆ baseclass_stats_

std::vector<AffineXformStats*> baseclass_stats_
private

Per-baseclass stats; used for accumulation.

Definition at line 141 of file regtree-mllr-diag-gmm.h.

◆ dim_

int32 dim_
private

Dimension of feature vectors.

Definition at line 143 of file regtree-mllr-diag-gmm.h.

◆ num_baseclasses_

int32 num_baseclasses_
private

Number of baseclasses.

Definition at line 142 of file regtree-mllr-diag-gmm.h.


The documentation for this class was generated from the following files: