EbwAmSgmm2Updater Class Reference

#include <estimate-am-sgmm2-ebw.h>

Collaboration diagram for EbwAmSgmm2Updater:

Public Member Functions

 EbwAmSgmm2Updater (const EbwAmSgmm2Options &options)
 
void Update (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, AmSgmm2 *model, SgmmUpdateFlagsType flags, BaseFloat *auxf_change_out, BaseFloat *count_out)
 

Private Member Functions

double UpdatePhoneVectors (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const std::vector< SpMatrix< double > > &H, AmSgmm2 *model) const
 
void UpdatePhoneVectorsInternal (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const std::vector< SpMatrix< double > > &H, AmSgmm2 *model, double *auxf_impr, int32 num_threads, int32 thread_id) const
 
double UpdateM (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const std::vector< SpMatrix< double > > &Q_num, const std::vector< SpMatrix< double > > &Q_den, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model) const
 
double UpdateN (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model) const
 
double UpdateVars (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, const std::vector< SpMatrix< double > > &S_means, AmSgmm2 *model) const
 
double UpdateW (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model)
 Note: in the discriminative case we do just one iteration of updating the w quantities. More...
 
double UpdateU (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model)
 
double UpdateSubstateWeights (const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, AmSgmm2 *model)
 
 KALDI_DISALLOW_COPY_AND_ASSIGN (EbwAmSgmm2Updater)
 
 EbwAmSgmm2Updater ()
 

Static Private Member Functions

static void ComputePhoneVecStats (const MleAmSgmm2Accs &accs, const AmSgmm2 &model, const std::vector< SpMatrix< double > > &H, int32 j1, int32 m, const Vector< double > &w_jm, double gamma_jm, Vector< double > *g_jm, SpMatrix< double > *H_jm)
 

Private Attributes

EbwAmSgmm2Options options_
 
Vector< double > gamma_j_
 State occupancies. More...
 

Friends

class EbwUpdateWClass
 
class EbwUpdatePhoneVectorsClass
 

Detailed Description

Definition at line 147 of file estimate-am-sgmm2-ebw.h.

Constructor & Destructor Documentation

◆ EbwAmSgmm2Updater() [1/2]

EbwAmSgmm2Updater ( const EbwAmSgmm2Options options)
inlineexplicit

Definition at line 149 of file estimate-am-sgmm2-ebw.h.

149  :
150  options_(options) {}

◆ EbwAmSgmm2Updater() [2/2]

EbwAmSgmm2Updater ( )
inlineprivate

Definition at line 235 of file estimate-am-sgmm2-ebw.h.

235 {} // Prevent unconfigured updater.

Member Function Documentation

◆ ComputePhoneVecStats()

void ComputePhoneVecStats ( const MleAmSgmm2Accs accs,
const AmSgmm2 model,
const std::vector< SpMatrix< double > > &  H,
int32  j1,
int32  m,
const Vector< double > &  w_jm,
double  gamma_jm,
Vector< double > *  g_jm,
SpMatrix< double > *  H_jm 
)
staticprivate

Definition at line 146 of file estimate-am-sgmm2-ebw.cc.

References MleAmSgmm2Accs::a_, SpMatrix< Real >::AddSp(), VectorBase< Real >::AddVec(), SpMatrix< Real >::AddVec2(), VectorBase< Real >::CopyFromVec(), MleAmSgmm2Accs::gamma_, rnnlm::i, VectorBase< Real >::MulElements(), MleAmSgmm2Accs::num_gaussians_, MatrixBase< Real >::Row(), VectorBase< Real >::Scale(), VectorBase< Real >::Sum(), AmSgmm2::v_, kaldi::VecVec(), AmSgmm2::w_, and MleAmSgmm2Accs::y_.

Referenced by EbwAmSgmm2Updater::UpdatePhoneVectorsInternal().

155  {
156  Vector<double> w_jm(w_jm_in);
157  if (!accs.a_.empty() && accs.a_[j1](m, 0) != 0) { // [SSGMM]
158  w_jm.MulElements(accs.a_[j1].Row(m)); // multiply by "a" quantities..
159  w_jm.Scale(1.0 / w_jm.Sum()); // renormalize.
160  }
161  g_jm->CopyFromVec(accs.y_[j1].Row(m));
162  for (int32 i = 0; i < accs.num_gaussians_; i++) {
163  double gamma_jmi = accs.gamma_[j1](m, i);
164  double quadratic_term = std::max(gamma_jmi, gamma_jm * w_jm(i));
165  double scalar = gamma_jmi - gamma_jm * w_jm(i) + quadratic_term
166  * VecVec(model.w_.Row(i), model.v_[j1].Row(m));
167  g_jm->AddVec(scalar, model.w_.Row(i));
168  if (gamma_jmi != 0.0)
169  H_jm->AddSp(gamma_jmi, H[i]); // The most important term..
170  if (quadratic_term > 1.0e-10)
171  H_jm->AddVec2(static_cast<BaseFloat>(quadratic_term), model.w_.Row(i));
172  }
173 }
kaldi::int32 int32
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
void AddVec2(const Real alpha, const VectorBase< OtherReal > &v)
rank-one update, this <– this + alpha v v&#39;
Definition: sp-matrix.cc:946
void MulElements(const VectorBase< Real > &v)
Multiply element-by-element by another vector.
void AddSp(const Real alpha, const SpMatrix< Real > &Ma)
Definition: sp-matrix.h:211
void Scale(Real alpha)
Multiplies all elements by this constant.
Real Sum() const
Returns sum of the elements.
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:37
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( EbwAmSgmm2Updater  )
private

◆ Update()

void Update ( const MleAmSgmm2Accs num_accs,
const MleAmSgmm2Accs den_accs,
AmSgmm2 model,
SgmmUpdateFlagsType  flags,
BaseFloat auxf_change_out,
BaseFloat count_out 
)

Definition at line 27 of file estimate-am-sgmm2-ebw.cc.

