PldaStats Class Reference

#include <plda.h>

Collaboration diagram for PldaStats:

Classes

struct  ClassInfo
 

Public Member Functions

 PldaStats ()
 
void AddSamples (double weight, const Matrix< double > &group)
 The dimension is set up the first time you add samples. More...
 
int32 Dim () const
 
void Init (int32 dim)
 
void Sort ()
 
bool IsSorted () const
 
 ~PldaStats ()
 

Protected Attributes

int32 dim_
 
int64 num_classes_
 
int64 num_examples_
 
double class_weight_
 
double example_weight_
 
Vector< double > sum_
 
SpMatrix< double > offset_scatter_
 
std::vector< ClassInfoclass_info_
 

Private Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (PldaStats)
 

Friends

class PldaEstimator
 

Detailed Description

Definition at line 172 of file plda.h.

Constructor & Destructor Documentation

◆ PldaStats()

PldaStats ( )
inline

Definition at line 174 of file plda.h.

174 : dim_(0) { }
int32 dim_
Definition: plda.h:195

◆ ~PldaStats()

~PldaStats ( )

Definition at line 313 of file plda.cc.

References rnnlm::i.

313  {
314  for (size_t i = 0; i < class_info_.size(); i++)
315  delete class_info_[i].mean;
316 }
std::vector< ClassInfo > class_info_
Definition: plda.h:220

Member Function Documentation

◆ AddSamples()

void AddSamples ( double  weight,
const Matrix< double > &  group 
)

The dimension is set up the first time you add samples.

This function adds training samples corresponding to one class (e.g. a speaker). Each row is a separate sample from this group. The "weight" would normally be 1.0, but you can set it to other values if you want to weight your training samples.

Definition at line 286 of file plda.cc.

References VectorBase< Real >::AddRowSumMat(), KALDI_ASSERT, kaldi::kTrans, rnnlm::n, MatrixBase< Real >::NumCols(), and MatrixBase< Real >::NumRows().

Referenced by main(), and kaldi::UnitTestPldaEstimation().

287  {
288  if (dim_ == 0) {
289  Init(group.NumCols());
290  } else {
291  KALDI_ASSERT(dim_ == group.NumCols());
292  }
293  int32 n = group.NumRows(); // number of examples for this class
294  Vector<double> *mean = new Vector<double>(dim_);
295  mean->AddRowSumMat(1.0 / n, group);
296 
297  offset_scatter_.AddMat2(weight, group, kTrans, 1.0);
298  // the following statement has the same effect as if we
299  // had first subtracted the mean from each element of
300  // the group before the statement above.
301  offset_scatter_.AddVec2(-n * weight, *mean);
302 
303  class_info_.push_back(ClassInfo(weight, mean, n));
304 
305  num_classes_ ++;
306  num_examples_ += n;
307  class_weight_ += weight;
308  example_weight_ += weight * n;
309 
310  sum_.AddVec(weight, *mean);
311 }
void AddMat2(const Real alpha, const MatrixBase< Real > &M, MatrixTransposeType transM, const Real beta)
rank-N update: if (transM == kNoTrans) (*this) = beta*(*this) + alpha * M * M^T, or (if transM == kTr...
Definition: sp-matrix.cc:1110
double example_weight_
Definition: plda.h:199
kaldi::int32 int32
SpMatrix< double > offset_scatter_
Definition: plda.h:204
Vector< double > sum_
Definition: plda.h:201
void Init(int32 dim)
Definition: plda.cc:325
void AddVec2(const Real alpha, const VectorBase< OtherReal > &v)
rank-one update, this <– this + alpha v v&#39;
Definition: sp-matrix.cc:946
double class_weight_
Definition: plda.h:198
struct rnnlm::@11::@12 n
int64 num_classes_
Definition: plda.h:196
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< ClassInfo > class_info_
Definition: plda.h:220
int32 dim_
Definition: plda.h:195
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
int64 num_examples_
Definition: plda.h:197

◆ Dim()

int32 Dim ( ) const
inline

Definition at line 184 of file plda.h.

Referenced by main().

184 { return dim_; }
int32 dim_
Definition: plda.h:195

◆ Init()

void Init ( int32  dim)

Definition at line 325 of file plda.cc.

References KALDI_ASSERT.

325  {
326  KALDI_ASSERT(dim_ == 0);
327  dim_ = dim;
328  num_classes_ = 0;
329  num_examples_ = 0;
330  class_weight_ = 0.0;
331  example_weight_ = 0.0;
332  sum_.Resize(dim);
333  offset_scatter_.Resize(dim);
334  KALDI_ASSERT(class_info_.empty());
335 }
double example_weight_
Definition: plda.h:199
SpMatrix< double > offset_scatter_
Definition: plda.h:204
Vector< double > sum_
Definition: plda.h:201
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
double class_weight_
Definition: plda.h:198
int64 num_classes_
Definition: plda.h:196
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< ClassInfo > class_info_
Definition: plda.h:220
int32 dim_
Definition: plda.h:195
void Resize(MatrixIndexT nRows, MatrixResizeType resize_type=kSetZero)
Definition: sp-matrix.h:81
int64 num_examples_
Definition: plda.h:197

◆ IsSorted()

bool IsSorted ( ) const

Definition at line 318 of file plda.cc.

References rnnlm::i.

Referenced by PldaEstimator::PldaEstimator().

318  {
319  for (size_t i = 0; i + 1 < class_info_.size(); i++)
320  if (class_info_[i+1] < class_info_[i])
321  return false;
322  return true;
323 }
std::vector< ClassInfo > class_info_
Definition: plda.h:220

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( PldaStats  )
private

◆ Sort()

void Sort ( )
inline

Definition at line 188 of file plda.h.

References kaldi::IsSorted().

Referenced by main(), and kaldi::UnitTestPldaEstimation().

188 { std::sort(class_info_.begin(), class_info_.end()); }
std::vector< ClassInfo > class_info_
Definition: plda.h:220

Friends And Related Function Documentation

◆ PldaEstimator

friend class PldaEstimator
friend

Definition at line 193 of file plda.h.

Member Data Documentation

◆ class_info_

std::vector<ClassInfo> class_info_
protected

◆ class_weight_

◆ dim_

int32 dim_
protected

Definition at line 195 of file plda.h.

◆ example_weight_

double example_weight_
protected

◆ num_classes_

int64 num_classes_
protected

Definition at line 196 of file plda.h.

◆ num_examples_

int64 num_examples_
protected

Definition at line 197 of file plda.h.

◆ offset_scatter_

SpMatrix<double> offset_scatter_
protected

◆ sum_

Vector<double> sum_
protected

The documentation for this class was generated from the following files: