#include <plda.h>
Public Member Functions | |
PldaEstimator (const PldaStats &stats) | |
void | Estimate (const PldaEstimationConfig &config, Plda *output) |
Private Types | |
typedef PldaStats::ClassInfo | ClassInfo |
Private Member Functions | |
double | ComputeObjfPart1 () const |
Returns the part of the objf relating to offsets from the class means. More... | |
double | ComputeObjfPart2 () const |
Returns the part of the obj relating to the class means (total_not normalized) More... | |
double | ComputeObjf () const |
Returns the objective-function per sample. More... | |
int32 | Dim () const |
void | EstimateOneIter () |
void | InitParameters () |
void | ResetPerIterStats () |
void | GetStatsFromIntraClass () |
void | GetStatsFromClassMeans () |
GetStatsFromClassMeans() is the more complicated part of PLDA estimation. More... | |
void | EstimateFromStats () |
void | GetOutput (Plda *plda) |
KALDI_DISALLOW_COPY_AND_ASSIGN (PldaEstimator) | |
Private Attributes | |
const PldaStats & | stats_ |
SpMatrix< double > | within_var_ |
SpMatrix< double > | between_var_ |
SpMatrix< double > | within_var_stats_ |
double | within_var_count_ |
SpMatrix< double > | between_var_stats_ |
double | between_var_count_ |
|
private |
PldaEstimator | ( | const PldaStats & | stats | ) |
Definition at line 338 of file plda.cc.
References PldaEstimator::InitParameters(), PldaStats::IsSorted(), and KALDI_ASSERT.
|
private |
Returns the objective-function per sample.
Definition at line 390 of file plda.cc.
References PldaEstimator::ComputeObjfPart1(), PldaEstimator::ComputeObjfPart2(), PldaStats::example_weight_, KALDI_LOG, and PldaEstimator::stats_.
Referenced by PldaEstimator::EstimateOneIter().
|
private |
Returns the part of the objf relating to offsets from the class means.
(total, not normalized)
Definition at line 345 of file plda.cc.
References PldaStats::class_weight_, PldaEstimator::Dim(), PldaStats::example_weight_, SpMatrix< Real >::Invert(), KALDI_ASSERT, M_LOG_2PI, PldaStats::offset_scatter_, PldaEstimator::stats_, kaldi::TraceSpSp(), and PldaEstimator::within_var_.
Referenced by PldaEstimator::ComputeObjf().
|
private |
Returns the part of the obj relating to the class means (total_not normalized)
Definition at line 365 of file plda.cc.
References SpMatrix< Real >::AddSp(), VectorBase< Real >::AddVec(), PldaEstimator::between_var_, PldaStats::class_info_, PldaStats::class_weight_, SpMatrix< Real >::CopyFromSp(), PldaEstimator::Dim(), rnnlm::i, SpMatrix< Real >::Invert(), M_LOG_2PI, PldaStats::ClassInfo::mean, rnnlm::n, PldaStats::ClassInfo::num_examples, PldaEstimator::stats_, PldaStats::sum_, kaldi::VecSpVec(), PldaStats::ClassInfo::weight, and PldaEstimator::within_var_.
Referenced by PldaEstimator::ComputeObjf().
|
inlineprivate |
Definition at line 255 of file plda.h.
References kaldi::GetOutput().
Referenced by PldaEstimator::ComputeObjfPart1(), PldaEstimator::ComputeObjfPart2(), PldaEstimator::GetOutput(), PldaEstimator::GetStatsFromClassMeans(), PldaEstimator::InitParameters(), and PldaEstimator::ResetPerIterStats().
void Estimate | ( | const PldaEstimationConfig & | config, |
Plda * | output | ||
) |
Definition at line 525 of file plda.cc.
References PldaEstimator::EstimateOneIter(), PldaStats::example_weight_, PldaEstimator::GetOutput(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, PldaEstimationConfig::num_em_iters, and PldaEstimator::stats_.
Referenced by main(), and kaldi::UnitTestPldaEstimation().
|
private |
Definition at line 505 of file plda.cc.
References PldaEstimator::between_var_, PldaEstimator::between_var_count_, PldaEstimator::between_var_stats_, SpMatrix< Real >::CopyFromSp(), KALDI_LOG, PackedMatrix< Real >::Scale(), SpMatrix< Real >::Trace(), PldaEstimator::within_var_, PldaEstimator::within_var_count_, and PldaEstimator::within_var_stats_.
Referenced by PldaEstimator::EstimateOneIter().
|
private |
Definition at line 516 of file plda.cc.
References PldaEstimator::ComputeObjf(), PldaEstimator::EstimateFromStats(), PldaEstimator::GetStatsFromClassMeans(), PldaEstimator::GetStatsFromIntraClass(), KALDI_VLOG, and PldaEstimator::ResetPerIterStats().
Referenced by PldaEstimator::Estimate().
|
private |
Definition at line 537 of file plda.cc.
References SpMatrix< Real >::AddMat2Sp(), MatrixBase< Real >::AddMatMat(), VectorBase< Real >::ApplyFloor(), kaldi::AssertEqual(), PldaEstimator::between_var_, PldaStats::class_weight_, Plda::ComputeDerivedVars(), kaldi::ComputeNormalizingTransform(), PldaEstimator::Dim(), SpMatrix< Real >::Eig(), kaldi::GetVerboseLevel(), SpMatrix< Real >::IsUnit(), KALDI_ASSERT, KALDI_LOG, KALDI_WARN, kaldi::kNoTrans, kaldi::kTrans, Plda::mean_, VectorBase< Real >::Min(), rnnlm::n, VectorBase< Real >::Norm(), Plda::psi_, Matrix< Real >::Resize(), VectorBase< Real >::Scale(), kaldi::SortSvd(), PldaEstimator::stats_, PldaStats::sum_, Plda::transform_, and PldaEstimator::within_var_.
Referenced by PldaEstimator::Estimate().
|
private |
GetStatsFromClassMeans() is the more complicated part of PLDA estimation.
Let's suppose the mean of a particular class is m, and suppose that that class had n examples. We suppose that m ~ N(0, between_var_ + 1/n within_var_) i.e. m is Gaussian-distributed with zero mean and variance equal to the between-class variance plus 1/n times the within-class variance. Now, m is observed (as stats_.class_info_[something].mean). We're doing an E-M procedure where we treat m as the sum of two variables: m = x + y where x ~ N(0, between_var_) y ~ N(0, 1/n * within_var_) The distribution of x will contribute to the stats of between_var_, and y to within_var_. Now, y = m - x, so we can focus on working out the distribution of x and then we can very simply get the distribution of y. The following expression also includes the likelihood of y as a function of x. Note: the C is different from line to line.
log p(x) = C - 0.5 ( x^T between_var^{-1} x + (m-x)^T (1/n within_var)^{-1) (m-x) ) = C - 0.5 x^T (between_var^{-1} + n within_var^{-1}) x + x^T z
where z = n within_var^{-1} m, and we can write this as:
log p(x) = C - 0.5 (x-w)^T (between_var^{-1} + n within_var^{-1}) (x-w)
where x^T (between_var^{-1} + n within_var^{-1}) w = x^T z, i.e. (between_var^{-1} + n within_var^{-1}) w = z = n within_var^{-1} m, so
w = (between_var^{-1} + n within_var^{-1})^{-1} * n within_var^{-1} m
We can see that the distribution over x is Gaussian, with mean w and variance (between_var^{-1} + n within_var^{-1})^{-1}. The distribution over y is Gaussian with the same variance, and mean m - w. So the update to the between-var stats will be: between-var-stats += w w^T + (between_var^{-1} + n within_var^{-1})^{-1}. and the update to the within-var stats will be: within-var-stats += n ( (m-w) (m-w)^T (between_var^{-1} + n within_var^{-1})^{-1} ).
The drawback of this formulation is that each time we encounter a different value of n (number of examples) we will have to do a different matrix inversion. We'll try to improve on this later using a suitable transform.
Definition at line 470 of file plda.cc.
References SpMatrix< Real >::AddSp(), VectorBase< Real >::AddSpVec(), VectorBase< Real >::AddVec(), SpMatrix< Real >::AddVec2(), PldaEstimator::between_var_, PldaEstimator::between_var_count_, PldaEstimator::between_var_stats_, PldaStats::class_info_, PldaStats::class_weight_, SpMatrix< Real >::CopyFromSp(), PldaEstimator::Dim(), rnnlm::i, SpMatrix< Real >::Invert(), PldaStats::ClassInfo::mean, rnnlm::n, PldaStats::ClassInfo::num_examples, PldaEstimator::stats_, PldaStats::sum_, PldaStats::ClassInfo::weight, PldaEstimator::within_var_, PldaEstimator::within_var_count_, and PldaEstimator::within_var_stats_.
Referenced by PldaEstimator::EstimateOneIter().
|
private |
Definition at line 416 of file plda.cc.
References SpMatrix< Real >::AddSp(), PldaStats::class_weight_, PldaStats::example_weight_, PldaStats::offset_scatter_, PldaEstimator::stats_, PldaEstimator::within_var_count_, and PldaEstimator::within_var_stats_.
Referenced by PldaEstimator::EstimateOneIter().
|
private |
Definition at line 402 of file plda.cc.
References PldaEstimator::between_var_, PldaEstimator::Dim(), SpMatrix< Real >::Resize(), PackedMatrix< Real >::SetUnit(), and PldaEstimator::within_var_.
Referenced by PldaEstimator::PldaEstimator().
|
private |
|
private |
Definition at line 409 of file plda.cc.
References PldaEstimator::between_var_count_, PldaEstimator::between_var_stats_, PldaEstimator::Dim(), SpMatrix< Real >::Resize(), PldaEstimator::within_var_count_, and PldaEstimator::within_var_stats_.
Referenced by PldaEstimator::EstimateOneIter().
|
private |
Definition at line 278 of file plda.h.
Referenced by PldaEstimator::ComputeObjfPart2(), PldaEstimator::EstimateFromStats(), PldaEstimator::GetOutput(), PldaEstimator::GetStatsFromClassMeans(), and PldaEstimator::InitParameters().
|
private |
Definition at line 284 of file plda.h.
Referenced by PldaEstimator::EstimateFromStats(), PldaEstimator::GetStatsFromClassMeans(), and PldaEstimator::ResetPerIterStats().
|
private |
Definition at line 283 of file plda.h.
Referenced by PldaEstimator::EstimateFromStats(), PldaEstimator::GetStatsFromClassMeans(), and PldaEstimator::ResetPerIterStats().
|
private |
Definition at line 275 of file plda.h.
Referenced by PldaEstimator::ComputeObjf(), PldaEstimator::ComputeObjfPart1(), PldaEstimator::ComputeObjfPart2(), PldaEstimator::Estimate(), PldaEstimator::GetOutput(), PldaEstimator::GetStatsFromClassMeans(), and PldaEstimator::GetStatsFromIntraClass().
|
private |
Definition at line 277 of file plda.h.
Referenced by PldaEstimator::ComputeObjfPart1(), PldaEstimator::ComputeObjfPart2(), PldaEstimator::EstimateFromStats(), PldaEstimator::GetOutput(), PldaEstimator::GetStatsFromClassMeans(), and PldaEstimator::InitParameters().
|
private |
Definition at line 282 of file plda.h.
Referenced by PldaEstimator::EstimateFromStats(), PldaEstimator::GetStatsFromClassMeans(), PldaEstimator::GetStatsFromIntraClass(), and PldaEstimator::ResetPerIterStats().
|
private |
Definition at line 281 of file plda.h.
Referenced by PldaEstimator::EstimateFromStats(), PldaEstimator::GetStatsFromClassMeans(), PldaEstimator::GetStatsFromIntraClass(), and PldaEstimator::ResetPerIterStats().