References VectorBase< Real >::AddRowSumMat(), AmSgmm2::ComputeH(), AmSgmm2::ComputeNormalizers(), MleAmSgmm2Updater::ComputeQ(), MleAmSgmm2Updater::ComputeSMeans(), MleAmSgmm2Accs::gamma_, rnnlm::i, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, kaldi::kSgmmCovarianceMatrix, kaldi::kSgmmPhoneProjections, kaldi::kSgmmPhoneVectors, kaldi::kSgmmPhoneWeightProjections, kaldi::kSgmmSpeakerProjections, kaldi::kSgmmSpeakerWeightProjections, kaldi::kSgmmSubstateWeights, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, MleAmSgmm2Accs::total_frames_, MleAmSgmm2Accs::total_like_, EbwAmSgmm2Updater::UpdateM(), EbwAmSgmm2Updater::UpdateN(), EbwAmSgmm2Updater::UpdatePhoneVectors(), EbwAmSgmm2Updater::UpdateSubstateWeights(), EbwAmSgmm2Updater::UpdateU(), EbwAmSgmm2Updater::UpdateVars(), and EbwAmSgmm2Updater::UpdateW().

Referenced by main().

32  {
33 
34  // Various quantities need to be computed at the start, before we
35  // change any of the model parameters.
36  std::vector< SpMatrix<double> > Q_num, Q_den, H, S_means;
37 
38  if (flags & kSgmmPhoneProjections) {
39  MleAmSgmm2Updater::ComputeQ(num_accs, *model, &Q_num);
40  MleAmSgmm2Updater::ComputeQ(den_accs, *model, &Q_den);
41  }
42  if (flags & kSgmmCovarianceMatrix) { // compute the difference between
43  // the num and den S_means matrices... this is what we will need.
44  MleAmSgmm2Updater::ComputeSMeans(num_accs, *model, &S_means);
45  std::vector< SpMatrix<double> > S_means_tmp;
46  MleAmSgmm2Updater::ComputeSMeans(den_accs, *model, &S_means_tmp);
47  for (size_t i = 0; i < S_means.size(); i++)
48  S_means[i].AddSp(-1.0, S_means_tmp[i]);
49  }
51  model->ComputeH(&H);
52 
53  Vector<double> gamma_num(num_accs.num_gaussians_);
54  for (int32 j1 = 0; j1 < num_accs.num_groups_; j1++)
55  gamma_num.AddRowSumMat(1.0, num_accs.gamma_[j1]);
56  Vector<double> gamma_den(den_accs.num_gaussians_);
57  for (int32 j1 = 0; j1 < den_accs.num_groups_; j1++)
58  gamma_den.AddRowSumMat(1.0, den_accs.gamma_[j1]);
59 
60  BaseFloat tot_impr = 0.0;
61 
62  if (flags & kSgmmPhoneVectors)
63  tot_impr += UpdatePhoneVectors(num_accs, den_accs, H, model);
64 
65  if (flags & kSgmmPhoneProjections)
66  tot_impr += UpdateM(num_accs, den_accs, Q_num, Q_den,
67  gamma_num, gamma_den, model);
68 
69  if (flags & kSgmmPhoneWeightProjections)
70  tot_impr += UpdateW(num_accs, den_accs, gamma_num, gamma_den, model);
71 
73  tot_impr += UpdateU(num_accs, den_accs, gamma_num, gamma_den, model);
74 
75  if (flags & kSgmmCovarianceMatrix)
76  tot_impr += UpdateVars(num_accs, den_accs,
77  gamma_num, gamma_den, S_means, model);
78 
79  if (flags & kSgmmSubstateWeights)
80  tot_impr += UpdateSubstateWeights(num_accs, den_accs, model);
81 
82  if (flags & kSgmmSpeakerProjections)
83  tot_impr += UpdateN(num_accs, den_accs, gamma_num, gamma_den, model);
84 
85 
86  if (auxf_change_out) *auxf_change_out = tot_impr * num_accs.total_frames_;
87  if (count_out) *count_out = num_accs.total_frames_;
88 
89  if (fabs(num_accs.total_frames_ - den_accs.total_frames_) >
90  0.01*(num_accs.total_frames_ + den_accs.total_frames_))
91  KALDI_WARN << "Num and den frame counts differ, "
92  << num_accs.total_frames_ << " vs. " << den_accs.total_frames_;
93 
94  BaseFloat like_diff = num_accs.total_like_ - den_accs.total_like_;
95 
96  KALDI_LOG << "***Averaged differenced likelihood per frame is "
97  << (like_diff/num_accs.total_frames_)
98  << " over " << (num_accs.total_frames_) << " frames.";
99  KALDI_LOG << "***Note: for this to be at all meaningful, if you use "
100  << "\"canceled\" stats you will have to renormalize this over "
101  << "the \"real\" frame count.";
102  KALDI_ASSERT(num_accs.total_frames_ > 0 && den_accs.total_frames_ > 0);
103 
104  model->ComputeNormalizers();
105 }
double UpdateN(const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model) const
static void ComputeQ(const MleAmSgmm2Accs &accs, const AmSgmm2 &model, std::vector< SpMatrix< double > > *Q)
Compute the Q_i quantities (Eq. 64).
double UpdateU(const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model)
double UpdateVars(const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, const std::vector< SpMatrix< double > > &S_means, AmSgmm2 *model) const
kaldi::int32 int32
static void ComputeSMeans(const MleAmSgmm2Accs &accs, const AmSgmm2 &model, std::vector< SpMatrix< double > > *S_means)
Compute the S_means quantities, minus sum: (Y_i M_i^T + M_i Y_I^T).
t .. not really part of SGMM.
Definition: model-common.h:55
float BaseFloat
Definition: kaldi-types.h:29
The letters correspond to the variable names.
Definition: model-common.h:48
#define KALDI_WARN
Definition: kaldi-error.h:150
double UpdateSubstateWeights(const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, AmSgmm2 *model)
double UpdateW(const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model)
Note: in the discriminative case we do just one iteration of updating the w quantities.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
double UpdatePhoneVectors(const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const std::vector< SpMatrix< double > > &H, AmSgmm2 *model) const
#define KALDI_LOG
Definition: kaldi-error.h:153
double UpdateM(const MleAmSgmm2Accs &num_accs, const MleAmSgmm2Accs &den_accs, const std::vector< SpMatrix< double > > &Q_num, const std::vector< SpMatrix< double > > &Q_den, const Vector< double > &gamma_num, const Vector< double > &gamma_den, AmSgmm2 *model) const

◆ UpdateM()

double UpdateM ( const MleAmSgmm2Accs num_accs,
const MleAmSgmm2Accs den_accs,
const std::vector< SpMatrix< double > > &  Q_num,
const std::vector< SpMatrix< double > > &  Q_den,
const Vector< double > &  gamma_num,
const Vector< double > &  gamma_den,
AmSgmm2 model 
) const
private

Definition at line 283 of file estimate-am-sgmm2-ebw.cc.

