#include "gmm/full-gmm.h"
#include "gmm/diag-gmm.h"
#include "gmm/model-common.h"
#include "gmm/mle-full-gmm.h"
#include "gmm/mle-diag-gmm.h"
#include "util/stl-utils.h"
#include "util/kaldi-io.h"
Go to the source code of this file.
Functions | |
void | TestComponentAcc (const FullGmm &gmm, const Matrix< BaseFloat > &feats) |
void | rand_posdef_spmatrix (size_t dim, SpMatrix< BaseFloat > *matrix, TpMatrix< BaseFloat > *matrix_sqrt=NULL, BaseFloat *logdet=NULL) |
BaseFloat | GetLogLikeTest (const FullGmm &gmm, const VectorBase< BaseFloat > &feats, bool print_eigs) |
void | test_flags_driven_update (const FullGmm &gmm, const Matrix< BaseFloat > &feats, GmmFlagsType flags) |
void | test_io (const FullGmm &gmm, const AccumFullGmm &est_gmm, bool binary, const Matrix< BaseFloat > &feats) |
void | UnitTestEstimateFullGmm () |
int | main () |
BaseFloat GetLogLikeTest | ( | const FullGmm & | gmm, |
const VectorBase< BaseFloat > & | feats, | ||
bool | print_eigs | ||
) |
Definition at line 113 of file mle-full-gmm-test.cc.
References VectorBase< Real >::AddVec(), FullGmm::Dim(), FullGmm::GetMeans(), rnnlm::i, FullGmm::inv_covars(), SpMatrix< Real >::Invert(), kaldi::Log(), kaldi::LogAdd(), M_LOG_2PI, FullGmm::NumGauss(), PackedMatrix< Real >::NumRows(), MatrixBase< Real >::Row(), SpMatrix< Real >::SymPosSemiDefEig(), kaldi::VecSpVec(), and FullGmm::weights().
Referenced by UnitTestEstimateFullGmm().
int main | ( | ) |
Definition at line 478 of file mle-full-gmm-test.cc.
References rnnlm::i, and UnitTestEstimateFullGmm().
void rand_posdef_spmatrix | ( | size_t | dim, |
SpMatrix< BaseFloat > * | matrix, | ||
TpMatrix< BaseFloat > * | matrix_sqrt = NULL , |
||
BaseFloat * | logdet = NULL |
||
) |
Definition at line 90 of file mle-full-gmm-test.cc.
References SpMatrix< Real >::AddMat2(), TpMatrix< Real >::Cholesky(), MatrixBase< Real >::Cond(), kaldi::kNoTrans, SpMatrix< Real >::LogPosDefDet(), and MatrixBase< Real >::SetRandn().
Referenced by UnitTestEstimateFullGmm().
void test_flags_driven_update | ( | const FullGmm & | gmm, |
const Matrix< BaseFloat > & | feats, | ||
GmmFlagsType | flags | ||
) |
Definition at line 145 of file mle-full-gmm-test.cc.
References AccumFullGmm::AccumulateFromFull(), kaldi::AssertEqual(), FullGmm::ComputeGconsts(), FullGmm::CopyFromFullGmm(), FullGmm::Dim(), FullGmm::GetCovars(), FullGmm::GetMeans(), rnnlm::i, KALDI_LOG, KALDI_WARN, kaldi::kGmmAll, kaldi::kGmmMeans, kaldi::kGmmVariances, kaldi::kGmmWeights, FullGmm::LogLikelihood(), kaldi::MleFullGmmUpdate(), FullGmm::NumGauss(), MatrixBase< Real >::NumRows(), AccumFullGmm::Resize(), MatrixBase< Real >::Row(), FullGmm::SetInvCovars(), FullGmm::SetMeans(), FullGmm::SetWeights(), AccumFullGmm::SetZero(), and FullGmm::weights().
Referenced by UnitTestEstimateFullGmm().
void test_io | ( | const FullGmm & | gmm, |
const AccumFullGmm & | est_gmm, | ||
bool | binary, | ||
const Matrix< BaseFloat > & | feats | ||
) |
Definition at line 219 of file mle-full-gmm-test.cc.
References kaldi::AssertEqual(), FullGmm::CopyFromFullGmm(), FullGmm::Dim(), AccumFullGmm::Flags(), rnnlm::i, kaldi::kGmmAll, FullGmm::LogLikelihood(), kaldi::MleFullGmmUpdate(), FullGmm::NumGauss(), MatrixBase< Real >::NumRows(), AccumFullGmm::Read(), AccumFullGmm::Resize(), MatrixBase< Real >::Row(), AccumFullGmm::Scale(), Input::Stream(), and AccumFullGmm::Write().
Referenced by UnitTestEstimateFullGmm().
Definition at line 31 of file mle-full-gmm-test.cc.
References AccumFullGmm::AccumulateForComponent(), AccumFullGmm::AccumulateFromFull(), kaldi::AssertEqual(), FullGmm::ComponentPosteriors(), FullGmm::Dim(), rnnlm::i, KALDI_ASSERT, KALDI_WARN, kaldi::kGmmAll, FullGmm::LogLikelihood(), kaldi::MleFullGmmUpdate(), FullGmm::NumGauss(), AccumFullGmm::NumGauss(), MatrixBase< Real >::NumRows(), FullGmm::Resize(), AccumFullGmm::Resize(), MatrixBase< Real >::Row(), and AccumFullGmm::SetZero().
Referenced by UnitTestEstimateFullGmm().
void UnitTestEstimateFullGmm | ( | ) |
Definition at line 260 of file mle-full-gmm-test.cc.
References AccumFullGmm::AccumulateFromFull(), MatrixBase< Real >::AddMatMat(), VectorBase< Real >::AddRowSumMat(), PackedMatrix< Real >::AddToDiag(), MatrixBase< Real >::AddVecVec(), kaldi::ApproxEqual(), kaldi::AssertEqual(), FullGmm::ComputeGconsts(), FullGmmNormal::CopyToFullGmm(), count, rnnlm::d, FullGmm::Dim(), GetLogLikeTest(), rnnlm::i, FullGmm::inv_covars(), KALDI_ASSERT, KALDI_LOG, kaldi::kGmmAll, kaldi::kGmmMeans, kaldi::kGmmVariances, kaldi::kGmmWeights, kaldi::kNoTrans, kaldi::kTrans, MatrixBase< Real >::LogDet(), M_LOG_2PI, FullGmmNormal::means_, FullGmm::means_invcovars(), kaldi::MleFullGmmUpdate(), FullGmm::NumGauss(), MatrixBase< Real >::NumRows(), rand_posdef_spmatrix(), kaldi::RandGauss(), FullGmm::Resize(), AccumFullGmm::Resize(), MatrixBase< Real >::Row(), MatrixBase< Real >::Scale(), VectorBase< Real >::Scale(), FullGmm::SetInvCovarsAndMeans(), FullGmm::SetWeights(), AccumFullGmm::SetZero(), FullGmm::Split(), MatrixBase< Real >::SymPosSemiDefEig(), test_flags_driven_update(), test_io(), TestComponentAcc(), FullGmmNormal::vars_, and FullGmm::weights().
Referenced by main().