A class for estimating Maximum Likelihood Linear Transform, also known as global Semi-tied Covariance (STC), for GMMs. More...
#include <mllt.h>

Public Member Functions | |
| MlltAccs () | |
| MlltAccs (int32 dim, BaseFloat rand_prune=0.25) | |
| Need rand_prune >= 0. More... | |
| void | Init (int32 dim, BaseFloat rand_prune=0.25) |
| initializes (destroys anything that was there before). More... | |
| void | Read (std::istream &is, bool binary, bool add=false) |
| void | Write (std::ostream &os, bool binary) const |
| int32 | Dim () |
| void | Update (MatrixBase< BaseFloat > *M, BaseFloat *objf_impr_out, BaseFloat *count_out) const |
| The Update function does the ML update; it requires that M has the right size. More... | |
| void | AccumulateFromPosteriors (const DiagGmm &gmm, const VectorBase< BaseFloat > &data, const VectorBase< BaseFloat > &posteriors) |
| BaseFloat | AccumulateFromGmm (const DiagGmm &gmm, const VectorBase< BaseFloat > &data, BaseFloat weight) |
| BaseFloat | AccumulateFromGmmPreselect (const DiagGmm &gmm, const std::vector< int32 > &gselect, const VectorBase< BaseFloat > &data, BaseFloat weight) |
Static Public Member Functions | |
| static void | Update (double beta, const std::vector< SpMatrix< double > > &G, MatrixBase< BaseFloat > *M, BaseFloat *objf_impr_out, BaseFloat *count_out) |
Public Attributes | |
| BaseFloat | rand_prune_ |
| rand_prune_ controls randomized pruning; the larger it is, the more pruning we do. More... | |
| double | beta_ |
| std::vector< SpMatrix< double > > | G_ |
A class for estimating Maximum Likelihood Linear Transform, also known as global Semi-tied Covariance (STC), for GMMs.
The resulting transform left-multiplies the feature vector.
|
inline |
Definition at line 44 of file mllt.h.
Need rand_prune >= 0.
The larger it is, the faster it will be. Zero is exact. If a posterior p < rand_prune, will set p to rand_prune with probability (p/rand_prune), otherwise zero. E.g. 10 will give 10x speedup.
Definition at line 51 of file mllt.h.
References MlltAccs::Init(), MlltAccs::Read(), and MlltAccs::Write().
| BaseFloat AccumulateFromGmm | ( | const DiagGmm & | gmm, |
| const VectorBase< BaseFloat > & | data, | ||
| BaseFloat | weight | ||
| ) |
Definition at line 162 of file mllt.cc.
References MlltAccs::AccumulateFromPosteriors(), DiagGmm::ComponentPosteriors(), and DiagGmm::NumGauss().
Referenced by MlltAccs::Update().
| BaseFloat AccumulateFromGmmPreselect | ( | const DiagGmm & | gmm, |
| const std::vector< int32 > & | gselect, | ||
| const VectorBase< BaseFloat > & | data, | ||
| BaseFloat | weight | ||
| ) |
Definition at line 173 of file mllt.cc.
References MlltAccs::AccumulateFromPosteriors(), rnnlm::i, KALDI_ASSERT, DiagGmm::LogLikelihoodsPreselect(), and DiagGmm::NumGauss().
Referenced by MlltAccs::Update().
| void AccumulateFromPosteriors | ( | const DiagGmm & | gmm, |
| const VectorBase< BaseFloat > & | data, | ||
| const VectorBase< BaseFloat > & | posteriors | ||
| ) |
Definition at line 131 of file mllt.cc.
References MlltAccs::beta_, MlltAccs::Dim(), VectorBase< Real >::Dim(), DiagGmm::Dim(), MlltAccs::G_, rnnlm::i, DiagGmm::inv_vars(), rnnlm::j, KALDI_ASSERT, DiagGmm::means_invvars(), DiagGmm::NumGauss(), MlltAccs::rand_prune_, and kaldi::RandPrune().
Referenced by MlltAccs::AccumulateFromGmm(), MlltAccs::AccumulateFromGmmPreselect(), and MlltAccs::Update().
|
inline |
Definition at line 60 of file mllt.h.
References MlltAccs::G_.
Referenced by MlltAccs::AccumulateFromPosteriors(), and main().
initializes (destroys anything that was there before).
Definition at line 25 of file mllt.cc.
References MlltAccs::beta_, MlltAccs::G_, rnnlm::i, KALDI_ASSERT, and MlltAccs::rand_prune_.
Referenced by MlltAccs::MlltAccs().
Definition at line 34 of file mllt.cc.
References MlltAccs::beta_, kaldi::ExpectToken(), MlltAccs::G_, rnnlm::i, KALDI_ERR, and kaldi::ReadBasicType().
Referenced by main(), and MlltAccs::MlltAccs().
|
inline |
The Update function does the ML update; it requires that M has the right size.
| [in,out] | M | The output transform, will be of dimension Dim() x Dim(). At input, should be the unit transform (the objective function improvement is measured relative to this value). |
| [out] | objf_impr_out | The objective function improvement |
| [out] | count_out | The data-count |
Definition at line 69 of file mllt.h.
References MlltAccs::AccumulateFromGmm(), MlltAccs::AccumulateFromGmmPreselect(), MlltAccs::AccumulateFromPosteriors(), MlltAccs::beta_, and MlltAccs::G_.
Referenced by main().
|
static |
Definition at line 66 of file mllt.cc.
References VectorBase< Real >::AddSpVec(), MatrixBase< Real >::CopyFromMat(), rnnlm::i, MatrixBase< Real >::Invert(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_WARN, kaldi::Log(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), Matrix< Real >::Transpose(), kaldi::VecSpVec(), and kaldi::VecVec().
| void Write | ( | std::ostream & | os, |
| bool | binary | ||
| ) | const |
Definition at line 51 of file mllt.cc.
References MlltAccs::beta_, MlltAccs::G_, rnnlm::i, kaldi::WriteBasicType(), and kaldi::WriteToken().
Referenced by main(), and MlltAccs::MlltAccs().
| double beta_ |
Definition at line 108 of file mllt.h.
Referenced by MlltAccs::AccumulateFromPosteriors(), MlltAccs::Init(), MlltAccs::Read(), MlltAccs::Update(), and MlltAccs::Write().
| std::vector<SpMatrix<double> > G_ |
Definition at line 109 of file mllt.h.
Referenced by MlltAccs::AccumulateFromPosteriors(), MlltAccs::Dim(), MlltAccs::Init(), MlltAccs::Read(), MlltAccs::Update(), and MlltAccs::Write().
| BaseFloat rand_prune_ |
rand_prune_ controls randomized pruning; the larger it is, the more pruning we do.
Typical value is 0.1.
Definition at line 107 of file mllt.h.
Referenced by MlltAccs::AccumulateFromPosteriors(), and MlltAccs::Init().