All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
AccumAmDiagGmm Class Reference

#include <mle-am-diag-gmm.h>

Collaboration diagram for AccumAmDiagGmm:

Public Member Functions

 AccumAmDiagGmm ()
 
 ~AccumAmDiagGmm ()
 
void Read (std::istream &in_stream, bool binary, bool add=false)
 
void Write (std::ostream &out_stream, bool binary) const
 
void Init (const AmDiagGmm &model, GmmFlagsType flags)
 Initializes accumulators for each GMM based on the number of components and dimension. More...
 
void Init (const AmDiagGmm &model, int32 dim, GmmFlagsType flags)
 Initialization using different dimension than model. More...
 
void SetZero (GmmFlagsType flags)
 
BaseFloat AccumulateForGmm (const AmDiagGmm &model, const VectorBase< BaseFloat > &data, int32 gmm_index, BaseFloat weight)
 Accumulate stats for a single GMM in the model; returns log likelihood. More...
 
BaseFloat AccumulateForGmmTwofeats (const AmDiagGmm &model, const VectorBase< BaseFloat > &data1, const VectorBase< BaseFloat > &data2, int32 gmm_index, BaseFloat weight)
 Accumulate stats for a single GMM in the model; uses data1 for getting posteriors and data2 for stats. More...
 
void AccumulateFromPosteriors (const AmDiagGmm &model, const VectorBase< BaseFloat > &data, int32 gmm_index, const VectorBase< BaseFloat > &posteriors)
 Accumulates stats for a single GMM in the model using pre-computed Gaussian posteriors. More...
 
void AccumulateForGaussian (const AmDiagGmm &am, const VectorBase< BaseFloat > &data, int32 gmm_index, int32 gauss_index, BaseFloat weight)
 Accumulate stats for a single Gaussian component in the model. More...
 
int32 NumAccs ()
 
int32 NumAccs () const
 
BaseFloat TotStatsCount () const
 
BaseFloat TotCount () const
 
BaseFloat TotLogLike () const
 
const AccumDiagGmmGetAcc (int32 index) const
 
AccumDiagGmmGetAcc (int32 index)
 
void Add (BaseFloat scale, const AccumAmDiagGmm &other)
 
void Scale (BaseFloat scale)
 
int32 Dim () const
 

Private Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (AccumAmDiagGmm)
 

Private Attributes

std::vector< AccumDiagGmm * > gmm_accumulators_
 MLE accumulators and update methods for the GMMs. More...
 
double total_frames_
 Total counts & likelihood (for diagnostics) More...
 
double total_log_like_
 

Detailed Description

Definition at line 34 of file mle-am-diag-gmm.h.

Constructor & Destructor Documentation

AccumAmDiagGmm ( )
inline

Definition at line 36 of file mle-am-diag-gmm.h.

36 : total_frames_(0.0), total_log_like_(0.0) {}
double total_frames_
Total counts & likelihood (for diagnostics)

Definition at line 37 of file mle-am-diag-gmm.cc.

References kaldi::DeletePointers(), and AccumAmDiagGmm::gmm_accumulators_.

