#include <estimate-am-sgmm2-ebw.h>
Public Member Functions | |
EbwAmSgmm2Updater (const EbwAmSgmm2Options &options) | |
void | Update (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, AmSgmm2 *model, SgmmUpdateFlagsType flags, BaseFloat *auxf_change_out, BaseFloat *count_out) |
Private Member Functions | |
double | UpdatePhoneVectors (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const std::vector< SpMatrix< double > > &H, AmSgmm2 *model) const |
void | UpdatePhoneVectorsInternal (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const std::vector< SpMatrix< double > > &H, AmSgmm2 *model, double *auxf_impr, int32 num_threads, int32 thread_id) const |
double | UpdateM (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const std::vector< SpMatrix< double > > &Q_num, const std::vector< SpMatrix< double > > &Q_den, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model) const |
double | UpdateN (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model) const |
double | UpdateVars (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, const std::vector< SpMatrix< double > > &S_means, AmSgmm2 *model) const |
double | UpdateW (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model) |
Note: in the discriminative case we do just one iteration of updating the w quantities. More... | |
double | UpdateU (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model) |
double | UpdateSubstateWeights (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, AmSgmm2 *model) |
KALDI_DISALLOW_COPY_AND_ASSIGN (EbwAmSgmm2Updater) | |
EbwAmSgmm2Updater () | |
Static Private Member Functions | |
static void | ComputePhoneVecStats (const MleAmSgmm2Accs &accs, const AmSgmm2 &model, const std::vector< SpMatrix< double > > &H, int32 j1, int32 m, const Vector< double > &w_jm, double gamma_jm, Vector< double > *g_jm, SpMatrix< double > *H_jm) |
Private Attributes | |
EbwAmSgmm2Options | options_ |
Vector< double > | gamma_j_ |
State occupancies. More... | |
Friends | |
class | EbwUpdateWClass |
class | EbwUpdatePhoneVectorsClass |
Definition at line 147 of file estimate-am-sgmm2-ebw.h.
|
inlineexplicit |
Definition at line 149 of file estimate-am-sgmm2-ebw.h.
|
inlineprivate |
Definition at line 235 of file estimate-am-sgmm2-ebw.h.
|
staticprivate |
Definition at line 146 of file estimate-am-sgmm2-ebw.cc.
References MleAmSgmm2Accs::a_, SpMatrix< Real >::AddSp(), VectorBase< Real >::AddVec(), SpMatrix< Real >::AddVec2(), VectorBase< Real >::CopyFromVec(), MleAmSgmm2Accs::gamma_, rnnlm::i, VectorBase< Real >::MulElements(), MleAmSgmm2Accs::num_gaussians_, MatrixBase< Real >::Row(), VectorBase< Real >::Scale(), VectorBase< Real >::Sum(), AmSgmm2::v_, kaldi::VecVec(), AmSgmm2::w_, and MleAmSgmm2Accs::y_.
Referenced by EbwAmSgmm2Updater::UpdatePhoneVectorsInternal().
|
private |
void Update | ( | const MleAmSgmm2Accs & | num_accs, |
const MleAmSgmm2Accs & | den_accs, | ||
AmSgmm2 * | model, | ||
SgmmUpdateFlagsType | flags, | ||
BaseFloat * | auxf_change_out, | ||
BaseFloat * | count_out | ||
) |
Definition at line 27 of file estimate-am-sgmm2-ebw.cc.
References VectorBase< Real >::AddRowSumMat(), AmSgmm2::ComputeH(), AmSgmm2::ComputeNormalizers(), MleAmSgmm2Updater::ComputeQ(), MleAmSgmm2Updater::ComputeSMeans(), MleAmSgmm2Accs::gamma_, rnnlm::i, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, kaldi::kSgmmCovarianceMatrix, kaldi::kSgmmPhoneProjections, kaldi::kSgmmPhoneVectors, kaldi::kSgmmPhoneWeightProjections, kaldi::kSgmmSpeakerProjections, kaldi::kSgmmSpeakerWeightProjections, kaldi::kSgmmSubstateWeights, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, MleAmSgmm2Accs::total_frames_, MleAmSgmm2Accs::total_like_, EbwAmSgmm2Updater::UpdateM(), EbwAmSgmm2Updater::UpdateN(), EbwAmSgmm2Updater::UpdatePhoneVectors(), EbwAmSgmm2Updater::UpdateSubstateWeights(), EbwAmSgmm2Updater::UpdateU(), EbwAmSgmm2Updater::UpdateVars(), and EbwAmSgmm2Updater::UpdateW().
Referenced by main().
|
private |
Definition at line 283 of file estimate-am-sgmm2-ebw.cc.
References MatrixBase< Real >::AddMat(), MatrixBase< Real >::AddMatSp(), SpMatrix< Real >::AddSp(), SolverOptions::eps, EbwAmSgmm2Options::epsilon, AmSgmm2::FeatureDim(), rnnlm::i, SolverOptions::K, KALDI_LOG, KALDI_VLOG, KALDI_WARN, kaldi::kNoTrans, EbwAmSgmm2Options::lrate_M, AmSgmm2::M_, EbwAmSgmm2Options::max_cond, SolverOptions::name, AmSgmm2::NumGauss(), EbwAmSgmm2Updater::options_, AmSgmm2::PhoneSpaceDim(), PackedMatrix< Real >::Scale(), AmSgmm2::SigmaInv_, kaldi::SolveQuadraticMatrixProblem(), VectorBase< Real >::Sum(), EbwAmSgmm2Options::tau_M, and MleAmSgmm2Accs::Y_.
Referenced by EbwAmSgmm2Updater::Update().
|
private |
Definition at line 524 of file estimate-am-sgmm2-ebw.cc.
References MatrixBase< Real >::AddMat(), MatrixBase< Real >::AddMatSp(), SpMatrix< Real >::AddSp(), SolverOptions::eps, EbwAmSgmm2Options::epsilon, MleAmSgmm2Accs::feature_dim_, rnnlm::i, SolverOptions::K, KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, kaldi::kNoTrans, EbwAmSgmm2Options::lrate_N, EbwAmSgmm2Options::max_cond, AmSgmm2::N_, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, EbwAmSgmm2Updater::options_, MleAmSgmm2Accs::R_, PackedMatrix< Real >::Scale(), AmSgmm2::SigmaInv_, kaldi::SolveQuadraticMatrixProblem(), MleAmSgmm2Accs::spk_space_dim_, VectorBase< Real >::Sum(), EbwAmSgmm2Options::tau_N, and MleAmSgmm2Accs::Z_.
Referenced by EbwAmSgmm2Updater::Update().
|
private |
Definition at line 261 of file estimate-am-sgmm2-ebw.cc.
References count, MleAmSgmm2Accs::gamma_, KALDI_LOG, MleAmSgmm2Accs::num_groups_, and kaldi::RunMultiThreaded().
Referenced by EbwAmSgmm2Updater::Update().
|
private |
Definition at line 178 of file estimate-am-sgmm2-ebw.cc.
References VectorBase< Real >::AddMatVec(), SpMatrix< Real >::AddSp(), VectorBase< Real >::AddSpVec(), VectorBase< Real >::AddVec(), VectorBase< Real >::ApplySoftMax(), EbwAmSgmm2Updater::ComputePhoneVecStats(), SolverOptions::eps, EbwAmSgmm2Options::epsilon, MleAmSgmm2Accs::gamma_, SolverOptions::K, KALDI_LOG, kaldi::kNoTrans, EbwAmSgmm2Options::lrate_v, EbwAmSgmm2Options::max_cond, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, AmSgmm2::NumSubstatesForGroup(), EbwAmSgmm2Updater::options_, MleAmSgmm2Accs::phn_space_dim_, PackedMatrix< Real >::Scale(), kaldi::SolveQuadraticProblem(), EbwAmSgmm2Options::tau_v, AmSgmm2::v_, and AmSgmm2::w_.
|
private |
Definition at line 672 of file estimate-am-sgmm2-ebw.cc.
References AmSgmm2::c_, VectorBase< Real >::Dim(), MleAmSgmm2Accs::gamma_c_, KALDI_LOG, EbwAmSgmm2Options::min_substate_weight, MleAmSgmm2Accs::num_pdfs_, AmSgmm2::NumSubstatesForPdf(), EbwAmSgmm2Updater::options_, VectorBase< Real >::Scale(), VectorBase< Real >::Sum(), and EbwAmSgmm2Options::tau_c.
Referenced by EbwAmSgmm2Updater::Update().
|
private |
Definition at line 463 of file estimate-am-sgmm2-ebw.cc.
References SpMatrix< Real >::AddSp(), VectorBase< Real >::AddVec(), SolverOptions::eps, EbwAmSgmm2Options::epsilon, rnnlm::i, SolverOptions::K, KALDI_LOG, KALDI_WARN, EbwAmSgmm2Options::lrate_u, EbwAmSgmm2Options::max_cond, EbwAmSgmm2Options::max_impr_u, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, EbwAmSgmm2Updater::options_, MatrixBase< Real >::Row(), PackedMatrix< Real >::Scale(), VectorBase< Real >::Scale(), kaldi::SolveQuadraticProblem(), MleAmSgmm2Accs::spk_space_dim_, VectorBase< Real >::Sum(), MleAmSgmm2Accs::t_, EbwAmSgmm2Options::tau_u, MleAmSgmm2Accs::U_, and AmSgmm2::u_.
Referenced by EbwAmSgmm2Updater::Update().
|
private |
Definition at line 597 of file estimate-am-sgmm2-ebw.cc.
References SpMatrix< Real >::AddSp(), SpMatrix< Real >::ApplyFloor(), count, EbwAmSgmm2Options::cov_min_value, rnnlm::i, SpMatrix< Real >::Invert(), KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, SpMatrix< Real >::LogDet(), EbwAmSgmm2Options::lrate_Sigma, MleAmSgmm2Accs::num_gaussians_, EbwAmSgmm2Updater::options_, MleAmSgmm2Accs::S_, PackedMatrix< Real >::Scale(), AmSgmm2::SigmaInv_, EbwAmSgmm2Options::tau_Sigma, and kaldi::TraceSpSp().
Referenced by EbwAmSgmm2Updater::Update().
|
private |
Note: in the discriminative case we do just one iteration of updating the w quantities.
Definition at line 370 of file estimate-am-sgmm2-ebw.cc.
References VectorBase< Real >::AddVec(), MleAmSgmm2Updater::ComputeLogA(), PackedMatrix< Real >::CopyFromVec(), SolverOptions::eps, EbwAmSgmm2Options::epsilon, AmSgmm2::HasSpeakerDependentWeights(), rnnlm::i, SolverOptions::K, KALDI_LOG, KALDI_VLOG, EbwAmSgmm2Options::lrate_w, EbwAmSgmm2Options::max_cond, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, EbwAmSgmm2Updater::options_, MleAmSgmm2Accs::phn_space_dim_, MatrixBase< Real >::Row(), kaldi::RunMultiThreaded(), kaldi::SolveQuadraticProblem(), VectorBase< Real >::Sum(), EbwAmSgmm2Options::tau_w, and AmSgmm2::w_.
Referenced by EbwAmSgmm2Updater::Update().
|
friend |
Definition at line 163 of file estimate-am-sgmm2-ebw.h.
|
friend |
Definition at line 162 of file estimate-am-sgmm2-ebw.h.
|
private |
State occupancies.
Definition at line 167 of file estimate-am-sgmm2-ebw.h.
|
private |
Definition at line 165 of file estimate-am-sgmm2-ebw.h.
Referenced by EbwAmSgmm2Updater::UpdateM(), EbwAmSgmm2Updater::UpdateN(), EbwAmSgmm2Updater::UpdatePhoneVectorsInternal(), EbwAmSgmm2Updater::UpdateSubstateWeights(), EbwAmSgmm2Updater::UpdateU(), EbwAmSgmm2Updater::UpdateVars(), and EbwAmSgmm2Updater::UpdateW().