MlltAccs Class Reference

A class for estimating Maximum Likelihood Linear Transform, also known as global Semi-tied Covariance (STC), for GMMs. More...

#include <mllt.h>

Collaboration diagram for MlltAccs:

Public Member Functions

 MlltAccs ()
 
 MlltAccs (int32 dim, BaseFloat rand_prune=0.25)
 Need rand_prune >= 0. More...
 
void Init (int32 dim, BaseFloat rand_prune=0.25)
 initializes (destroys anything that was there before). More...
 
void Read (std::istream &is, bool binary, bool add=false)
 
void Write (std::ostream &os, bool binary) const
 
int32 Dim ()
 
void Update (MatrixBase< BaseFloat > *M, BaseFloat *objf_impr_out, BaseFloat *count_out) const
 The Update function does the ML update; it requires that M has the right size. More...
 
void AccumulateFromPosteriors (const DiagGmm &gmm, const VectorBase< BaseFloat > &data, const VectorBase< BaseFloat > &posteriors)
 
BaseFloat AccumulateFromGmm (const DiagGmm &gmm, const VectorBase< BaseFloat > &data, BaseFloat weight)
 
BaseFloat AccumulateFromGmmPreselect (const DiagGmm &gmm, const std::vector< int32 > &gselect, const VectorBase< BaseFloat > &data, BaseFloat weight)
 

Static Public Member Functions

static void Update (double beta, const std::vector< SpMatrix< double > > &G, MatrixBase< BaseFloat > *M, BaseFloat *objf_impr_out, BaseFloat *count_out)
 

Public Attributes

BaseFloat rand_prune_
 rand_prune_ controls randomized pruning; the larger it is, the more pruning we do. More...
 
double beta_
 
std::vector< SpMatrix< double > > G_
 

Detailed Description

A class for estimating Maximum Likelihood Linear Transform, also known as global Semi-tied Covariance (STC), for GMMs.

The resulting transform left-multiplies the feature vector.

Definition at line 42 of file mllt.h.

Constructor & Destructor Documentation

◆ MlltAccs() [1/2]

MlltAccs ( )
inline

Definition at line 44 of file mllt.h.

44 : rand_prune_(0.0), beta_(0.0) { }
double beta_
Definition: mllt.h:108
BaseFloat rand_prune_
rand_prune_ controls randomized pruning; the larger it is, the more pruning we do.
Definition: mllt.h:107

◆ MlltAccs() [2/2]

MlltAccs ( int32  dim,
BaseFloat  rand_prune = 0.25 
)
inline

Need rand_prune >= 0.

The larger it is, the faster it will be. Zero is exact. If a posterior p < rand_prune, will set p to rand_prune with probability (p/rand_prune), otherwise zero. E.g. 10 will give 10x speedup.

Definition at line 51 of file mllt.h.

References MlltAccs::Init(), MlltAccs::Read(), and MlltAccs::Write().

51 { Init(dim, rand_prune); }
void Init(int32 dim, BaseFloat rand_prune=0.25)
initializes (destroys anything that was there before).
Definition: mllt.cc:25

Member Function Documentation

◆ AccumulateFromGmm()

BaseFloat AccumulateFromGmm ( const DiagGmm gmm,
const VectorBase< BaseFloat > &  data,
BaseFloat  weight 
)

Definition at line 162 of file mllt.cc.

References MlltAccs::AccumulateFromPosteriors(), DiagGmm::ComponentPosteriors(), and DiagGmm::NumGauss().

Referenced by MlltAccs::Update().

164  { // e.g. weight = 1.0
165  Vector<BaseFloat> posteriors(gmm.NumGauss());
166  BaseFloat ans = gmm.ComponentPosteriors(data, &posteriors);
167  posteriors.Scale(weight);
168  AccumulateFromPosteriors(gmm, data, posteriors);
169  return ans;
170 }
void AccumulateFromPosteriors(const DiagGmm &gmm, const VectorBase< BaseFloat > &data, const VectorBase< BaseFloat > &posteriors)
Definition: mllt.cc:131
float BaseFloat
Definition: kaldi-types.h:29

◆ AccumulateFromGmmPreselect()

BaseFloat AccumulateFromGmmPreselect ( const DiagGmm gmm,
const std::vector< int32 > &  gselect,
const VectorBase< BaseFloat > &  data,
BaseFloat  weight 
)

Definition at line 173 of file mllt.cc.

