#include <estimate-am-sgmm2.h>
Public Member Functions | |
MleAmSgmm2Updater (const MleAmSgmm2Options &options) | |
void | Reconfigure (const MleAmSgmm2Options &options) |
void | Update (const MleAmSgmm2Accs &accs, AmSgmm2 *model, SgmmUpdateFlagsType flags) |
Private Member Functions | |
void | UpdatePhoneVectorsInternal (const MleAmSgmm2Accs &accs, const std::vector< SpMatrix< double > > &H, const std::vector< Matrix< double > > &log_a, AmSgmm2 *model, double *auxf_impr, int32 num_threads, int32 thread_id) const |
double | UpdatePhoneVectors (const MleAmSgmm2Accs &accs, const std::vector< SpMatrix< double > > &H, const std::vector< Matrix< double > > &log_a, AmSgmm2 *model) const |
In this update, smoothing terms are not supported. More... | |
double | UpdateM (const MleAmSgmm2Accs &accs, const std::vector< SpMatrix< double > > &Q, const Vector< double > &gamma_i, AmSgmm2 *model) |
void | RenormalizeV (const MleAmSgmm2Accs &accs, AmSgmm2 *model, const Vector< double > &gamma_i, const std::vector< SpMatrix< double > > &H) |
double | UpdateN (const MleAmSgmm2Accs &accs, const Vector< double > &gamma_i, AmSgmm2 *model) |
void | RenormalizeN (const MleAmSgmm2Accs &accs, const Vector< double > &gamma_i, AmSgmm2 *model) |
double | UpdateVars (const MleAmSgmm2Accs &accs, const std::vector< SpMatrix< double > > &S_means, const Vector< double > &gamma_i, AmSgmm2 *model) |
double | UpdateW (const MleAmSgmm2Accs &accs, const std::vector< Matrix< double > > &log_a, const Vector< double > &gamma_i, AmSgmm2 *model) |
double | UpdateU (const MleAmSgmm2Accs &accs, const Vector< double > &gamma_i, AmSgmm2 *model) |
double | UpdateSubstateWeights (const MleAmSgmm2Accs &accs, AmSgmm2 *model) |
void | ComputeMPrior (AmSgmm2 *model) |
double | MapUpdateM (const MleAmSgmm2Accs &accs, const std::vector< SpMatrix< double > > &Q, const Vector< double > &gamma_i, AmSgmm2 *model) |
KALDI_DISALLOW_COPY_AND_ASSIGN (MleAmSgmm2Updater) | |
MleAmSgmm2Updater () | |
Static Private Member Functions | |
static void | ComputeQ (const MleAmSgmm2Accs &accs, const AmSgmm2 &model, std::vector< SpMatrix< double > > *Q) |
Compute the Q_i quantities (Eq. 64). More... | |
static void | ComputeSMeans (const MleAmSgmm2Accs &accs, const AmSgmm2 &model, std::vector< SpMatrix< double > > *S_means) |
Compute the S_means quantities, minus sum: (Y_i M_i^T + M_i Y_I^T). More... | |
static void | UpdateWGetStats (const MleAmSgmm2Accs &accs, const AmSgmm2 &model, const Matrix< double > &w, const std::vector< Matrix< double > > &log_a, Matrix< double > *F_i, Matrix< double > *g_i, double *tot_like, int32 num_threads, int32 thread_id) |
Called, multithreaded, inside UpdateW. More... | |
static void | ComputeLogA (const MleAmSgmm2Accs &accs, std::vector< Matrix< double > > *log_a) |
Private Attributes | |
MleAmSgmm2Options | options_ |
Friends | |
class | UpdateWClass |
class | UpdatePhoneVectorsClass |
class | EbwEstimateAmSgmm2 |
class | EbwAmSgmm2Updater |
Definition at line 246 of file estimate-am-sgmm2.h.
|
inlineexplicit |
Definition at line 248 of file estimate-am-sgmm2.h.
|
inlineprivate |
Definition at line 341 of file estimate-am-sgmm2.h.
|
staticprivate |
Definition at line 806 of file estimate-am-sgmm2.cc.
References MleAmSgmm2Accs::a_, MleAmSgmm2Accs::gamma_, KALDI_ASSERT, KALDI_WARN, MleAmSgmm2Accs::num_gaussians_, and MleAmSgmm2Accs::num_groups_.
Referenced by EbwAmSgmm2Updater::UpdateW().
|
private |
Definition at line 1086 of file estimate-am-sgmm2.cc.
References MatrixBase< Real >::AddMat(), SpMatrix< Real >::AddMat2Sp(), AmSgmm2::col_cov_inv_, MatrixBase< Real >::CopyFromMat(), AmSgmm2::FeatureDim(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, kaldi::kNoTrans, kaldi::kTrans, kaldi::Log(), AmSgmm2::M_, M_PI, AmSgmm2::M_prior_, AmSgmm2::NumGauss(), AmSgmm2::PhoneSpaceDim(), AmSgmm2::row_cov_inv_, MatrixBase< Real >::Scale(), and kaldi::TraceSpSp().
|
staticprivate |
Compute the Q_i quantities (Eq. 64).
Definition at line 686 of file estimate-am-sgmm2.cc.
References MleAmSgmm2Accs::gamma_, rnnlm::i, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, AmSgmm2::NumSubstatesForGroup(), MleAmSgmm2Accs::phn_space_dim_, and AmSgmm2::v_.
Referenced by EbwAmSgmm2Updater::Update().
|
staticprivate |
Compute the S_means quantities, minus sum: (Y_i M_i^T + M_i Y_I^T).
Definition at line 706 of file estimate-am-sgmm2.cc.
References MatrixBase< Real >::AddMat(), MatrixBase< Real >::AddMatMat(), VectorBase< Real >::AddMatVec(), MleAmSgmm2Accs::feature_dim_, MleAmSgmm2Accs::gamma_, rnnlm::i, KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, kaldi::kUndefined, AmSgmm2::M_, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, AmSgmm2::NumSubstatesForGroup(), AmSgmm2::v_, and MleAmSgmm2Accs::Y_.
Referenced by EbwAmSgmm2Updater::Update().
|
private |
|
private |
Definition at line 1183 of file estimate-am-sgmm2.cc.
References MatrixBase< Real >::AddMat(), MatrixBase< Real >::AddMatSp(), MatrixBase< Real >::AddSpMat(), AmSgmm2::col_cov_inv_, MatrixBase< Real >::CopyFromMat(), SolverOptions::eps, MleAmSgmm2Accs::feature_dim_, AmSgmm2::FeatureDim(), rnnlm::i, SolverOptions::K, KALDI_LOG, KALDI_WARN, kaldi::kNoTrans, kaldi::kSetZero, AmSgmm2::M_, AmSgmm2::M_prior_, SolverOptions::name, AmSgmm2::NumGauss(), AmSgmm2::PhoneSpaceDim(), AmSgmm2::row_cov_inv_, PackedMatrix< Real >::Scale(), AmSgmm2::SigmaInv_, kaldi::SolveDoubleQuadraticMatrixProblem(), and MleAmSgmm2Accs::Y_.
|
inline |
Definition at line 250 of file estimate-am-sgmm2.h.
|
private |
Definition at line 1525 of file estimate-am-sgmm2.cc.
References MatrixBase< Real >::AddMatMat(), SpMatrix< Real >::AddSp(), MleAmSgmm2Accs::feature_dim_, rnnlm::i, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, kaldi::kNoTrans, MatrixBase< Real >::MulColsVec(), AmSgmm2::N_, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::R_, PackedMatrix< Real >::Scale(), MleAmSgmm2Accs::spk_space_dim_, VectorBase< Real >::Sum(), and SpMatrix< Real >::SymPosSemiDefEig().
|
private |
Definition at line 944 of file estimate-am-sgmm2.cc.
References SpMatrix< Real >::AddMat2Sp(), MatrixBase< Real >::AddMatMat(), VectorBase< Real >::AddMatVec(), SpMatrix< Real >::AddSp(), SpMatrix< Real >::AddVec2(), TpMatrix< Real >::Cholesky(), MatrixBase< Real >::CopyFromTp(), count, MleAmSgmm2Accs::feature_dim_, rnnlm::i, TpMatrix< Real >::Invert(), MatrixBase< Real >::Invert(), SpMatrix< Real >::IsDiagonal(), SpMatrix< Real >::IsPosDef(), SpMatrix< Real >::IsUnit(), KALDI_ASSERT, KALDI_LOG, kaldi::kNoTrans, kaldi::kTrans, AmSgmm2::M_, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, AmSgmm2::NumSubstatesForGroup(), MleAmSgmm2Accs::phn_space_dim_, SpMatrix< Real >::PrintEigs(), MatrixBase< Real >::Row(), PackedMatrix< Real >::Scale(), VectorBase< Real >::Sum(), SpMatrix< Real >::SymPosSemiDefEig(), AmSgmm2::v_, and AmSgmm2::w_.
void Update | ( | const MleAmSgmm2Accs & | accs, |
AmSgmm2 * | model, | ||
SgmmUpdateFlagsType | flags | ||
) |
Definition at line 612 of file estimate-am-sgmm2.cc.
References MleAmSgmm2Accs::a_, VectorBase< Real >::AddRowSumMat(), AmSgmm2::ComputeH(), MleAmSgmm2Accs::gamma_, KALDI_LOG, kaldi::kSgmmCovarianceMatrix, kaldi::kSgmmPhoneProjections, kaldi::kSgmmPhoneVectors, kaldi::kSgmmPhoneWeightProjections, kaldi::kSgmmSpeakerProjections, kaldi::kSgmmSpeakerWeightProjections, kaldi::kSgmmSubstateWeights, AmSgmm2::n_, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, MleAmSgmm2Accs::total_frames_, MleAmSgmm2Accs::total_like_, and AmSgmm2::w_jmi_.
Referenced by main(), and TestSgmm2AccsIO().
|
private |
Definition at line 1040 of file estimate-am-sgmm2.cc.
References SolverOptions::eps, MleAmSgmm2Accs::feature_dim_, rnnlm::i, SolverOptions::K, KALDI_LOG, KALDI_VLOG, KALDI_WARN, AmSgmm2::M_, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, AmSgmm2::SigmaInv_, kaldi::SolveQuadraticMatrixProblem(), and MleAmSgmm2Accs::Y_.
|
private |
Definition at line 1486 of file estimate-am-sgmm2.cc.
References SolverOptions::eps, rnnlm::i, SolverOptions::K, KALDI_ERR, KALDI_LOG, KALDI_WARN, AmSgmm2::N_, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::R_, AmSgmm2::SigmaInv_, kaldi::SolveQuadraticMatrixProblem(), MleAmSgmm2Accs::spk_space_dim_, and MleAmSgmm2Accs::Z_.
|
private |
In this update, smoothing terms are not supported.
However, it does compute the auxiliary function after doing the update, and backtracks if it did not increase (due to the weight terms, increase is not mathematically guaranteed).
Definition at line 782 of file estimate-am-sgmm2.cc.
References count, MleAmSgmm2Accs::gamma_, KALDI_LOG, MleAmSgmm2Accs::num_groups_, and kaldi::RunMultiThreaded().
|
private |
Definition at line 838 of file estimate-am-sgmm2.cc.
References VectorBase< Real >::Add(), VectorBase< Real >::AddMatVec(), SpMatrix< Real >::AddSp(), VectorBase< Real >::AddVec(), SpMatrix< Real >::AddVec2(), VectorBase< Real >::ApplyExp(), SolverOptions::eps, MleAmSgmm2Accs::gamma_, rnnlm::i, SolverOptions::K, KALDI_LOG, KALDI_WARN, kaldi::kNoTrans, VectorBase< Real >::LogSumExp(), SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, AmSgmm2::NumSubstatesForGroup(), MleAmSgmm2Accs::phn_space_dim_, MatrixBase< Real >::Row(), kaldi::SolveQuadraticProblem(), AmSgmm2::v_, kaldi::VecSpVec(), kaldi::VecVec(), AmSgmm2::w_, and MleAmSgmm2Accs::y_.
|
private |
Definition at line 1682 of file estimate-am-sgmm2.cc.
References VectorBase< Real >::Add(), AmSgmm2::c_, MleAmSgmm2Accs::gamma_c_, KALDI_LOG, KALDI_WARN, kaldi::Log(), MleAmSgmm2Accs::num_pdfs_, AmSgmm2::NumSubstatesForPdf(), and VectorBase< Real >::Sum().
|
private |
Definition at line 1439 of file estimate-am-sgmm2.cc.
References VectorBase< Real >::AddVec(), SolverOptions::eps, rnnlm::i, SolverOptions::K, KALDI_LOG, KALDI_WARN, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, MatrixBase< Real >::Row(), VectorBase< Real >::Scale(), kaldi::SolveQuadraticProblem(), MleAmSgmm2Accs::spk_space_dim_, VectorBase< Real >::Sum(), MleAmSgmm2Accs::t_, MleAmSgmm2Accs::U_, and AmSgmm2::u_.
|
private |
Definition at line 1566 of file estimate-am-sgmm2.cc.
References SpMatrix< Real >::AddSp(), VectorBase< Real >::ApplyFloor(), SpMatrix< Real >::CopyFromSp(), rnnlm::d, MleAmSgmm2Accs::feature_dim_, rnnlm::i, rnnlm::j, KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, KALDI_WARN, SpMatrix< Real >::LimitCondDouble(), MleAmSgmm2Accs::num_gaussians_, PackedMatrix< Real >::NumRows(), MleAmSgmm2Accs::S_, PackedMatrix< Real >::Scale(), PackedMatrix< Real >::SetUnit(), AmSgmm2::SigmaInv_, VectorBase< Real >::Sum(), and kaldi::TraceSpSp().
|
private |
Definition at line 1321 of file estimate-am-sgmm2.cc.
References VectorBase< Real >::Add(), MatrixBase< Real >::AddMat(), MatrixBase< Real >::AddMatMat(), MatrixBase< Real >::CopyFromMat(), PackedMatrix< Real >::CopyFromVec(), SolverOptions::eps, MleAmSgmm2Accs::gamma_, rnnlm::i, SolverOptions::K, KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, KALDI_WARN, kaldi::kNoTrans, kaldi::kTrans, VectorBase< Real >::LogSumExp(), SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, AmSgmm2::NumSubstatesForGroup(), MleAmSgmm2Accs::phn_space_dim_, MatrixBase< Real >::Row(), kaldi::RunMultiThreaded(), VectorBase< Real >::Scale(), MatrixBase< Real >::SetZero(), kaldi::SolveQuadraticProblem(), VectorBase< Real >::Sum(), kaldi::TraceMatMat(), AmSgmm2::v_, kaldi::VecSpVec(), kaldi::VecVec(), AmSgmm2::w_, and AmSgmm2::w_jmi_.
|
staticprivate |
Called, multithreaded, inside UpdateW.
This function gets stats used inside UpdateW, where it accumulates the F_i and g_i quantities.
Note: F_i is viewed as a vector of SpMatrix (one for each i); each row of F_i is viewed as an SpMatrix even though it's stored as a vector.... Note: on the first iteration w is just a double-precision copy of the matrix model->w_; thereafter it may differ. log_a relates to the SSGMM.
Definition at line 1258 of file estimate-am-sgmm2.cc.
References VectorBase< Real >::Add(), MatrixBase< Real >::AddMat(), MatrixBase< Real >::AddMatMat(), SpMatrix< Real >::AddVec2(), VectorBase< Real >::ApplyExp(), MleAmSgmm2Accs::gamma_, rnnlm::i, kaldi::kNoTrans, kaldi::kTrans, VectorBase< Real >::LogSumExp(), MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, AmSgmm2::NumSubstatesForGroup(), MleAmSgmm2Accs::phn_space_dim_, MatrixBase< Real >::Row(), PackedMatrix< Real >::SetZero(), AmSgmm2::v_, and kaldi::VecVec().
Referenced by UpdateWClass::operator()().
|
friend |
Definition at line 272 of file estimate-am-sgmm2.h.
|
friend |
Definition at line 261 of file estimate-am-sgmm2.h.
|
friend |
Definition at line 260 of file estimate-am-sgmm2.h.
|
friend |
Definition at line 259 of file estimate-am-sgmm2.h.
|
private |
Definition at line 274 of file estimate-am-sgmm2.h.