References MatrixBase< Real >::AddMat(), MatrixBase< Real >::AddMatSp(), SpMatrix< Real >::AddSp(), SolverOptions::eps, EbwAmSgmm2Options::epsilon, AmSgmm2::FeatureDim(), rnnlm::i, SolverOptions::K, KALDI_LOG, KALDI_VLOG, KALDI_WARN, kaldi::kNoTrans, EbwAmSgmm2Options::lrate_M, AmSgmm2::M_, EbwAmSgmm2Options::max_cond, SolverOptions::name, AmSgmm2::NumGauss(), EbwAmSgmm2Updater::options_, AmSgmm2::PhoneSpaceDim(), PackedMatrix< Real >::Scale(), AmSgmm2::SigmaInv_, kaldi::SolveQuadraticMatrixProblem(), VectorBase< Real >::Sum(), EbwAmSgmm2Options::tau_M, and MleAmSgmm2Accs::Y_.

Referenced by EbwAmSgmm2Updater::Update().

289  {
290  int32 S = model->PhoneSpaceDim(),
291  D = model->FeatureDim(),
292  I = model->NumGauss();
293 
294  Vector<double> impr_vec(I);
295 
296  for (int32 i = 0; i < I; i++) {
297  double gamma_i_num = gamma_num(i), gamma_i_den = gamma_den(i);
298 
299  if (gamma_i_num + gamma_i_den == 0.0) {
300  KALDI_WARN << "Not updating phonetic basis for i = " << i
301  << " because count is zero. ";
302  continue;
303  }
304 
305  Matrix<double> Mi(model->M_[i]);
306  Matrix<double> L(D, S); // this is something like the Y quantity, which
307  // represents the linear term in the objf on M-- except that we make it the local
308  // derivative about the current value, instead of the derivative around zero.
309  // But it's not exactly the derivative w.r.t. M, due to the factor of Sigma_i.
310  // The auxiliary function is Q(x) = tr(M^T P Y) - 0.5 tr(P M Q M^T),
311  // where P is Y^{-1}. The quantity L we define here will be Y - M Q,
312  // and you can think of this as like the local derivative, except there is
313  // a term P in there.
314  L.AddMat(1.0, num_accs.Y_[i]);
315  L.AddMatSp(-1.0, Mi, kNoTrans, Q_num[i], 1.0);
316  L.AddMat(-1.0, den_accs.Y_[i]);
317  L.AddMatSp(-1.0*-1.0, Mi, kNoTrans, Q_den[i], 1.0);
318 
319  SpMatrix<double> Q(S); // This is a combination of the Q's for the numerator and denominator.
320  Q.AddSp(1.0, Q_num[i]);
321  Q.AddSp(1.0, Q_den[i]);
322 
323  double state_count = 1.0e-10 + gamma_i_num + gamma_i_den; // the count
324  // represented by the quadratic part of the stats.
325  Q.Scale( (state_count + options_.tau_M) / state_count );
326  Q.Scale( 1.0 / (options_.lrate_M + 1.0e-10) );
327 
328 
329  SolverOptions opts;
330  opts.name = "M";
331  opts.K = options_.max_cond;
332  opts.eps = options_.epsilon;
333 
334  Matrix<double> deltaM(D, S);
335  double impr =
337  SpMatrix<double>(model->SigmaInv_[i]),
338  opts, &deltaM);
339 
340  impr_vec(i) = impr;
341  Mi.AddMat(1.0, deltaM);
342  model->M_[i].CopyFromMat(Mi);
343  if (i < 10 || impr / state_count > 3.0) {
344  KALDI_VLOG(2) << "Objf impr for projection M for i = " << i << ", is "
345  << (impr/(gamma_i_num + 1.0e-20)) << " over " << gamma_i_num
346  << " frames";
347  }
348  }
349  BaseFloat tot_count = gamma_num.Sum(), tot_impr = impr_vec.Sum();
350 
351  tot_impr /= (tot_count + 1.0e-20);
352  KALDI_LOG << "Overall auxiliary function improvement for model projections "
353  << "M is " << tot_impr << " over " << tot_count << " frames";
354 
355  KALDI_VLOG(1) << "Updating M: num-count is " << gamma_num;
356  KALDI_VLOG(1) << "Updating M: den-count is " << gamma_den;
357  KALDI_VLOG(1) << "Updating M: objf-impr is " << impr_vec;
358 
359  return tot_impr;
360 }
BaseFloat max_cond
is allowed to change.
Real SolveQuadraticMatrixProblem(const SpMatrix< Real > &Q, const MatrixBase< Real > &Y, const SpMatrix< Real > &SigmaInv, const SolverOptions &opts, MatrixBase< Real > *M)
Maximizes the auxiliary function : Like a numerically stable version of .
Definition: sp-matrix.cc:729
kaldi::int32 int32
BaseFloat tau_M
Smoothing constant for the M quantities (phone-subspace projections)
BaseFloat lrate_M
Learning rate used in updating M– default 0.5.
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_WARN
Definition: kaldi-error.h:150
Real Sum() const
Returns sum of the elements.
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
BaseFloat epsilon
very small value used in SolveQuadraticProblem; workaround
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ UpdateN()

double UpdateN ( const MleAmSgmm2Accs num_accs,
const MleAmSgmm2Accs den_accs,
const Vector< double > &  gamma_num,
const Vector< double > &  gamma_den,
AmSgmm2 model 
) const
private

Definition at line 524 of file estimate-am-sgmm2-ebw.cc.

References MatrixBase< Real >::AddMat(), MatrixBase< Real >::AddMatSp(), SpMatrix< Real >::AddSp(), SolverOptions::eps, EbwAmSgmm2Options::epsilon, MleAmSgmm2Accs::feature_dim_, rnnlm::i, SolverOptions::K, KALDI_ERR, KALDI_LOG, KALDI_VLOG, KALDI_WARN, kaldi::kNoTrans, EbwAmSgmm2Options::lrate_N, EbwAmSgmm2Options::max_cond, AmSgmm2::N_, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, EbwAmSgmm2Updater::options_, MleAmSgmm2Accs::R_, PackedMatrix< Real >::Scale(), AmSgmm2::SigmaInv_, kaldi::SolveQuadraticMatrixProblem(), MleAmSgmm2Accs::spk_space_dim_, VectorBase< Real >::Sum(), EbwAmSgmm2Options::tau_N, and MleAmSgmm2Accs::Z_.

Referenced by EbwAmSgmm2Updater::Update().