References MlltAccs::AccumulateFromPosteriors(), rnnlm::i, KALDI_ASSERT, DiagGmm::LogLikelihoodsPreselect(), and DiagGmm::NumGauss().

Referenced by MlltAccs::Update().

177  { // e.g. weight = 1.0
178  KALDI_ASSERT(!gselect.empty());
179  Vector<BaseFloat> loglikes;
180  gmm.LogLikelihoodsPreselect(data, gselect, &loglikes);
181  BaseFloat loglike = loglikes.ApplySoftMax();
182  // now "loglikes" is a vector of posteriors, indexed
183  // by the same index as gselect.
184  Vector<BaseFloat> posteriors(gmm.NumGauss());
185  for (size_t i = 0; i < gselect.size(); i++)
186  posteriors(gselect[i]) = loglikes(i) * weight;
187  AccumulateFromPosteriors(gmm, data, posteriors);
188  return loglike;
189 }
void AccumulateFromPosteriors(const DiagGmm &gmm, const VectorBase< BaseFloat > &data, const VectorBase< BaseFloat > &posteriors)
Definition: mllt.cc:131
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ AccumulateFromPosteriors()

void AccumulateFromPosteriors ( const DiagGmm gmm,
const VectorBase< BaseFloat > &  data,
const VectorBase< BaseFloat > &  posteriors 
)

Definition at line 131 of file mllt.cc.

References MlltAccs::beta_, MlltAccs::Dim(), VectorBase< Real >::Dim(), DiagGmm::Dim(), MlltAccs::G_, rnnlm::i, DiagGmm::inv_vars(), rnnlm::j, KALDI_ASSERT, DiagGmm::means_invvars(), DiagGmm::NumGauss(), MlltAccs::rand_prune_, and kaldi::RandPrune().

Referenced by MlltAccs::AccumulateFromGmm(), MlltAccs::AccumulateFromGmmPreselect(), and MlltAccs::Update().

133  {
134  KALDI_ASSERT(data.Dim() == gmm.Dim());
135  KALDI_ASSERT(data.Dim() == Dim());
136  KALDI_ASSERT(posteriors.Dim() == gmm.NumGauss());
137  const Matrix<BaseFloat> &means_invvars = gmm.means_invvars();
138  const Matrix<BaseFloat> &inv_vars = gmm.inv_vars();
139  Vector<BaseFloat> mean(data.Dim());
140  SpMatrix<double> tmp(data.Dim());
141  Vector<double> offset_dbl(data.Dim());
142  double this_beta_ = 0.0;
143  KALDI_ASSERT(rand_prune_ >= 0.0);
144  for (int32 i = 0; i < posteriors.Dim(); i++) { // for each mixcomp..
145  BaseFloat posterior = RandPrune(posteriors(i), rand_prune_);
146  if (posterior == 0.0) continue;
147  SubVector<BaseFloat> mean_invvar(means_invvars, i);
148  SubVector<BaseFloat> inv_var(inv_vars, i);
149  mean.AddVecDivVec(1.0, mean_invvar, inv_var, 0.0); // get mean.
150  mean.AddVec(-1.0, data); // get offset
151  offset_dbl.CopyFromVec(mean); // make it double.
152  tmp.SetZero();
153  tmp.AddVec2(1.0, offset_dbl);
154  for (int32 j = 0; j < data.Dim(); j++)
155  G_[j].AddSp(inv_var(j)*posterior, tmp);
156  this_beta_ += posterior;
157  }
158  beta_ += this_beta_;
159  Vector<double> data_dbl(data);
160 }
Float RandPrune(Float post, BaseFloat prune_thresh, struct RandomState *state=NULL)
Definition: kaldi-math.h:174
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
double beta_
Definition: mllt.h:108
std::vector< SpMatrix< double > > G_
Definition: mllt.h:109
int32 Dim()
Definition: mllt.h:60
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
BaseFloat rand_prune_
rand_prune_ controls randomized pruning; the larger it is, the more pruning we do.
Definition: mllt.h:107

◆ Dim()

int32 Dim ( )
inline

Definition at line 60 of file mllt.h.

References MlltAccs::G_.

Referenced by MlltAccs::AccumulateFromPosteriors(), and main().

60 { return G_.size(); }; // returns model dimension.
std::vector< SpMatrix< double > > G_
Definition: mllt.h:109

◆ Init()

void Init ( int32  dim,
BaseFloat  rand_prune = 0.25 
)

initializes (destroys anything that was there before).

Definition at line 25 of file mllt.cc.

References MlltAccs::beta_, MlltAccs::G_, rnnlm::i, KALDI_ASSERT, and MlltAccs::rand_prune_.

Referenced by MlltAccs::MlltAccs().

25  { // initializes (destroys anything that was there before).
26  KALDI_ASSERT(dim > 0);
27  beta_ = 0;
28  rand_prune_ = rand_prune;
29  G_.resize(dim);
30  for (int32 i = 0; i < dim; i++)
31  G_[i].Resize(dim); // will zero it too.
32 }
kaldi::int32 int32
double beta_
Definition: mllt.h:108
std::vector< SpMatrix< double > > G_
Definition: mllt.h:109
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
BaseFloat rand_prune_
rand_prune_ controls randomized pruning; the larger it is, the more pruning we do.
Definition: mllt.h:107

◆ Read()

void Read ( std::istream &  is,
bool  binary,
bool  add = false 
)

Definition at line 34 of file mllt.cc.

References MlltAccs::beta_, kaldi::ExpectToken(), MlltAccs::G_, rnnlm::i, KALDI_ERR, and kaldi::ReadBasicType().

Referenced by main(), and MlltAccs::MlltAccs().

34  {
35  ExpectToken(is, binary, "<MlltAccs>");
36  double beta;
37  int32 dim;
38  ReadBasicType(is, binary, &beta);
39  if (!add) beta_ = beta;
40  else beta_ += beta;
41  ReadBasicType(is, binary, &dim);
42  if (add && G_.size() != 0 && static_cast<size_t>(dim) != G_.size())
43  KALDI_ERR << "MlltAccs::Read, summing accs of different size.";
44  if (!add || G_.empty()) G_.resize(dim);
45  ExpectToken(is, binary, "<G>");
46  for (size_t i = 0; i < G_.size(); i++)
47  G_[i].Read(is, binary, add);
48  ExpectToken(is, binary, "</MlltAccs>");
49 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void Read(std::istream &is, bool binary, bool add=false)
Definition: mllt.cc:34
kaldi::int32 int32
double beta_
Definition: mllt.h:108
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
std::vector< SpMatrix< double > > G_
Definition: mllt.h:109
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ Update() [1/2]

void Update ( MatrixBase< BaseFloat > *  M,
BaseFloat objf_impr_out,
BaseFloat count_out 
) const
inline

The Update function does the ML update; it requires that M has the right size.

Parameters
[in,out]MThe output transform, will be of dimension Dim() x Dim(). At input, should be the unit transform (the objective function improvement is measured relative to this value).
[out]objf_impr_outThe objective function improvement
[out]count_outThe data-count

Definition at line 69 of file mllt.h.

References MlltAccs::AccumulateFromGmm(), MlltAccs::AccumulateFromGmmPreselect(), MlltAccs::AccumulateFromPosteriors(), MlltAccs::beta_, and MlltAccs::G_.

Referenced by main().

71  {
72  Update(beta_, G_, M, objf_impr_out, count_out);
73  }
void Update(MatrixBase< BaseFloat > *M, BaseFloat *objf_impr_out, BaseFloat *count_out) const
The Update function does the ML update; it requires that M has the right size.
Definition: mllt.h:69
double beta_
Definition: mllt.h:108
std::vector< SpMatrix< double > > G_
Definition: mllt.h:109

◆ Update() [2/2]

void Update ( double  beta,
const std::vector< SpMatrix< double > > &  G,
MatrixBase< BaseFloat > *  M,
BaseFloat objf_impr_out,
BaseFloat count_out 
)
static

Definition at line 66 of file mllt.cc.

References VectorBase< Real >::AddSpVec(), MatrixBase< Real >::CopyFromMat(), rnnlm::i, MatrixBase< Real >::Invert(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_WARN, kaldi::Log(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), Matrix< Real >::Transpose(), kaldi::VecSpVec(), and kaldi::VecVec().

70  {
71  int32 dim = G.size();
72  KALDI_ASSERT(dim != 0 && M_ptr != NULL
73  && M_ptr->NumRows() == dim
74  && M_ptr->NumCols() == dim);
75  if (beta < 10*dim) { // not really enough data to estimate.
76  // don't bother with min-count parameter etc., as MLLT is typically
77  // global.
78  if (beta > 2*dim)
79  KALDI_WARN << "Mllt:Update, very small count " << beta;
80  else
81  KALDI_WARN << "Mllt:Update, insufficient count " << beta;
82  }
83  int32 num_iters = 200; // may later make this an option.
84  Matrix<double> M(dim, dim), Minv(dim, dim);
85  M.CopyFromMat(*M_ptr);
86  std::vector<SpMatrix<double> > Ginv(dim);
87  for (int32 i = 0; i < dim; i++) {
88  Ginv[i].Resize(dim);
89  Ginv[i].CopyFromSp(G[i]);
90  Ginv[i].Invert();
91  }
92 
93  double tot_objf_impr = 0.0;
94  for (int32 p = 0; p < num_iters; p++) {
95  for (int32 i = 0; i < dim; i++) { // for each row
96  SubVector<double> row(M, i);
97  // work out cofactor (actually cofactor times a constant which
98  // doesn't affect anything):
99  Minv.CopyFromMat(M);
100  Minv.Invert();
101  Minv.Transpose();
102  SubVector<double> cofactor(Minv, i);
103  // Objf is: beta log(|row . cofactor|) -0.5 row^T G[i] row
104  // optimized by (c.f. Mark Gales's techreport "semitied covariance matrices
105  // for hidden markov models, eq. (22)),
106  // row = G_i^{-1} cofactor sqrt(beta / cofactor^T G_i^{-1} cofactor). (1)
107  // here, "row" and "cofactor" are considered as column vectors.
108  double objf_before = beta * Log(std::abs(VecVec(row, cofactor)))
109  -0.5 * VecSpVec(row, G[i], row);
110  // do eq. (1) above:
111  row.AddSpVec(std::sqrt(beta / VecSpVec(cofactor, Ginv[i], cofactor)),
112  Ginv[i], cofactor, 0.0);
113  double objf_after = beta * Log(std::abs(VecVec(row, cofactor)))
114  -0.5 * VecSpVec(row, G[i], row);
115  if (objf_after < objf_before - fabs(objf_before)*0.00001)
116  KALDI_ERR << "Objective decrease in MLLT update.";
117  tot_objf_impr += objf_after - objf_before;
118  }
119  if (p < 10 || p % 10 == 0)
120  KALDI_LOG << "MLLT objective improvement per frame by " << p
121  << "'th iteration is " << (tot_objf_impr/beta) << " per frame "
122  << "over " << beta << " frames.";
123  }
124  if (objf_impr_out)
125  *objf_impr_out = tot_objf_impr;
126  if (count_out)
127  *count_out = beta;
128  M_ptr->CopyFromMat(M);
129 }
kaldi::int32 int32
double Log(double x)
Definition: kaldi-math.h:100
#define KALDI_ERR
Definition: kaldi-error.h:147
Real VecSpVec(const VectorBase< Real > &v1, const SpMatrix< Real > &M, const VectorBase< Real > &v2)
Computes v1^T * M * v2.
Definition: sp-matrix.cc:964
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:37

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const

Definition at line 51 of file mllt.cc.

References MlltAccs::beta_, MlltAccs::G_, rnnlm::i, kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by main(), and MlltAccs::MlltAccs().

51  {
52  WriteToken(os, binary, "<MlltAccs>");
53  if(!binary) os << '\n';
54  WriteBasicType(os, binary, beta_);
55  int32 dim = G_.size();
56  WriteBasicType(os, binary, dim);
57  WriteToken(os, binary, "<G>");
58  if(!binary) os << '\n';
59  for (size_t i = 0; i < G_.size(); i++)
60  G_[i].Write(os, binary);
61  WriteToken(os, binary, "</MlltAccs>");
62  if(!binary) os << '\n';
63 }
kaldi::int32 int32
double beta_
Definition: mllt.h:108
std::vector< SpMatrix< double > > G_
Definition: mllt.h:109
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
void Write(std::ostream &os, bool binary) const
Definition: mllt.cc:51

Member Data Documentation

◆ beta_

◆ G_

◆ rand_prune_

BaseFloat rand_prune_

rand_prune_ controls randomized pruning; the larger it is, the more pruning we do.

Typical value is 0.1.

Definition at line 107 of file mllt.h.

Referenced by MlltAccs::AccumulateFromPosteriors(), and MlltAccs::Init().


The documentation for this class was generated from the following files: