regtree-mllr-diag-gmm.cc
Go to the documentation of this file.
1 // transform/regtree-mllr-diag-gmm.cc
2 
3 // Copyright 2009-2011 Saarland University; Jan Silovsky
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <utility>
21 using std::pair;
22 #include <vector>
23 using std::vector;
24 
26 
27 namespace kaldi {
28 
29 void RegtreeMllrDiagGmm::Init(int32 num_xforms, int32 dim) {
30  if (num_xforms == 0) { // empty transform
31  xform_matrices_.clear();
32  dim_ = 0; // non-zero dimension is meaningless with empty transform
33  num_xforms_ = 0;
34  bclass2xforms_.clear();
35  } else {
36  KALDI_ASSERT(dim != 0); // if not empty, dim = 0 is meaningless
37  dim_ = dim;
38  num_xforms_ = num_xforms;
39  xform_matrices_.resize(num_xforms);
40  vector< Matrix<BaseFloat> >::iterator xform_itr = xform_matrices_.begin(),
41  xform_itr_end = xform_matrices_.end();
42  for (; xform_itr != xform_itr_end; ++xform_itr) {
43  xform_itr->Resize(dim, dim+1);
44  xform_itr->SetUnit();
45  }
46  }
47 }
48 
50  vector< Matrix<BaseFloat> >::iterator xform_itr = xform_matrices_.begin(),
51  xform_itr_end = xform_matrices_.end();
52  for (; xform_itr != xform_itr_end; ++xform_itr) {
53  xform_itr->SetUnit();
54  }
55 }
56 
58  AmDiagGmm *am) {
59  KALDI_ASSERT(static_cast<int32>(bclass2xforms_.size()) ==
60  regtree.NumBaseclasses());
61  Vector<BaseFloat> extended_mean(dim_+1), xformed_mean(dim_);
62  for (int32 bclass_index = 0, num_bclasses = regtree.NumBaseclasses();
63  bclass_index < num_bclasses; ++bclass_index) {
64  int32 xform_index;
65  if ((xform_index = bclass2xforms_[bclass_index]) > -1) {
66  KALDI_ASSERT(xform_index < num_xforms_);
67  const vector< pair<int32, int32> > &bclass =
68  regtree.GetBaseclass(bclass_index);
69  for (vector< pair<int32, int32> >::const_iterator itr = bclass.begin(),
70  end = bclass.end(); itr != end; ++itr) {
71  SubVector<BaseFloat> tmp_mean(extended_mean.Range(0, dim_));
72  am->GetGaussianMean(itr->first, itr->second, &tmp_mean);
73  extended_mean(dim_) = 1.0;
74  xformed_mean.AddMatVec(1.0, xform_matrices_[xform_index], kNoTrans,
75  extended_mean, 0.0);
76  am->SetGaussianMean(itr->first, itr->second, xformed_mean);
77  } // end iterating over Gaussians in baseclass
78  } // else keep the means untransformed
79  } // end iterating over all baseclasses
80  am->ComputeGconsts();
81 }
82 
83 
85  const AmDiagGmm &am,
86  int32 pdf_index,
87  MatrixBase<BaseFloat> *out) const {
88  KALDI_ASSERT(static_cast<int32>(bclass2xforms_.size()) ==
89  regtree.NumBaseclasses());
90  int32 num_gauss = am.GetPdf(pdf_index).NumGauss();
91  KALDI_ASSERT(out->NumRows() == num_gauss && out->NumCols() == dim_);
92 
93  Vector<BaseFloat> extended_mean(dim_+1);
94  extended_mean(dim_) = 1.0;
95 
96  for (int32 gauss_index = 0; gauss_index < num_gauss; gauss_index++) {
97  int32 bclass_index = regtree.Gauss2BaseclassId(pdf_index, gauss_index);
98  int32 xform_index = bclass2xforms_[bclass_index];
99  if (xform_index > -1) { // use a transform
100  KALDI_ASSERT(xform_index < num_xforms_);
101  SubVector<BaseFloat> tmp_mean(extended_mean.Range(0, dim_));
102  am.GetGaussianMean(pdf_index, gauss_index, &tmp_mean);
103  SubVector<BaseFloat> out_row(out->Row(gauss_index));
104  out_row.AddMatVec(1.0, xform_matrices_[xform_index], kNoTrans,
105  extended_mean, 0.0);
106  } else { // Copy untransformed mean
107  SubVector<BaseFloat> out_row(out->Row(gauss_index));
108  am.GetGaussianMean(pdf_index, gauss_index, &out_row);
109  }
110  }
111 }
112 
113 
114 void RegtreeMllrDiagGmm::Write(std::ostream &out, bool binary) const {
115  WriteToken(out, binary, "<MLLRXFORM>");
116  WriteToken(out, binary, "<NUMXFORMS>");
117  WriteBasicType(out, binary, num_xforms_);
118  WriteToken(out, binary, "<DIMENSION>");
119  WriteBasicType(out, binary, dim_);
120 
121  vector< Matrix<BaseFloat> >::const_iterator xform_itr =
122  xform_matrices_.begin(), xform_itr_end = xform_matrices_.end();
123  for (; xform_itr != xform_itr_end; ++xform_itr) {
124  WriteToken(out, binary, "<XFORM>");
125  xform_itr->Write(out, binary);
126  }
127 
128  WriteToken(out, binary, "<BCLASS2XFORMS>");
129  WriteIntegerVector(out, binary, bclass2xforms_);
130  WriteToken(out, binary, "</MLLRXFORM>");
131 }
132 
133 
134 void RegtreeMllrDiagGmm::Read(std::istream &in, bool binary) {
135  ExpectToken(in, binary, "<MLLRXFORM>");
136  ExpectToken(in, binary, "<NUMXFORMS>");
137  ReadBasicType(in, binary, &num_xforms_);
138  ExpectToken(in, binary, "<DIMENSION>");
139  ReadBasicType(in, binary, &dim_);
140  KALDI_ASSERT(num_xforms_ >= 0 && dim_ >= 0); // can be 0 for empty xform
141 
143  vector< Matrix<BaseFloat> >::iterator xform_itr = xform_matrices_.begin(),
144  xform_itr_end = xform_matrices_.end();
145  for (; xform_itr != xform_itr_end; ++xform_itr) {
146  ExpectToken(in, binary, "<XFORM>");
147  xform_itr->Read(in, binary);
148  KALDI_ASSERT(xform_itr->NumRows() == (xform_itr->NumCols() - 1)
149  && xform_itr->NumRows() == dim_);
150  }
151 
152  ExpectToken(in, binary, "<BCLASS2XFORMS>");
153  ReadIntegerVector(in, binary, &bclass2xforms_);
154  ExpectToken(in, binary, "</MLLRXFORM>");
155 }
156 
157 // ************************************************************************
158 
159 void RegtreeMllrDiagGmmAccs::Init(int32 num_bclass, int32 dim) {
160  if (num_bclass == 0) { // empty stats
161  DeletePointers(&baseclass_stats_);
162  baseclass_stats_.clear();
163  num_baseclasses_ = 0;
164  dim_ = 0; // non-zero dimension is meaningless in empty stats
165  } else {
166  KALDI_ASSERT(dim != 0); // if not empty, dim = 0 is meaningless
167  num_baseclasses_ = num_bclass;
168  dim_ = dim;
169  baseclass_stats_.resize(num_baseclasses_);
170  for (vector<AffineXformStats*>::iterator it = baseclass_stats_.begin(),
171  end = baseclass_stats_.end(); it != end; ++it) {
172  *it = new AffineXformStats();
173  (*it)->Init(dim_, dim_);
174  }
175  }
176 }
177 
179  for (vector<AffineXformStats*>::iterator it = baseclass_stats_.begin(),
180  end = baseclass_stats_.end(); it != end; ++it) {
181  (*it)->SetZero();
182  }
183 }
184 
186  const RegressionTree &regtree, const AmDiagGmm &am,
187  const VectorBase<BaseFloat> &data, int32 pdf_index, BaseFloat weight) {
188  const DiagGmm &pdf = am.GetPdf(pdf_index);
189  int32 num_comp = static_cast<int32>(pdf.NumGauss());
190  Vector<BaseFloat> posterior(num_comp);
191  BaseFloat loglike = pdf.ComponentPosteriors(data, &posterior);
192  posterior.Scale(weight);
193  Vector<double> posterior_d(posterior);
194 
195  Vector<double> data_d(data);
196  Vector<double> inv_var_x(dim_);
197  Vector<double> extended_mean(dim_+1);
198  SpMatrix<double> mean_scatter(dim_+1);
199 
200  for (int32 m = 0; m < num_comp; m++) {
201  unsigned bclass = regtree.Gauss2BaseclassId(pdf_index, m);
202  inv_var_x.CopyFromVec(pdf.inv_vars().Row(m));
203  inv_var_x.MulElements(data_d);
204 
205  // Using SubVector to stop compiler warning
206  SubVector<double> tmp_mean(extended_mean, 0, dim_);
207  pdf.GetComponentMean(m, &tmp_mean); // modifies extended_mean
208  extended_mean(dim_) = 1.0;
209  mean_scatter.SetZero();
210  mean_scatter.AddVec2(1.0, extended_mean);
211 
212  baseclass_stats_[bclass]->beta_ += posterior_d(m);
213  baseclass_stats_[bclass]->K_.AddVecVec(posterior_d(m), inv_var_x,
214  extended_mean);
215  vector< SpMatrix<double> > &G = baseclass_stats_[bclass]->G_;
216  for (int32 d = 0; d < dim_; d++)
217  G[d].AddSp((posterior_d(m) * pdf.inv_vars()(m, d)), mean_scatter);
218  }
219  return loglike;
220 }
221 
223  const RegressionTree &regtree, const AmDiagGmm &am,
224  const VectorBase<BaseFloat> &data, int32 pdf_index, int32 gauss_index,
225  BaseFloat weight) {
226  const DiagGmm &pdf = am.GetPdf(pdf_index);
227  Vector<double> data_d(data);
228  Vector<double> inv_var_x(dim_);
229  Vector<double> extended_mean(dim_+1);
230  double weight_d = static_cast<double>(weight);
231 
232  unsigned bclass = regtree.Gauss2BaseclassId(pdf_index, gauss_index);
233  inv_var_x.CopyFromVec(pdf.inv_vars().Row(gauss_index));
234  inv_var_x.MulElements(data_d);
235 
236  // Using SubVector to stop compiler warning
237  SubVector<double> tmp_mean(extended_mean, 0, dim_);
238  pdf.GetComponentMean(gauss_index, &tmp_mean); // modifies extended_mean
239  extended_mean(dim_) = 1.0;
240  SpMatrix<double> mean_scatter(dim_+1);
241  mean_scatter.AddVec2(1.0, extended_mean);
242 
243  baseclass_stats_[bclass]->beta_ += weight_d;
244  baseclass_stats_[bclass]->K_.AddVecVec(weight_d, inv_var_x, extended_mean);
245  vector< SpMatrix<double> > &G = baseclass_stats_[bclass]->G_;
246  for (int32 d = 0; d < dim_; d++)
247  G[d].AddSp((weight_d * pdf.inv_vars()(gauss_index, d)), mean_scatter);
248 }
249 
250 void RegtreeMllrDiagGmmAccs::Write(std::ostream &out, bool binary) const {
251  WriteToken(out, binary, "<MLLRACCS>");
252  WriteToken(out, binary, "<NUMBASECLASSES>");
253  WriteBasicType(out, binary, num_baseclasses_);
254  WriteToken(out, binary, "<DIMENSION>");
255  WriteBasicType(out, binary, dim_);
256  WriteToken(out, binary, "<STATS>");
257  vector<AffineXformStats*>::const_iterator itr = baseclass_stats_.begin(),
258  end = baseclass_stats_.end();
259  for ( ; itr != end; ++itr)
260  (*itr)->Write(out, binary);
261  WriteToken(out, binary, "</MLLRACCS>");
262 }
263 
264 void RegtreeMllrDiagGmmAccs::Read(std::istream &in, bool binary, bool add) {
265  ExpectToken(in, binary, "<MLLRACCS>");
266  ExpectToken(in, binary, "<NUMBASECLASSES>");
267  ReadBasicType(in, binary, &num_baseclasses_);
268  ExpectToken(in, binary, "<DIMENSION>");
269  ReadBasicType(in, binary, &dim_);
270  KALDI_ASSERT(num_baseclasses_ > 0 && dim_ > 0);
271  baseclass_stats_.resize(num_baseclasses_);
272  ExpectToken(in, binary, "<STATS>");
273  vector<AffineXformStats*>::iterator itr = baseclass_stats_.begin(),
274  end = baseclass_stats_.end();
275  for ( ; itr != end; ++itr) {
276  *itr = new AffineXformStats();
277  (*itr)->Init(dim_, dim_);
278  (*itr)->Read(in, binary, add);
279  }
280  ExpectToken(in, binary, "</MLLRACCS>");
281 }
282 
283 static void ComputeMllrMatrix(const Matrix<double> &K,
284  const vector< SpMatrix<double> > &G,
285  Matrix<BaseFloat> *out) {
286  int32 dim = G.size();
287  Matrix<double> tmp_out(dim, dim+1);
288  for (int32 d = 0; d < dim; d++) {
289  if (G[d].Cond() > 1.0e+9) {
290  KALDI_WARN << "Dim " << d << ": Badly conditioned stats. Setting MLLR "
291  << "transform to unit.";
292  tmp_out.SetUnit();
293  break;
294  }
295  SpMatrix<double> inv_g(G[d]);
296 // KALDI_LOG << "Dim " << d << ": G: max = " << inv_g.Max() << ", min = "
297 // << inv_g.Min() << ", log det = " << inv_g.LogDet(NULL)
298 // << ", cond = " << inv_g.Cond();
299  inv_g.Invert();
300 // KALDI_LOG << "Inv G: max = " << inv_g.Max() << ", min = " << inv_g.Min()
301 // << ", log det = " << inv_g.LogDet(NULL) << ", cond = "
302 // << inv_g.Cond();
303  tmp_out.Row(d).AddSpVec(1.0, inv_g, K.Row(d), 0.0);
304  }
305  out->CopyFromMat(tmp_out, kNoTrans);
306 }
307 
309  const AffineXformStats &stats) {
310  int32 dim = stats.G_.size();
311  Matrix<double> xform_d(xform);
312  Vector<double> xform_row_g(dim + 1);
313  SubMatrix<double> A(xform_d, 0, dim, 0, dim);
314  double obj = TraceMatMat(xform_d, stats.K_, kTrans);
315  for (int32 d = 0; d < dim; d++) {
316  xform_row_g.AddSpVec(1.0, stats.G_[d], xform_d.Row(d), 0.0);
317  obj -= 0.5 * VecVec(xform_row_g, xform_d.Row(d));
318  }
319  return obj;
320 }
321 
323  const RegtreeMllrOptions &opts,
324  RegtreeMllrDiagGmm *out_mllr,
325  BaseFloat *auxf_impr,
326  BaseFloat *t) const {
327  BaseFloat tot_auxf_impr = 0, tot_t = 0;
328  Matrix<BaseFloat> xform_mat(dim_, dim_ + 1);
329  if (opts.use_regtree) { // estimate transforms using a regression tree
330  vector<AffineXformStats*> regclass_stats;
331  vector<int32> base2regclass;
332  bool update_xforms = regtree.GatherStats(baseclass_stats_, opts.min_count,
333  &base2regclass, &regclass_stats);
334  out_mllr->set_bclass2xforms(base2regclass);
335  // If update_xforms == true, none should be negative, else all should be -1
336  if (update_xforms) {
337  out_mllr->Init(regclass_stats.size(), dim_);
338  for (int32 rclass_index = 0, num_rclass = regclass_stats.size();
339  rclass_index < num_rclass; ++rclass_index) {
340  KALDI_ASSERT(regclass_stats[rclass_index]->beta_ >= opts.min_count);
341  xform_mat.SetUnit();
342  BaseFloat obj_old = MllrAuxFunction(xform_mat,
343  *(regclass_stats[rclass_index]));
344  ComputeMllrMatrix(regclass_stats[rclass_index]->K_,
345  regclass_stats[rclass_index]->G_, &xform_mat);
346  out_mllr->SetParameters(xform_mat, rclass_index);
347  BaseFloat obj_new = MllrAuxFunction(xform_mat,
348  *(regclass_stats[rclass_index]));
349  KALDI_LOG << "MLLR: regclass " << (rclass_index)
350  << ": Objective function impr per frame is "
351  << ((obj_new - obj_old)/regclass_stats[rclass_index]->beta_)
352  << " over " << regclass_stats[rclass_index]->beta_
353  << " frames.";
354  KALDI_ASSERT(obj_new >= obj_old - (std::abs(obj_new)+std::abs(obj_old))*1.0e-05);
355  tot_t += regclass_stats[rclass_index]->beta_;
356  tot_auxf_impr += obj_new - obj_old;
357  }
358  } else {
359  out_mllr->Init(1, dim_); // Use a unit transform at the root.
360  }
361  DeletePointers(&regclass_stats);
362  // end of estimation using regression tree
363  } else { // estimate 1 transform per baseclass (if enough count)
364  out_mllr->Init(num_baseclasses_, dim_);
365  vector<int32> base2xforms(num_baseclasses_, -1);
366  for (int32 bclass_index = 0; bclass_index < num_baseclasses_;
367  ++bclass_index) {
368  if (baseclass_stats_[bclass_index]->beta_ > opts.min_count) {
369  base2xforms[bclass_index] = bclass_index;
370  xform_mat.SetUnit();
371  BaseFloat obj_old = MllrAuxFunction(xform_mat,
372  *(baseclass_stats_[bclass_index]));
373  ComputeMllrMatrix(baseclass_stats_[bclass_index]->K_,
374  baseclass_stats_[bclass_index]->G_, &xform_mat);
375  out_mllr->SetParameters(xform_mat, bclass_index);
376  BaseFloat obj_new = MllrAuxFunction(xform_mat,
377  *(baseclass_stats_[bclass_index]));
378  KALDI_LOG << "MLLR: base-class " << (bclass_index)
379  << ": Auxiliary function impr per frame is "
380  << ((obj_new-obj_old)/baseclass_stats_[bclass_index]->beta_);
381  KALDI_ASSERT(obj_new >= obj_old - (std::abs(obj_new)+std::abs(obj_old))*1.0e-05);
382  tot_t += baseclass_stats_[bclass_index]->beta_;
383  tot_auxf_impr += obj_new - obj_old;
384  } else {
385  KALDI_WARN << "For baseclass " << (bclass_index) << " count = "
386  << (baseclass_stats_[bclass_index]->beta_) << " < "
387  << opts.min_count << ": not updating MLLR";
388  tot_t += baseclass_stats_[bclass_index]->beta_;
389  }
390  } // end looping over all baseclasses
391  out_mllr->set_bclass2xforms(base2xforms);
392  } // end of estimating one transform per baseclass
393  if (auxf_impr != NULL) *auxf_impr = tot_auxf_impr;
394  if (t != NULL) *t = tot_t;
395 }
396 
397 } // namespace kaldi
398 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void AccumulateForGaussian(const RegressionTree &regtree, const AmDiagGmm &am, const VectorBase< BaseFloat > &data, int32 pdf_index, int32 gauss_index, BaseFloat weight)
Accumulate stats for a single Gaussian component in the model.
Matrix< double > K_
K_ is the summed outer product of [mean times inverse variance] with [extended data], scaled by the occupation counts; dimension is dim by (dim+1)
int32 Gauss2BaseclassId(size_t pdf_id, size_t gauss_id) const
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
An MLLR mean transformation is an affine transformation of Gaussian means.
std::vector< Matrix< BaseFloat > > xform_matrices_
Transform matrices: size() = num_xforms_.
void Init(int32 num_xforms, int32 dim)
Allocates memory for transform matrix & bias vector.
int32 dim_
Dimension of feature vectors.
int32 ComputeGconsts()
Sets the gconsts for all the PDFs.
Definition: am-diag-gmm.cc:90
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
const std::vector< std::pair< int32, int32 > > & GetBaseclass(int32 bclass) const
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void TransformModel(const RegressionTree &regtree, AmDiagGmm *am)
Apply the transform(s) to all the Gaussian means in the model.
Configuration variables for FMLLR transforms.
void GetComponentMean(int32 gauss, VectorBase< Real > *out) const
Accessor for single component mean.
Definition: diag-gmm-inl.h:135
bool use_regtree
If &#39;true&#39;, find transforms to generate using regression tree.
BaseFloat min_count
Minimum occupancy for computing a transform.
void set_bclass2xforms(const std::vector< int32 > &in)
kaldi::int32 int32
std::vector< int32 > bclass2xforms_
For each baseclass index of which transform to use; -1 => no xform.
void AddSpVec(const Real alpha, const SpMatrix< Real > &M, const VectorBase< Real > &v, const Real beta)
Add symmetric positive definite matrix times vector: this <– beta*this + alpha*M*v.
void Init(int32 num_bclass, int32 dim)
void SetGaussianMean(int32 pdf_index, int32 gauss_index, const VectorBase< BaseFloat > &in)
Mutators.
Definition: am-diag-gmm.h:145
void CopyFromMat(const MatrixBase< OtherReal > &M, MatrixTransposeType trans=kNoTrans)
Copy given matrix. (no resize is done).
void SetUnit()
Sets to zero, except ones along diagonal [for non-square matrices too].
bool GatherStats(const std::vector< AffineXformStats *> &stats_in, double min_count, std::vector< int32 > *regclasses_out, std::vector< AffineXformStats *> *stats_out) const
Parses the regression tree and finds the nodes whose occupancies (read from stats_in) are greater tha...
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
void Read(std::istream &in_stream, bool binary, bool add)
void AddVec2(const Real alpha, const VectorBase< OtherReal > &v)
rank-one update, this <– this + alpha v v&#39;
Definition: sp-matrix.cc:946
BaseFloat ComponentPosteriors(const VectorBase< BaseFloat > &data, Vector< BaseFloat > *posteriors) const
Computes the posterior probabilities of all Gaussian components given a data point.
Definition: diag-gmm.cc:601
int32 NumBaseclasses() const
Accessors (const)
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
Definition: io-funcs-inl.h:232
void MulElements(const VectorBase< Real > &v)
Multiply element-by-element by another vector.
void SetParameters(const MatrixBase< BaseFloat > &mat, int32 regclass)
Mutators.
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
A regression tree is a clustering of Gaussian densities in an acoustic model, such that the group of ...
void SetUnit()
Initialize transform matrix to identity and bias vector to zero.
static void ComputeMllrMatrix(const Matrix< double > &K, const vector< SpMatrix< double > > &G, Matrix< BaseFloat > *out)
void Read(std::istream &in_stream, bool binary)
#define KALDI_WARN
Definition: kaldi-error.h:150
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
void Scale(Real alpha)
Multiplies all elements by this constant.
void AddMatVec(const Real alpha, const MatrixBase< Real > &M, const MatrixTransposeType trans, const VectorBase< Real > &v, const Real beta)
Add matrix times vector : this <– beta*this + alpha*M*v.
Definition: kaldi-vector.cc:92
std::vector< SpMatrix< double > > G_
G_ is the outer product of extended-data, scaled by inverse variance, for each dimension.
void GetTransformedMeans(const RegressionTree &regtree, const AmDiagGmm &am, int32 pdf_index, MatrixBase< BaseFloat > *out) const
Get all the transformed means for a given pdf.
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
Definition: am-diag-gmm.h:119
int32 num_xforms_
Number of transforms == xform_matrices_.size()
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void GetGaussianMean(int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
Definition: am-diag-gmm.h:131
void Write(std::ostream &out_stream, bool binary) const
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
static BaseFloat MllrAuxFunction(const Matrix< BaseFloat > &xform, const AffineXformStats &stats)
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
BaseFloat AccumulateForGmm(const RegressionTree &regtree, const AmDiagGmm &am, const VectorBase< BaseFloat > &data, int32 pdf_index, BaseFloat weight)
Accumulate stats for a single GMM in the model; returns log likelihood.
void Update(const RegressionTree &regtree, const RegtreeMllrOptions &opts, RegtreeMllrDiagGmm *out_mllr, BaseFloat *auxf_impr, BaseFloat *t) const
void Invert(Real *logdet=NULL, Real *det_sign=NULL, bool inverse_needed=true)
matrix inverse.
Definition: sp-matrix.cc:219
#define KALDI_LOG
Definition: kaldi-error.h:153
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:37
Sub-matrix representation.
Definition: kaldi-matrix.h:988
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
void Write(std::ostream &out_stream, bool binary) const
const Matrix< BaseFloat > & inv_vars() const
Definition: diag-gmm.h:180