528  {
529  if (num_accs.spk_space_dim_ == 0 || num_accs.R_.size() == 0 ||
530  num_accs.Z_.size() == 0) {
531  KALDI_ERR << "Speaker subspace dim is zero or no stats accumulated";
532  }
533 
534  int32 I = num_accs.num_gaussians_, D = num_accs.feature_dim_,
535  T = num_accs.spk_space_dim_;
536 
537  Vector<double> impr_vec(I);
538 
539  for (int32 i = 0; i < I; i++) {
540  double gamma_i_num = gamma_num(i), gamma_i_den = gamma_den(i);
541  if (gamma_i_num + gamma_i_den == 0.0) {
542  KALDI_WARN << "Not updating speaker basis for i = " << i
543  << " because count is zero. ";
544  continue;
545  }
546  Matrix<double> Ni(model->N_[i]);
547  // See comment near declaration of L in UpdateM(). This update is the
548  // same, but change M->N, Y->Z and Q->R.
549 
550  Matrix<double> L(D, T);
551  L.AddMat(1.0, num_accs.Z_[i]);
552  L.AddMatSp(-1.0, Ni, kNoTrans, num_accs.R_[i], 1.0);
553  L.AddMat(-1.0, den_accs.Z_[i]);
554  L.AddMatSp(-1.0*-1.0, Ni, kNoTrans, den_accs.R_[i], 1.0);
555 
556  SpMatrix<double> R(T); // combination of the numerator and denominator R's.
557  R.AddSp(1.0, num_accs.R_[i]);
558  R.AddSp(1.0, den_accs.R_[i]);
559 
560  double state_count = 1.0e-10 + gamma_i_num + gamma_i_den; // the count
561  // represented by the quadratic part of the stats.
562  R.Scale( (state_count + options_.tau_N) / state_count );
563  R.Scale( 1.0 / (options_.lrate_N + 1.0e-10) );
564 
565  Matrix<double> deltaN(D, T);
566 
567  SolverOptions opts;
568  opts.name = "N";
569  opts.K = options_.max_cond;
570  opts.eps = options_.epsilon;
571 
572  double impr =
574  SpMatrix<double>(model->SigmaInv_[i]),
575  opts, &deltaN);
576  impr_vec(i) = impr;
577  Ni.AddMat(1.0, deltaN);
578  model->N_[i].CopyFromMat(Ni);
579  if (i < 10 || impr / (state_count+1.0e-20) > 3.0) {
580  KALDI_LOG << "Objf impr for spk projection N for i = " << (i)
581  << ", is " << (impr / (gamma_i_num + 1.0e-20)) << " over "
582  << gamma_i_num << " frames";
583  }
584  }
585 
586  KALDI_VLOG(1) << "Updating N: numerator count is " << gamma_num;
587  KALDI_VLOG(1) << "Updating N: denominator count is " << gamma_den;
588  KALDI_VLOG(1) << "Updating N: objf-impr is " << impr_vec;
589 
590  double tot_count = gamma_num.Sum(), tot_impr = impr_vec.Sum();
591  tot_impr /= (tot_count + 1.0e-20);
592  KALDI_LOG << "**Overall auxf impr for N is " << tot_impr
593  << " over " << tot_count << " frames";
594  return tot_impr;
595 }
BaseFloat max_cond
is allowed to change.
Real SolveQuadraticMatrixProblem(const SpMatrix< Real > &Q, const MatrixBase< Real > &Y, const SpMatrix< Real > &SigmaInv, const SolverOptions &opts, MatrixBase< Real > *M)
Maximizes the auxiliary function : Like a numerically stable version of .
Definition: sp-matrix.cc:729
BaseFloat lrate_N
Learning rate used in updating N– default 0.5.
kaldi::int32 int32
BaseFloat tau_N
Smoothing constant for the N quantities (speaker-subspace projections)
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
BaseFloat epsilon
very small value used in SolveQuadraticProblem; workaround
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ UpdatePhoneVectors()

double UpdatePhoneVectors ( const MleAmSgmm2Accs num_accs,
const MleAmSgmm2Accs den_accs,
const std::vector< SpMatrix< double > > &  H,
AmSgmm2 model 
) const
private

Definition at line 261 of file estimate-am-sgmm2-ebw.cc.

References count, MleAmSgmm2Accs::gamma_, KALDI_LOG, MleAmSgmm2Accs::num_groups_, and kaldi::RunMultiThreaded().

Referenced by EbwAmSgmm2Updater::Update().

264  {
265  KALDI_LOG << "Updating phone vectors.";
266 
267  double count = 0.0, auxf_impr = 0.0;
268 
269  int32 J1 = num_accs.num_groups_;
270  for (int32 j1 = 0; j1 < J1; j1++) count += num_accs.gamma_[j1].Sum();
271 
272  EbwUpdatePhoneVectorsClass c(this, num_accs, den_accs, H, model, &auxf_impr);
273  RunMultiThreaded(c);
274 
275  auxf_impr /= count;
276 
277  KALDI_LOG << "**Overall auxf improvement for v is " << auxf_impr
278  << " over " << count << " frames";
279  return auxf_impr;
280 }
kaldi::int32 int32
const size_t count
void RunMultiThreaded(const C &c_in)
Here, class C should inherit from MultiThreadable.
Definition: kaldi-thread.h:151
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ UpdatePhoneVectorsInternal()

void UpdatePhoneVectorsInternal ( const MleAmSgmm2Accs num_accs,
const MleAmSgmm2Accs den_accs,
const std::vector< SpMatrix< double > > &  H,
AmSgmm2 model,
double *  auxf_impr,
int32  num_threads,
int32  thread_id 
) const
private

Definition at line 178 of file estimate-am-sgmm2-ebw.cc.

References VectorBase< Real >::AddMatVec(), SpMatrix< Real >::AddSp(), VectorBase< Real >::AddSpVec(), VectorBase< Real >::AddVec(), VectorBase< Real >::ApplySoftMax(), EbwAmSgmm2Updater::ComputePhoneVecStats(), SolverOptions::eps, EbwAmSgmm2Options::epsilon, MleAmSgmm2Accs::gamma_, SolverOptions::K, KALDI_LOG, kaldi::kNoTrans, EbwAmSgmm2Options::lrate_v, EbwAmSgmm2Options::max_cond, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, MleAmSgmm2Accs::num_groups_, AmSgmm2::NumSubstatesForGroup(), EbwAmSgmm2Updater::options_, MleAmSgmm2Accs::phn_space_dim_, PackedMatrix< Real >::Scale(), kaldi::SolveQuadraticProblem(), EbwAmSgmm2Options::tau_v, AmSgmm2::v_, and AmSgmm2::w_.