37  {
39 }
void DeletePointers(std::vector< A * > *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:198
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.

Member Function Documentation

void AccumulateForGaussian ( const AmDiagGmm am,
const VectorBase< BaseFloat > &  data,
int32  gmm_index,
int32  gauss_index,
BaseFloat  weight 
)

Accumulate stats for a single Gaussian component in the model.

Definition at line 108 of file mle-am-diag-gmm.cc.

References AmDiagGmm::GetPdf(), AccumAmDiagGmm::gmm_accumulators_, KALDI_ASSERT, AccumAmDiagGmm::NumAccs(), and DiagGmm::NumGauss().

110  {
111  KALDI_ASSERT(gmm_index >= 0 && gmm_index < NumAccs());
112  KALDI_ASSERT(gauss_index >= 0
113  && gauss_index < am.GetPdf(gmm_index).NumGauss());
114  gmm_accumulators_[gmm_index]->AccumulateForComponent(data, gauss_index, weight);
115 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
BaseFloat AccumulateForGmm ( const AmDiagGmm model,
const VectorBase< BaseFloat > &  data,
int32  gmm_index,
BaseFloat  weight 
)

Accumulate stats for a single GMM in the model; returns log likelihood.

This does not work with multiple feature transforms.

Definition at line 69 of file mle-am-diag-gmm.cc.

References AmDiagGmm::GetPdf(), AccumAmDiagGmm::gmm_accumulators_, KALDI_ASSERT, AccumAmDiagGmm::total_frames_, and AccumAmDiagGmm::total_log_like_.

Referenced by main(), TestAmDiagGmmAccsIO(), and kaldi::UnitTestRegtreeFmllrDiagGmm().

71  {
72  KALDI_ASSERT(static_cast<size_t>(gmm_index) < gmm_accumulators_.size());
73  BaseFloat log_like =
74  gmm_accumulators_[gmm_index]->AccumulateFromDiag(model.GetPdf(gmm_index),
75  data, weight);
76  total_log_like_ += log_like * weight;
77  total_frames_ += weight;
78  return log_like;
79 }
double total_frames_
Total counts & likelihood (for diagnostics)
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
BaseFloat AccumulateForGmmTwofeats ( const AmDiagGmm model,
const VectorBase< BaseFloat > &  data1,
const VectorBase< BaseFloat > &  data2,
int32  gmm_index,
BaseFloat  weight 
)

Accumulate stats for a single GMM in the model; uses data1 for getting posteriors and data2 for stats.

Returns log likelihood.

Definition at line 81 of file mle-am-diag-gmm.cc.

References AccumDiagGmm::AccumulateFromPosteriors(), AmDiagGmm::GetPdf(), AccumAmDiagGmm::gmm_accumulators_, KALDI_ASSERT, VectorBase< Real >::Scale(), AccumAmDiagGmm::total_frames_, and AccumAmDiagGmm::total_log_like_.

Referenced by main().

86  {
87  KALDI_ASSERT(static_cast<size_t>(gmm_index) < gmm_accumulators_.size());
88  const DiagGmm &gmm = model.GetPdf(gmm_index);
89  AccumDiagGmm &acc = *(gmm_accumulators_[gmm_index]);
90  Vector<BaseFloat> posteriors;
91  BaseFloat log_like = gmm.ComponentPosteriors(data1, &posteriors);
92  posteriors.Scale(weight);
93  acc.AccumulateFromPosteriors(data2, posteriors);
94  total_log_like_ += log_like * weight;
95  total_frames_ += weight;
96  return log_like;
97 }
double total_frames_
Total counts & likelihood (for diagnostics)
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
void AccumulateFromPosteriors ( const AmDiagGmm model,
const VectorBase< BaseFloat > &  data,
int32  gmm_index,
const VectorBase< BaseFloat > &  posteriors 
)

Accumulates stats for a single GMM in the model using pre-computed Gaussian posteriors.

Definition at line 100 of file mle-am-diag-gmm.cc.

References AccumAmDiagGmm::gmm_accumulators_, KALDI_ASSERT, AccumAmDiagGmm::NumAccs(), VectorBase< Real >::Sum(), and AccumAmDiagGmm::total_frames_.

102  {
103  KALDI_ASSERT(gmm_index >= 0 && gmm_index < NumAccs());
104  gmm_accumulators_[gmm_index]->AccumulateFromPosteriors(data, posteriors);
105  total_frames_ += posteriors.Sum();
106 }
Real Sum() const
Returns sum of the elements.
double total_frames_
Total counts & likelihood (for diagnostics)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
void Add ( BaseFloat  scale,
const AccumAmDiagGmm other 
)

Definition at line 279 of file mle-am-diag-gmm.cc.

References AccumAmDiagGmm::gmm_accumulators_, rnnlm::i, KALDI_ASSERT, AccumAmDiagGmm::NumAccs(), AccumAmDiagGmm::total_frames_, and AccumAmDiagGmm::total_log_like_.

Referenced by main().

279  {
280  total_frames_ += scale * other.total_frames_;
281  total_log_like_ += scale * other.total_log_like_;
282 
283  int32 num_accs = NumAccs();
284  KALDI_ASSERT(num_accs == other.NumAccs());
285  for (int32 i = 0; i < num_accs; i++)
286  gmm_accumulators_[i]->Add(scale, *(other.gmm_accumulators_[i]));
287 }
void Add(BaseFloat scale, const AccumAmDiagGmm &other)
double total_frames_
Total counts & likelihood (for diagnostics)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
int32 Dim ( ) const
inline

Definition at line 94 of file mle-am-diag-gmm.h.

References AccumAmDiagGmm::gmm_accumulators_.

Referenced by kaldi::MapAmDiagGmmUpdate(), and kaldi::MleAmDiagGmmUpdate().

94  {
95  return (gmm_accumulators_.empty() || !gmm_accumulators_[0] ?
96  0 : gmm_accumulators_[0]->Dim());
97  }
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
const AccumDiagGmm & GetAcc ( int32  index) const
AccumDiagGmm & GetAcc ( int32  index)

Definition at line 32 of file mle-am-diag-gmm.cc.

References AccumAmDiagGmm::gmm_accumulators_, and KALDI_ASSERT.

32  {
33  KALDI_ASSERT(index >= 0 && index < static_cast<int32>(gmm_accumulators_.size()));
34  return *(gmm_accumulators_[index]);
35 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
void Init ( const AmDiagGmm model,
GmmFlagsType  flags 
)

Initializes accumulators for each GMM based on the number of components and dimension.

Definition at line 41 of file mle-am-diag-gmm.cc.

References kaldi::DeletePointers(), AmDiagGmm::GetPdf(), AccumAmDiagGmm::gmm_accumulators_, rnnlm::i, and AmDiagGmm::NumPdfs().

Referenced by kaldi::GetStatsDerivative(), main(), TestAmDiagGmmAccsIO(), and kaldi::UnitTestRegtreeFmllrDiagGmm().

42  {
43  DeletePointers(&gmm_accumulators_); // in case was non-empty when called.
44  gmm_accumulators_.resize(model.NumPdfs(), NULL);
45  for (int32 i = 0; i < model.NumPdfs(); i++) {
46  gmm_accumulators_[i] = new AccumDiagGmm();
47  gmm_accumulators_[i]->Resize(model.GetPdf(i), flags);
48  }
49 }
void DeletePointers(std::vector< A * > *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:198
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
void Init ( const AmDiagGmm model,
int32  dim,
GmmFlagsType  flags 
)

Initialization using different dimension than model.

Definition at line 51 of file mle-am-diag-gmm.cc.

References kaldi::DeletePointers(), AmDiagGmm::GetPdf(), AccumAmDiagGmm::gmm_accumulators_, rnnlm::i, KALDI_ASSERT, DiagGmm::NumGauss(), and AmDiagGmm::NumPdfs().

52  {
53  KALDI_ASSERT(dim > 0);
54  DeletePointers(&gmm_accumulators_); // in case was non-empty when called.
55  gmm_accumulators_.resize(model.NumPdfs(), NULL);
56  for (int32 i = 0; i < model.NumPdfs(); i++) {
57  gmm_accumulators_[i] = new AccumDiagGmm();
58  gmm_accumulators_[i]->Resize(model.GetPdf(i).NumGauss(),
59  dim, flags);
60  }
61 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void DeletePointers(std::vector< A * > *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:198
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
KALDI_DISALLOW_COPY_AND_ASSIGN ( AccumAmDiagGmm  )
private
int32 NumAccs ( ) const
inline

Definition at line 77 of file mle-am-diag-gmm.h.

References AccumAmDiagGmm::gmm_accumulators_.

77 { return gmm_accumulators_.size(); }
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
void Read ( std::istream &  in_stream,
bool  binary,
bool  add = false 
)

Definition at line 117 of file mle-am-diag-gmm.cc.

References kaldi::ExpectToken(), AccumAmDiagGmm::gmm_accumulators_, KALDI_ASSERT, KALDI_ERR, kaldi::ReadBasicType(), AccumAmDiagGmm::total_frames_, and AccumAmDiagGmm::total_log_like_.

Referenced by main(), and TestAmDiagGmmAccsIO().

118  {
119  int32 num_pdfs;
120  ExpectToken(in_stream, binary, "<NUMPDFS>");
121  ReadBasicType(in_stream, binary, &num_pdfs);
122  KALDI_ASSERT(num_pdfs > 0);
123  if (!add || (add && gmm_accumulators_.empty())) {
124  gmm_accumulators_.resize(num_pdfs, NULL);
125  for (std::vector<AccumDiagGmm*>::iterator it = gmm_accumulators_.begin(),
126  end = gmm_accumulators_.end(); it != end; ++it) {
127  delete *it;
128  *it = new AccumDiagGmm();
129  (*it)->Read(in_stream, binary, add);
130  }
131  } else {
132  if (gmm_accumulators_.size() != static_cast<size_t> (num_pdfs))
133  KALDI_ERR << "Adding accumulators but num-pdfs do not match: "
134  << (gmm_accumulators_.size()) << " vs. "
135  << (num_pdfs);
136  for (std::vector<AccumDiagGmm*>::iterator it = gmm_accumulators_.begin(),
137  end = gmm_accumulators_.end(); it != end; ++it)
138  (*it)->Read(in_stream, binary, add);
139  }
140  // TODO(arnab): Bad hack! Need to make this self-delimiting.
141  in_stream.peek(); // This will set the EOF bit for older accs.
142  if (!in_stream.eof()) {
143  double like, frames;
144  ExpectToken(in_stream, binary, "<total_like>");
145  ReadBasicType(in_stream, binary, &like);
146  total_log_like_ = (add)? total_log_like_ + like : like;
147  ExpectToken(in_stream, binary, "<total_frames>");
148  ReadBasicType(in_stream, binary, &frames);
149  total_frames_ = (add)? total_frames_ + frames : frames;
150  }
151 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
double total_frames_
Total counts & likelihood (for diagnostics)
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:188
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
void Scale ( BaseFloat  scale)

Definition at line 270 of file mle-am-diag-gmm.cc.

References AccumDiagGmm::Flags(), AccumAmDiagGmm::GetAcc(), rnnlm::i, AccumAmDiagGmm::NumAccs(), AccumDiagGmm::Scale(), AccumAmDiagGmm::total_frames_, and AccumAmDiagGmm::total_log_like_.

Referenced by main().

270  {
271  for (int32 i = 0; i < NumAccs(); i++) {
272  AccumDiagGmm &acc = GetAcc(i);
273  acc.Scale(scale, acc.Flags());
274  }
275  total_frames_ *= scale;
276  total_log_like_ *= scale;
277 }
const AccumDiagGmm & GetAcc(int32 index) const
double total_frames_
Total counts & likelihood (for diagnostics)
void SetZero ( GmmFlagsType  flags)

Definition at line 63 of file mle-am-diag-gmm.cc.

References AccumAmDiagGmm::gmm_accumulators_, and rnnlm::i.

Referenced by main().

63  {
64  for (size_t i = 0; i < gmm_accumulators_.size(); i++) {
65  gmm_accumulators_[i]->SetZero(flags);
66  }
67 }
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
BaseFloat TotCount ( ) const
inline

Definition at line 83 of file mle-am-diag-gmm.h.

References AccumAmDiagGmm::total_frames_.

Referenced by main(), and TestAmDiagGmmAccsIO().

83 { return total_frames_; }
double total_frames_
Total counts & likelihood (for diagnostics)
BaseFloat TotLogLike ( ) const
inline

Definition at line 84 of file mle-am-diag-gmm.h.

References AccumAmDiagGmm::total_log_like_.

Referenced by main(), and TestAmDiagGmmAccsIO().

84 { return total_log_like_; }
BaseFloat TotStatsCount ( ) const

Definition at line 261 of file mle-am-diag-gmm.cc.

References AccumAmDiagGmm::GetAcc(), rnnlm::i, AccumAmDiagGmm::NumAccs(), and AccumDiagGmm::occupancy().

Referenced by main().

261  {
262  double ans = 0.0;
263  for (int32 i = 0; i < NumAccs(); i++) {
264  const AccumDiagGmm &acc = GetAcc(i);
265  ans += acc.occupancy().Sum();
266  }
267  return ans;
268 }
const AccumDiagGmm & GetAcc(int32 index) const
void Write ( std::ostream &  out_stream,
bool  binary 
) const

Definition at line 153 of file mle-am-diag-gmm.cc.

References AccumAmDiagGmm::gmm_accumulators_, AccumAmDiagGmm::total_frames_, AccumAmDiagGmm::total_log_like_, kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by main(), and TestAmDiagGmmAccsIO().

153  {
154  int32 num_pdfs = gmm_accumulators_.size();
155  WriteToken(out_stream, binary, "<NUMPDFS>");
156  WriteBasicType(out_stream, binary, num_pdfs);
157  for (std::vector<AccumDiagGmm*>::const_iterator it =
158  gmm_accumulators_.begin(), end = gmm_accumulators_.end(); it != end; ++it) {
159  (*it)->Write(out_stream, binary);
160  }
161  WriteToken(out_stream, binary, "<total_like>");
162  WriteBasicType(out_stream, binary, total_log_like_);
163 
164  WriteToken(out_stream, binary, "<total_frames>");
165  WriteBasicType(out_stream, binary, total_frames_);
166 }
double total_frames_
Total counts & likelihood (for diagnostics)
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.

Member Data Documentation


The documentation for this class was generated from the following files: