#include <am-sgmm2-project.h>
Public Member Functions | |
void | ComputeProjection (const AmSgmm2 &sgmm, const Matrix< BaseFloat > &inv_lda_mllt, int32 begin_dim, int32 end_dim, Matrix< BaseFloat > *projection) |
void | ApplyProjection (const Matrix< BaseFloat > &total_projection, AmSgmm2 *sgmm) |
Private Member Functions | |
void | ComputeLdaStats (const FullGmm &full_ubm, SpMatrix< double > *between_covar, SpMatrix< double > *within_covar) |
void | ProjectVariance (const Matrix< double > &total_projection, bool inverse, SpMatrix< double > *variance) |
void | ProjectVariance (const Matrix< double > &total_projection, bool inverse, SpMatrix< float > *variance) |
void | ComputeLdaTransform (const SpMatrix< double > &B, const SpMatrix< double > &W, int32 dim_to_retain, Matrix< double > *Projection) |
Definition at line 30 of file am-sgmm2-project.h.
Definition at line 180 of file am-sgmm2-project.cc.
References MatrixBase< Real >::AddMatMat(), FullGmm::ComputeGconsts(), DiagGmm::ComputeGconsts(), DiagGmm::CopyFromFullGmm(), FullGmmNormal::CopyToFullGmm(), AmSgmm2::diag_ubm_, AmSgmm2::FeatureDim(), AmSgmm2::full_ubm_, rnnlm::i, KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, AmSgmm2::M_, FullGmmNormal::means_, AmSgmm2::N_, AmSgmm2::n_, MatrixBase< Real >::NumCols(), AmSgmm2::NumGauss(), MatrixBase< Real >::NumRows(), Sgmm2Project::ProjectVariance(), FullGmm::Resize(), DiagGmm::Resize(), Matrix< Real >::Resize(), AmSgmm2::SigmaInv_, and FullGmmNormal::vars_.
Referenced by main().
|
private |
Definition at line 163 of file am-sgmm2-project.cc.
References SpMatrix< Real >::AddSp(), VectorBase< Real >::AddVec(), SpMatrix< Real >::AddVec2(), FullGmm::Dim(), rnnlm::i, FullGmmNormal::means_, FullGmm::NumGauss(), SpMatrix< Real >::Resize(), MatrixBase< Real >::Row(), and FullGmmNormal::vars_.
Referenced by Sgmm2Project::ComputeProjection().
|
private |
Definition at line 111 of file am-sgmm2-project.cc.
References SpMatrix< Real >::AddMat2Sp(), MatrixBase< Real >::AddMatTp(), SpMatrix< Real >::AddTp2Sp(), TpMatrix< Real >::Cholesky(), TpMatrix< Real >::Invert(), SpMatrix< Real >::IsUnit(), KALDI_ASSERT, KALDI_LOG, kaldi::kCopyData, kaldi::kNoTrans, kaldi::kTrans, PackedMatrix< Real >::NumRows(), VectorBase< Real >::Range(), Matrix< Real >::Resize(), and SpMatrix< Real >::SymPosSemiDefEig().
Referenced by Sgmm2Project::ComputeProjection().
void ComputeProjection | ( | const AmSgmm2 & | sgmm, |
const Matrix< BaseFloat > & | inv_lda_mllt, | ||
int32 | begin_dim, | ||
int32 | end_dim, | ||
Matrix< BaseFloat > * | projection | ||
) |
Definition at line 39 of file am-sgmm2-project.cc.
References SpMatrix< Real >::AddMat2Sp(), Sgmm2Project::ComputeLdaStats(), Sgmm2Project::ComputeLdaTransform(), AmSgmm2::FeatureDim(), AmSgmm2::full_ubm(), rnnlm::i, KALDI_ASSERT, kaldi::kCopyData, kaldi::kNoTrans, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), MatrixBase< Real >::Range(), SpMatrix< Real >::Resize(), and Matrix< Real >::Resize().
Referenced by main().
|
private |
Definition at line 230 of file am-sgmm2-project.cc.
References SpMatrix< Real >::AddMat2Sp(), SpMatrix< Real >::CopyFromSp(), rnnlm::i, SpMatrix< Real >::Invert(), KALDI_ASSERT, kaldi::kCopyData, kaldi::kNoTrans, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), PackedMatrix< Real >::NumRows(), and SpMatrix< Real >::Resize().
Referenced by Sgmm2Project::ApplyProjection(), and Sgmm2Project::ProjectVariance().
|
private |
Definition at line 254 of file am-sgmm2-project.cc.
References SpMatrix< Real >::CopyFromSp(), PackedMatrix< Real >::NumRows(), Sgmm2Project::ProjectVariance(), and SpMatrix< Real >::Resize().