185  {
186 
187  int32 block_size = (num_accs.num_groups_ + (num_threads-1)) / num_threads,
188  j1_start = block_size * thread_id,
189  j1_end = std::min(num_accs.num_groups_, j1_start + block_size);
190 
191  int32 S = num_accs.phn_space_dim_, I = num_accs.num_gaussians_;
192 
193  for (int32 j1 = j1_start; j1 < j1_end; j1++) {
194  double num_state_count = 0.0,
195  state_auxf_impr = 0.0;
196  Vector<double> w_jm(I);
197  for (int32 m = 0; m < model->NumSubstatesForGroup(j1); m++) {
198  double gamma_jm_num = num_accs.gamma_[j1].Row(m).Sum();
199  double gamma_jm_den = den_accs.gamma_[j1].Row(m).Sum();
200  num_state_count += gamma_jm_num;
201  Vector<double> g_jm_num(S); // computed using eq. 58 of SGMM paper [for numerator stats]
202  SpMatrix<double> H_jm_num(S); // computed using eq. 59 of SGMM paper [for numerator stats]
203  Vector<double> g_jm_den(S); // same, but for denominator stats.
204  SpMatrix<double> H_jm_den(S);
205 
206  // Compute the weights for this sub-state.
207  // w_jm = softmax([w_{k1}^T ... w_{kD}^T] * v_{jkm}) eq.(7)
208  w_jm.AddMatVec(1.0, Matrix<double>(model->w_), kNoTrans,
209  Vector<double>(model->v_[j1].Row(m)), 0.0);
210  w_jm.ApplySoftMax();
211  // Note: in the ML code, in the SSGMM case, at this point the w_jm would
212  // be modified with the "a" quantities to get the "\tilde{w}_{jm}" of the
213  // SSGMM techreport. But in this code, it gets done inside ComputePhoneVecStats.
214 
215  ComputePhoneVecStats(num_accs, *model, H, j1, m, w_jm, gamma_jm_num,
216  &g_jm_num, &H_jm_num);
217  ComputePhoneVecStats(den_accs, *model, H, j1, m, w_jm, gamma_jm_den,
218  &g_jm_den, &H_jm_den);
219 
220  Vector<double> v_jm(model->v_[j1].Row(m));
221  Vector<double> local_derivative(S); // difference of derivative of numerator
222  // and denominator objetive function.
223  local_derivative.AddVec(1.0, g_jm_num);
224  local_derivative.AddSpVec(-1.0, H_jm_num, v_jm, 1.0);
225  local_derivative.AddVec(-1.0, g_jm_den);
226  local_derivative.AddSpVec(-1.0 * -1.0, H_jm_den, v_jm, 1.0);
227 
228  SpMatrix<double> quadratic_term(H_jm_num);
229  quadratic_term.AddSp(1.0, H_jm_den);
230  double substate_count = 1.0e-10 + gamma_jm_num + gamma_jm_den;
231  quadratic_term.Scale( (substate_count + options_.tau_v) / substate_count);
232  quadratic_term.Scale(1.0 / (options_.lrate_v + 1.0e-10) );
233 
234  Vector<double> delta_v_jm(S);
235 
236  SolverOptions opts;
237  opts.name = "v";
238  opts.K = options_.max_cond;
239  opts.eps = options_.epsilon;
240 
241  double auxf_impr =
242  ((gamma_jm_num + gamma_jm_den == 0) ? 0.0 :
243  SolveQuadraticProblem(quadratic_term,
244  local_derivative,
245  opts, &delta_v_jm));
246 
247  v_jm.AddVec(1.0, delta_v_jm);
248  model->v_[j1].Row(m).CopyFromVec(v_jm);
249  state_auxf_impr += auxf_impr;
250  }
251 
252  *auxf_impr += state_auxf_impr;
253  if (j1 < 10 && thread_id == 0) {
254  KALDI_LOG << "Objf impr for group j = " << j1 << " is "
255  << (state_auxf_impr / (num_state_count + 1.0e-10))
256  << " over " << num_state_count << " frames";
257  }
258  }
259 }
BaseFloat max_cond
is allowed to change.
double SolveQuadraticProblem(const SpMatrix< double > &H, const VectorBase< double > &g, const SolverOptions &opts, VectorBase< double > *x)
Definition: sp-matrix.cc:635
BaseFloat lrate_v
Learning rate used in updating v– default 0.5.
BaseFloat tau_v
Smoothing constant for updates of sub-state vectors v_{jm}.
kaldi::int32 int32
BaseFloat epsilon
very small value used in SolveQuadraticProblem; workaround
#define KALDI_LOG
Definition: kaldi-error.h:153
static void ComputePhoneVecStats(const MleAmSgmm2Accs &accs, const AmSgmm2 &model, const std::vector< SpMatrix< double > > &H, int32 j1, int32 m, const Vector< double > &w_jm, double gamma_jm, Vector< double > *g_jm, SpMatrix< double > *H_jm)

◆ UpdateSubstateWeights()

double UpdateSubstateWeights ( const MleAmSgmm2Accs num_accs,
const MleAmSgmm2Accs den_accs,
AmSgmm2 model 
)
private

Definition at line 672 of file estimate-am-sgmm2-ebw.cc.

References AmSgmm2::c_, VectorBase< Real >::Dim(), MleAmSgmm2Accs::gamma_c_, KALDI_LOG, EbwAmSgmm2Options::min_substate_weight, MleAmSgmm2Accs::num_pdfs_, AmSgmm2::NumSubstatesForPdf(), EbwAmSgmm2Updater::options_, VectorBase< Real >::Scale(), VectorBase< Real >::Sum(), and EbwAmSgmm2Options::tau_c.

Referenced by EbwAmSgmm2Updater::Update().

