A class for estimating Maximum Likelihood Linear Transform, also known as global Semi-tied Covariance (STC), for GMMs. More...
#include <mllt.h>
Public Member Functions | |
MlltAccs () | |
MlltAccs (int32 dim, BaseFloat rand_prune=0.25) | |
Need rand_prune >= 0. More... | |
void | Init (int32 dim, BaseFloat rand_prune=0.25) |
initializes (destroys anything that was there before). More... | |
void | Read (std::istream &is, bool binary, bool add=false) |
void | Write (std::ostream &os, bool binary) const |
int32 | Dim () |
void | Update (MatrixBase< BaseFloat > *M, BaseFloat *objf_impr_out, BaseFloat *count_out) const |
The Update function does the ML update; it requires that M has the right size. More... | |
void | AccumulateFromPosteriors (const DiagGmm &gmm, const VectorBase< BaseFloat > &data, const VectorBase< BaseFloat > &posteriors) |
BaseFloat | AccumulateFromGmm (const DiagGmm &gmm, const VectorBase< BaseFloat > &data, BaseFloat weight) |
BaseFloat | AccumulateFromGmmPreselect (const DiagGmm &gmm, const std::vector< int32 > &gselect, const VectorBase< BaseFloat > &data, BaseFloat weight) |
Static Public Member Functions | |
static void | Update (double beta, const std::vector< SpMatrix< double > > &G, MatrixBase< BaseFloat > *M, BaseFloat *objf_impr_out, BaseFloat *count_out) |
Public Attributes | |
BaseFloat | rand_prune_ |
rand_prune_ controls randomized pruning; the larger it is, the more pruning we do. More... | |
double | beta_ |
std::vector< SpMatrix< double > > | G_ |
A class for estimating Maximum Likelihood Linear Transform, also known as global Semi-tied Covariance (STC), for GMMs.
The resulting transform left-multiplies the feature vector.
|
inline |
Definition at line 44 of file mllt.h.
Need rand_prune >= 0.
The larger it is, the faster it will be. Zero is exact. If a posterior p < rand_prune, will set p to rand_prune with probability (p/rand_prune), otherwise zero. E.g. 10 will give 10x speedup.
Definition at line 51 of file mllt.h.
References MlltAccs::Init(), MlltAccs::Read(), and MlltAccs::Write().
BaseFloat AccumulateFromGmm | ( | const DiagGmm & | gmm, |
const VectorBase< BaseFloat > & | data, | ||
BaseFloat | weight | ||
) |
Definition at line 162 of file mllt.cc.
References MlltAccs::AccumulateFromPosteriors(), DiagGmm::ComponentPosteriors(), and DiagGmm::NumGauss().
Referenced by MlltAccs::Update().
BaseFloat AccumulateFromGmmPreselect | ( | const DiagGmm & | gmm, |
const std::vector< int32 > & | gselect, | ||
const VectorBase< BaseFloat > & | data, | ||
BaseFloat | weight | ||
) |
Definition at line 173 of file mllt.cc.
References MlltAccs::AccumulateFromPosteriors(), rnnlm::i, KALDI_ASSERT, DiagGmm::LogLikelihoodsPreselect(), and DiagGmm::NumGauss().
Referenced by MlltAccs::Update().
void AccumulateFromPosteriors | ( | const DiagGmm & | gmm, |
const VectorBase< BaseFloat > & | data, | ||
const VectorBase< BaseFloat > & | posteriors | ||
) |
Definition at line 131 of file mllt.cc.
References MlltAccs::beta_, MlltAccs::Dim(), VectorBase< Real >::Dim(), DiagGmm::Dim(), MlltAccs::G_, rnnlm::i, DiagGmm::inv_vars(), rnnlm::j, KALDI_ASSERT, DiagGmm::means_invvars(), DiagGmm::NumGauss(), MlltAccs::rand_prune_, and kaldi::RandPrune().
Referenced by MlltAccs::AccumulateFromGmm(), MlltAccs::AccumulateFromGmmPreselect(), and MlltAccs::Update().
|
inline |
Definition at line 60 of file mllt.h.
References MlltAccs::G_.
Referenced by MlltAccs::AccumulateFromPosteriors(), and main().
initializes (destroys anything that was there before).
Definition at line 25 of file mllt.cc.
References MlltAccs::beta_, MlltAccs::G_, rnnlm::i, KALDI_ASSERT, and MlltAccs::rand_prune_.
Referenced by MlltAccs::MlltAccs().
Definition at line 34 of file mllt.cc.
References MlltAccs::beta_, kaldi::ExpectToken(), MlltAccs::G_, rnnlm::i, KALDI_ERR, and kaldi::ReadBasicType().
Referenced by main(), and MlltAccs::MlltAccs().
|
inline |
The Update function does the ML update; it requires that M has the right size.
[in,out] | M | The output transform, will be of dimension Dim() x Dim(). At input, should be the unit transform (the objective function improvement is measured relative to this value). |
[out] | objf_impr_out | The objective function improvement |
[out] | count_out | The data-count |
Definition at line 69 of file mllt.h.
References MlltAccs::AccumulateFromGmm(), MlltAccs::AccumulateFromGmmPreselect(), MlltAccs::AccumulateFromPosteriors(), MlltAccs::beta_, and MlltAccs::G_.
Referenced by main().
|
static |
Definition at line 66 of file mllt.cc.
References VectorBase< Real >::AddSpVec(), MatrixBase< Real >::CopyFromMat(), rnnlm::i, MatrixBase< Real >::Invert(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_WARN, kaldi::Log(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), Matrix< Real >::Transpose(), kaldi::VecSpVec(), and kaldi::VecVec().
void Write | ( | std::ostream & | os, |
bool | binary | ||
) | const |
Definition at line 51 of file mllt.cc.
References MlltAccs::beta_, MlltAccs::G_, rnnlm::i, kaldi::WriteBasicType(), and kaldi::WriteToken().
Referenced by main(), and MlltAccs::MlltAccs().
double beta_ |
Definition at line 108 of file mllt.h.
Referenced by MlltAccs::AccumulateFromPosteriors(), MlltAccs::Init(), MlltAccs::Read(), MlltAccs::Update(), and MlltAccs::Write().
std::vector<SpMatrix<double> > G_ |
Definition at line 109 of file mllt.h.
Referenced by MlltAccs::AccumulateFromPosteriors(), MlltAccs::Dim(), MlltAccs::Init(), MlltAccs::Read(), MlltAccs::Update(), and MlltAccs::Write().
BaseFloat rand_prune_ |
rand_prune_ controls randomized pruning; the larger it is, the more pruning we do.
Typical value is 0.1.
Definition at line 107 of file mllt.h.
Referenced by MlltAccs::AccumulateFromPosteriors(), and MlltAccs::Init().