675  {
676  KALDI_LOG << "Updating substate mixture weights";
677 
678  double tot_count = 0.0, tot_impr = 0.0;
679  for (int32 j2 = 0; j2 < num_accs.num_pdfs_; j2++) {
680  int32 M = model->NumSubstatesForPdf(j2);
681  Vector<double> num_occs(M), den_occs(M),
682  orig_weights(model->c_[j2]), weights(model->c_[j2]);
683 
684  for (int32 m = 0; m < M; m++) {
685  num_occs(m) = num_accs.gamma_c_[j2](m)
686  + options_.tau_c * weights(m);
687  den_occs(m) = den_accs.gamma_c_[j2](m);
688  }
689 
690  if (weights.Dim() > 1) {
691  double begin_auxf = 0.0, end_auxf = 0.0;
692  for (int32 m = 0; m < M; m++) { // see eq. 4.32, Dan Povey's PhD thesis.
693  begin_auxf += num_occs(m) * log (weights(m))
694  - den_occs(m) * weights(m) / orig_weights(m);
695  }
696  for (int32 iter = 0; iter < 50; iter++) {
697  Vector<double> k_jm(M);
698  double max_m = 0.0;
699  for (int32 m = 0; m < M; m++)
700  max_m = std::max(max_m, den_occs(m)/orig_weights(m));
701  for (int32 m = 0; m < M; m++)
702  k_jm(m) = max_m - den_occs(m)/orig_weights(m);
703  for (int32 m = 0; m < M; m++)
704  weights(m) = num_occs(m) + k_jm(m)*weights(m);
705  weights.Scale(1.0 / weights.Sum());
706  }
707  for (int32 m = 0; m < M; m++)
708  weights(m) = std::max(weights(m),
709  static_cast<double>(options_.min_substate_weight));
710  weights.Scale(1.0 / weights.Sum()); // renormalize.
711 
712  for (int32 m = 0; m < M; m++) {
713  end_auxf += num_occs(m) * log (weights(m))
714  - den_occs(m) * weights(m) / orig_weights(m);
715  }
716  tot_impr += end_auxf - begin_auxf;
717  double this_impr = ((end_auxf - begin_auxf) / num_occs.Sum());
718  if (j2 < 10 || this_impr > 0.5) {
719  KALDI_LOG << "Updating substate weights: auxf impr for pdf " << j2
720  << " is " << this_impr << " per frame over " << num_occs.Sum()
721  << " frames (den count is " << den_occs.Sum() << ")";
722  }
723  }
724  model->c_[j2].CopyFromVec(weights);
725  tot_count += den_occs.Sum(); // Note: num and den occs should be the
726  // same, except num occs are smoothed, so this is what we want.
727  }
728 
729  tot_impr /= (tot_count + 1.0e-20);
730 
731  KALDI_LOG << "**Overall auxf impr for c is " << tot_impr
732  << " over " << tot_count << " frames";
733  return tot_impr;
734 }
kaldi::int32 int32
BaseFloat tau_c
Tau value for smoothing substate weights (c)
BaseFloat min_substate_weight
Minimum allowed weight in a sub-state.
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ UpdateU()

double UpdateU ( const MleAmSgmm2Accs num_accs,
const MleAmSgmm2Accs den_accs,
const Vector< double > &  gamma_num,
const Vector< double > &  gamma_den,
AmSgmm2 model 
)
private

Definition at line 463 of file estimate-am-sgmm2-ebw.cc.

References SpMatrix< Real >::AddSp(), VectorBase< Real >::AddVec(), SolverOptions::eps, EbwAmSgmm2Options::epsilon, rnnlm::i, SolverOptions::K, KALDI_LOG, KALDI_WARN, EbwAmSgmm2Options::lrate_u, EbwAmSgmm2Options::max_cond, EbwAmSgmm2Options::max_impr_u, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, EbwAmSgmm2Updater::options_, MatrixBase< Real >::Row(), PackedMatrix< Real >::Scale(), VectorBase< Real >::Scale(), kaldi::SolveQuadraticProblem(), MleAmSgmm2Accs::spk_space_dim_, VectorBase< Real >::Sum(), MleAmSgmm2Accs::t_, EbwAmSgmm2Options::tau_u, MleAmSgmm2Accs::U_, and AmSgmm2::u_.

Referenced by EbwAmSgmm2Updater::Update().

467  {
468  int32 T = num_accs.spk_space_dim_;
469  double tot_impr = 0.0;
470  for (int32 i = 0; i < num_accs.num_gaussians_; i++) {
471  if (gamma_num(i) < 200.0) {
472  KALDI_LOG << "Numerator count is small " << gamma_num(i) << " for gaussian "
473  << i << ", not updating u_i.";
474  continue;
475  }
476  Vector<double> u_i(model->u_.Row(i));
477  Vector<double> delta_u(T);
478  Vector<double> t(T); // derivative.
479  t.AddVec(1.0, num_accs.t_.Row(i));
480  t.AddVec(-1.0, den_accs.t_.Row(i));
481  SpMatrix<double> U(T); // quadratic term.
482  U.AddSp(1.0, num_accs.U_[i]);
483  U.AddSp(1.0, den_accs.U_[i]);
484 
485  double state_count = gamma_num(i) + gamma_den(i);
486  U.Scale((state_count + options_.tau_u) / (state_count + 1.0e-10));
487  U.Scale(1.0 / (options_.lrate_u + 1.0e-10) );
488 
489  SolverOptions opts;
490  opts.name = "u";
491  opts.K = options_.max_cond;
492  opts.eps = options_.epsilon;
493 
494  double impr = SolveQuadraticProblem(U, t, opts, &delta_u);
495  double impr_per_frame = impr / gamma_num(i);
496  if (impr_per_frame > options_.max_impr_u) {
497  KALDI_WARN << "Updating speaker weight projections u, for Gaussian index "
498  << i << ", impr/frame is " << impr_per_frame << " over "
499  << gamma_num(i) << " frames, scaling back to not exceed "
500  << options_.max_impr_u;
501  double scale = options_.max_impr_u / impr_per_frame;
502  impr *= scale;
503  delta_u.Scale(scale);
504  // Note: a linear scaling of "impr" with "scale" is not quite accurate
505  // in depicting how the quadratic auxiliary function varies as we change
506  // the scale on "delta", but this does not really matter-- the goal is
507  // to limit the auxiliary-function change to not be too large.
508  }
509  if (i < 10) {
510  KALDI_LOG << "Objf impr for spk weight-projection u for i = " << (i)
511  << ", is " << (impr / (gamma_num(i) + 1.0e-20)) << " over "
512  << gamma_num(i) << " frames";
513  }
514  u_i.AddVec(1.0, delta_u);
515  model->u_.Row(i).CopyFromVec(u_i);
516  tot_impr += impr;
517  }
518  KALDI_LOG << "**Overall objf impr for u is " << (tot_impr/gamma_num.Sum())
519  << ", over " << gamma_num.Sum() << " frames";
520  return tot_impr;
521 }
BaseFloat max_cond
is allowed to change.
double SolveQuadraticProblem(const SpMatrix< double > &H, const VectorBase< double > &g, const SolverOptions &opts, VectorBase< double > *x)
Definition: sp-matrix.cc:635
BaseFloat max_impr_u
Maximum improvement/frame allowed for u [0.25, carried over from ML update.].
BaseFloat tau_u
Tau value for smoothing update of speaker-subspace weight projectsions (u)
kaldi::int32 int32
BaseFloat lrate_u
Learning rate used in updating u– default 1.0.
#define KALDI_WARN
Definition: kaldi-error.h:150
Real Sum() const
Returns sum of the elements.
BaseFloat epsilon
very small value used in SolveQuadraticProblem; workaround
#define KALDI_LOG
Definition: kaldi-error.h:153
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...

◆ UpdateVars()

double UpdateVars ( const MleAmSgmm2Accs num_accs,
const MleAmSgmm2Accs den_accs,
const Vector< double > &  gamma_num,
const Vector< double > &  gamma_den,
const std::vector< SpMatrix< double > > &  S_means,
AmSgmm2 model 
) const
private

Definition at line 597 of file estimate-am-sgmm2-ebw.cc.

References SpMatrix< Real >::AddSp(), SpMatrix< Real >::ApplyFloor(), count, EbwAmSgmm2Options::cov_min_value, rnnlm::i, SpMatrix< Real >::Invert(), KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, SpMatrix< Real >::LogDet(), EbwAmSgmm2Options::lrate_Sigma, MleAmSgmm2Accs::num_gaussians_, EbwAmSgmm2Updater::options_, MleAmSgmm2Accs::S_, PackedMatrix< Real >::Scale(), AmSgmm2::SigmaInv_, EbwAmSgmm2Options::tau_Sigma, and kaldi::TraceSpSp().

Referenced by EbwAmSgmm2Updater::Update().

602  {
603  // Note: S_means contains not only the quantity S_means in the paper,
604  // but also has a term - (Y_i M_i^T + M_i Y_i^T). Plus, it is differenced
605  // between numerator and denominator. We don't calculate it here,
606  // because it had to be computed with the original model, before we
607  // changed the M quantities.
608  int32 I = num_accs.num_gaussians_;
609  KALDI_ASSERT(S_means.size() == I);
610  Vector<double> impr_vec(I);
611 
612  for (int32 i = 0; i < I; i++) {
613  double num_count = gamma_num(i), den_count = gamma_den(i);
614 
615  SpMatrix<double> SigmaStats(S_means[i]);
616  SigmaStats.AddSp(1.0, num_accs.S_[i]);
617  SigmaStats.AddSp(-1.0, den_accs.S_[i]);
618  // SigmaStats now contain the stats for estimating Sigma (as in the main SGMM paper),
619  // differenced between num and den.
620  SpMatrix<double> SigmaInvOld(model->SigmaInv_[i]), SigmaOld(model->SigmaInv_[i]);
621  SigmaOld.Invert();
622  double count = num_count - den_count;
624  double inv_lrate = 1.0 / options_.lrate_Sigma;
625  // These formulas assure that the objective function behaves in
626  // a roughly symmetric way w.r.t. num and den counts.
627  double E_den = 1.0 + inv_lrate, E_num = inv_lrate - 1.0;
628 
629  double smoothing_count =
630  (options_.tau_Sigma * inv_lrate) + // multiply tau_Sigma by inverse-lrate
631  (E_den * den_count) + // for compatibility with other updates.
632  (E_num * num_count) +
633  1.0e-10;
634  SigmaStats.AddSp(smoothing_count, SigmaOld);
635  count += smoothing_count;
636  SigmaStats.Scale(1.0 / count);
637  SpMatrix<double> SigmaInv(SigmaStats); // before floor and ceiling. Currently sigma,
638  // not its inverse.
639  bool verbose = false;
640  int n_floor = SigmaInv.ApplyFloor(SigmaOld, options_.cov_min_value, verbose);
641  SigmaInv.Invert(); // make it inverse variance.
642  int n_ceiling = SigmaInv.ApplyFloor(SigmaInvOld, options_.cov_min_value, verbose);
643 
644  // this auxf_change.
645  double auxf_change = -0.5 * count *(TraceSpSp(SigmaInv, SigmaStats)
646  - TraceSpSp(SigmaInvOld, SigmaStats)
647  - SigmaInv.LogDet()
648  + SigmaInvOld.LogDet());
649 
650  model->SigmaInv_[i].CopyFromSp(SigmaInv);
651  impr_vec(i) = auxf_change;
652  if (i < 10 || auxf_change / (num_count+den_count+1.0e-10) > 2.0
653  || n_floor+n_ceiling > 0) {
654  KALDI_LOG << "Updating variance: Auxf change per frame for Gaussian "
655  << i << " is " << (auxf_change / num_count) << " over "
656  << num_count << " frames " << "(den count was " << den_count
657  << "), #floor,ceil was " << n_floor << ", " << n_ceiling;
658  }
659  }
660  KALDI_VLOG(1) << "Updating Sigma: numerator count is " << gamma_num;
661  KALDI_VLOG(1) << "Updating Sigma: denominator count is " << gamma_den;
662  KALDI_VLOG(1) << "Updating Sigma: objf-impr is " << impr_vec;
663 
664  double tot_count = gamma_num.Sum(), tot_impr = impr_vec.Sum();
665  tot_impr /= tot_count+1.0e-20;
666  KALDI_LOG << "**Overall auxf impr for Sigma is " << tot_impr
667  << " over " << tot_count << " frames";
668  return tot_impr;
669 }
BaseFloat tau_Sigma
Tau value for smoothing covariance-matrices Sigma.
BaseFloat lrate_Sigma
Learning rate used in updating Sigma– default 0.5.
kaldi::int32 int32
const size_t count
double TraceSpSp(const SpMatrix< double > &A, const SpMatrix< double > &B)
Definition: sp-matrix.cc:326
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ UpdateW()

double UpdateW ( const MleAmSgmm2Accs num_accs,
const MleAmSgmm2Accs den_accs,
const Vector< double > &  gamma_num,
const Vector< double > &  gamma_den,
AmSgmm2 model 
)
private

Note: in the discriminative case we do just one iteration of updating the w quantities.

Definition at line 370 of file estimate-am-sgmm2-ebw.cc.

References VectorBase< Real >::AddVec(), MleAmSgmm2Updater::ComputeLogA(), PackedMatrix< Real >::CopyFromVec(), SolverOptions::eps, EbwAmSgmm2Options::epsilon, AmSgmm2::HasSpeakerDependentWeights(), rnnlm::i, SolverOptions::K, KALDI_LOG, KALDI_VLOG, EbwAmSgmm2Options::lrate_w, EbwAmSgmm2Options::max_cond, SolverOptions::name, MleAmSgmm2Accs::num_gaussians_, EbwAmSgmm2Updater::options_, MleAmSgmm2Accs::phn_space_dim_, MatrixBase< Real >::Row(), kaldi::RunMultiThreaded(), kaldi::SolveQuadraticProblem(), VectorBase< Real >::Sum(), EbwAmSgmm2Options::tau_w, and AmSgmm2::w_.

Referenced by EbwAmSgmm2Updater::Update().

374  {
375  KALDI_LOG << "Updating weight projections";
376 
377  int32 I = num_accs.num_gaussians_, S = num_accs.phn_space_dim_;
378 
379  Matrix<double> g_i_num(I, S), g_i_den(I, S);
380 
381  // View F_i_{num,den} as vectors of SpMatrix [i.e. symmetric matrices,
382  // linearized into vectors]
383  Matrix<double> F_i_num(I, (S*(S+1))/2), F_i_den(I, (S*(S+1))/2);
384 
385  Vector<double> impr_vec(I);
386 
387  // Get the F_i and g_i quantities-- this is done in parallel (multi-core),
388  // using the same code we use in the ML update [except we get it for
389  // numerator and denominator separately.]
390  Matrix<double> w(model->w_);
391  {
392  std::vector<Matrix<double> > log_a_num;
393  if (model->HasSpeakerDependentWeights())
394  MleAmSgmm2Updater::ComputeLogA(num_accs, &log_a_num);
395  double garbage;
396  UpdateWClass c_num(num_accs, *model, w, log_a_num, &F_i_num, &g_i_num, &garbage);
397  RunMultiThreaded(c_num);
398  }
399  {
400  std::vector<Matrix<double> > log_a_den;
401  if (model->HasSpeakerDependentWeights())
402  MleAmSgmm2Updater::ComputeLogA(den_accs, &log_a_den);
403  double garbage;
404  UpdateWClass c_den(den_accs, *model, w, log_a_den, &F_i_den, &g_i_den, &garbage);
405  RunMultiThreaded(c_den);
406  }
407 
408  for (int32 i = 0; i < I; i++) {
409 
410  // auxf was originally formulated in terms of the change in w (i.e. the
411  // g quantities are the local derivatives), so there is less hassle than
412  // with some of the other updates, in changing it to be discriminative.
413  // we essentially just difference the linear terms and add the quadratic
414  // terms.
415 
416  Vector<double> derivative(g_i_num.Row(i));
417  derivative.AddVec(-1.0, g_i_den.Row(i));
418  // F_i_num quadratic_term is a bit like the negated 2nd derivative
419  // of the numerator stats-- actually it's not the actual 2nd deriv,
420  // but an upper bound on it.
421  SpMatrix<double> quadratic_term(S), tmp_F(S);
422  quadratic_term.CopyFromVec(F_i_num.Row(i));
423  tmp_F.CopyFromVec(F_i_den.Row(i)); // tmp_F is used for Vector->SpMatrix conversion.
424  quadratic_term.AddSp(1.0, tmp_F);
425 
426  double state_count = gamma_num(i) + gamma_den(i);
427 
428  quadratic_term.Scale((state_count + options_.tau_w) / (state_count + 1.0e-10));
429  quadratic_term.Scale(1.0 / (options_.lrate_w + 1.0e-10) );
430 
431  Vector<double> delta_w(S);
432 
433  SolverOptions opts;
434  opts.name = "w";
435  opts.K = options_.max_cond;
436  opts.eps = options_.epsilon;
437 
438  double objf_impr =
439  SolveQuadraticProblem(quadratic_term, derivative, opts, &delta_w);
440 
441  impr_vec(i) = objf_impr;
442  if (i < 10 || objf_impr / (gamma_num(i) + 1.0e-10) > 2.0) {
443  KALDI_LOG << "Predicted objf impr for w per frame is "
444  << (objf_impr / (gamma_num(i) + 1.0e-10))
445  << " over " << gamma_num(i) << " frames.";
446  }
447  model->w_.Row(i).AddVec(1.0, delta_w);
448  }
449  KALDI_VLOG(1) << "Updating w: numerator count is " << gamma_num;
450  KALDI_VLOG(1) << "Updating w: denominator count is " << gamma_den;
451  KALDI_VLOG(1) << "Updating w: objf-impr is " << impr_vec;
452 
453  double tot_num_count = gamma_num.Sum(), tot_impr = impr_vec.Sum();
454  tot_impr /= tot_num_count;
455 
456  KALDI_LOG << "**Overall objf impr for w per frame is "
457  << tot_impr << " over " << tot_num_count
458  << " frames.";
459  return tot_impr;
460 }
BaseFloat max_cond
is allowed to change.
double SolveQuadraticProblem(const SpMatrix< double > &H, const VectorBase< double > &g, const SolverOptions &opts, VectorBase< double > *x)
Definition: sp-matrix.cc:635
BaseFloat lrate_w
Learning rate used in updating w– default 1.0.
kaldi::int32 int32
void RunMultiThreaded(const C &c_in)
Here, class C should inherit from MultiThreadable.
Definition: kaldi-thread.h:151
Real Sum() const
Returns sum of the elements.
BaseFloat tau_w
Tau value for smoothing update of phonetic-subspace weight projectsions (w)
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
BaseFloat epsilon
very small value used in SolveQuadraticProblem; workaround
#define KALDI_LOG
Definition: kaldi-error.h:153
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
static void ComputeLogA(const MleAmSgmm2Accs &accs, std::vector< Matrix< double > > *log_a)

Friends And Related Function Documentation

◆ EbwUpdatePhoneVectorsClass

friend class EbwUpdatePhoneVectorsClass
friend

Definition at line 163 of file estimate-am-sgmm2-ebw.h.

◆ EbwUpdateWClass

friend class EbwUpdateWClass
friend

Definition at line 162 of file estimate-am-sgmm2-ebw.h.

Member Data Documentation

◆ gamma_j_

Vector<double> gamma_j_
private

State occupancies.

Definition at line 167 of file estimate-am-sgmm2-ebw.h.

◆ options_


The documentation for this class was generated from the following files: