GaussClusterable Class Reference

GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms. More...

#include <clusterable-classes.h>

Inheritance diagram for GaussClusterable:
Collaboration diagram for GaussClusterable:

Public Member Functions

 GaussClusterable ()
 
 GaussClusterable (int32 dim, BaseFloat var_floor)
 
 GaussClusterable (const Vector< BaseFloat > &x_stats, const Vector< BaseFloat > &x2_stats, BaseFloat var_floor, BaseFloat count)
 
virtual std::string Type () const
 Return a string that describes the inherited type. More...
 
void AddStats (const VectorBase< BaseFloat > &vec, BaseFloat weight=1.0)
 
virtual BaseFloat Objf () const
 Return the objective function associated with the stats [assuming ML estimation]. More...
 
virtual void SetZero ()
 Set stats to empty. More...
 
virtual void Add (const Clusterable &other_in)
 Add other stats. More...
 
virtual void Sub (const Clusterable &other_in)
 Subtract other stats. More...
 
virtual BaseFloat Normalizer () const
 Return the normalizer (typically, count) associated with the stats. More...
 
virtual ClusterableCopy () const
 Return a copy of this object. More...
 
virtual void Scale (BaseFloat f)
 Scale the stats by a positive number f [not mandatory to supply this]. More...
 
virtual void Write (std::ostream &os, bool binary) const
 Write data to stream. More...
 
virtual ClusterableReadNew (std::istream &is, bool binary) const
 Read data from a stream and return the corresponding object (const function; it's a class member because we need access to the vtable so generic code can read derived types). More...
 
virtual ~GaussClusterable ()
 
BaseFloat count () const
 
SubVector< double > x_stats () const
 
SubVector< double > x2_stats () const
 
- Public Member Functions inherited from Clusterable
virtual ~Clusterable ()
 
virtual BaseFloat ObjfPlus (const Clusterable &other) const
 Return the objective function of the combined object this + other. More...
 
virtual BaseFloat ObjfMinus (const Clusterable &other) const
 Return the objective function of the subtracted object this - other. More...
 
virtual BaseFloat Distance (const Clusterable &other) const
 Return the objective function decrease from merging the two clusters, negated to be a positive number (or zero). More...
 

Private Member Functions

void Read (std::istream &is, bool binary)
 

Private Attributes

double count_
 
Matrix< double > stats_
 
double var_floor_
 

Detailed Description

GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms.

Definition at line 65 of file clusterable-classes.h.

Constructor & Destructor Documentation

◆ GaussClusterable() [1/3]

GaussClusterable ( )
inline

Definition at line 67 of file clusterable-classes.h.

◆ GaussClusterable() [2/3]

GaussClusterable ( int32  dim,
BaseFloat  var_floor 
)
inline

Definition at line 68 of file clusterable-classes.h.

References count.

68  :
69  count_(0.0), stats_(2, dim), var_floor_(var_floor) {}

◆ GaussClusterable() [3/3]

GaussClusterable ( const Vector< BaseFloat > &  x_stats,
const Vector< BaseFloat > &  x2_stats,
BaseFloat  var_floor,
BaseFloat  count 
)
inline

Definition at line 107 of file clusterable-classes.h.

References MatrixBase< Real >::Row(), and GaussClusterable::stats_.

109  :
110  count_(count), stats_(2, x_stats.Dim()), var_floor_(var_floor) {
111  stats_.Row(0).CopyFromVec(x_stats);
112  stats_.Row(1).CopyFromVec(x2_stats);
113 }
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
SubVector< double > x2_stats() const
SubVector< double > x_stats() const

◆ ~GaussClusterable()

virtual ~GaussClusterable ( )
inlinevirtual

Definition at line 86 of file clusterable-classes.h.

86 {}

Member Function Documentation

◆ Add()

void Add ( const Clusterable other)
virtual

Add other stats.

Implements Clusterable.

Definition at line 144 of file clusterable-classes.cc.

References GaussClusterable::count_, KALDI_ASSERT, and Clusterable::Type().

Referenced by GaussClusterable::Copy(), and VectorClusterable::SetZero().

144  {
145  KALDI_ASSERT(other_in.Type() == "gauss");
146  const GaussClusterable *other =
147  static_cast<const GaussClusterable*>(&other_in);
148  count_ += other->count_;
149  stats_.AddMat(1.0, other->stats_);
150 }
void AddMat(const Real alpha, const MatrixBase< Real > &M, MatrixTransposeType transA=kNoTrans)
*this += alpha * M [or M^T]
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ AddStats()

void AddStats ( const VectorBase< BaseFloat > &  vec,
BaseFloat  weight = 1.0 
)

Definition at line 137 of file clusterable-classes.cc.

Referenced by kaldi::GenRandStats().

138  {
139  count_ += weight;
140  stats_.Row(0).AddVec(weight, vec);
141  stats_.Row(1).AddVec2(weight, vec);
142 }
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188

◆ Copy()

Clusterable * Copy ( ) const
virtual

Return a copy of this object.

Implements Clusterable.

Definition at line 160 of file clusterable-classes.cc.

References GaussClusterable::Add(), and KALDI_ASSERT.

Referenced by VectorClusterable::Normalizer().

160  {
161  KALDI_ASSERT(stats_.NumRows() == 2);
163  ans->Add(*this);
164  return ans;
165 }
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64

◆ count()

◆ Normalizer()

virtual BaseFloat Normalizer ( ) const
inlinevirtual

Return the normalizer (typically, count) associated with the stats.

Implements Clusterable.

Definition at line 81 of file clusterable-classes.h.

References ScalarClusterable::Copy(), ScalarClusterable::count_, ScalarClusterable::ReadNew(), Clusterable::Scale(), and ScalarClusterable::Write().

81 { return count_; }

◆ Objf()

BaseFloat Objf ( ) const
virtual

Return the objective function associated with the stats [assuming ML estimation].

Implements Clusterable.

Definition at line 193 of file clusterable-classes.cc.

References rnnlm::d, KALDI_ISNAN, KALDI_WARN, M_LOG_2PI, and VectorBase< Real >::SumLog().

Referenced by VectorClusterable::Type().

193  {
194  if (count_ <= 0.0) {
195  if (count_ < -0.1) {
196  KALDI_WARN << "GaussClusterable::Objf(), count is negative " << count_;
197  }
198  return 0.0;
199  } else {
200  size_t dim = stats_.NumCols();
201  Vector<double> vars(dim);
202  double objf_per_frame = 0.0;
203  for (size_t d = 0; d < dim; d++) {
204  double mean(stats_(0, d) / count_), var = stats_(1, d) / count_ - mean
205  * mean, floored_var = std::max(var, var_floor_);
206  vars(d) = floored_var;
207  objf_per_frame += -0.5 * var / floored_var;
208  }
209  objf_per_frame += -0.5 * (vars.SumLog() + M_LOG_2PI * dim);
210  if (KALDI_ISNAN(objf_per_frame)) {
211  KALDI_WARN << "GaussClusterable::Objf(), objf is NaN";
212  return 0.0;
213  }
214  // KALDI_VLOG(2) << "count = " << count_ << ", objf_per_frame = "<< objf_per_frame
215  // << ", returning " << (objf_per_frame*count_) << ", floor = " << var_floor_;
216  return objf_per_frame * count_;
217  }
218 }
#define M_LOG_2PI
Definition: kaldi-math.h:60
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_ISNAN
Definition: kaldi-math.h:72

◆ Read()

void Read ( std::istream &  is,
bool  binary 
)
private

Definition at line 186 of file clusterable-classes.cc.

References kaldi::ExpectToken(), and kaldi::ReadBasicType().

Referenced by GaussClusterable::ReadNew().

186  {
187  ExpectToken(is, binary, "GCL"); // magic string.
188  ReadBasicType(is, binary, &count_);
189  ReadBasicType(is, binary, &var_floor_);
190  stats_.Read(is, binary);
191 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void Read(std::istream &in, bool binary, bool add=false)
read from stream.
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191

◆ ReadNew()

Clusterable * ReadNew ( std::istream &  os,
bool  binary 
) const
virtual

Read data from a stream and return the corresponding object (const function; it's a class member because we need access to the vtable so generic code can read derived types).

Implements Clusterable.

Definition at line 180 of file clusterable-classes.cc.

References GaussClusterable::Read().

Referenced by VectorClusterable::Normalizer().

180  {
182  gc->Read(is, binary);
183  return gc;
184 }

◆ Scale()

void Scale ( BaseFloat  f)
virtual

Scale the stats by a positive number f [not mandatory to supply this].

Reimplemented from Clusterable.

Definition at line 167 of file clusterable-classes.cc.

References KALDI_ASSERT.

Referenced by VectorClusterable::Normalizer().

167  {
168  KALDI_ASSERT(f >= 0.0);
169  count_ *= f;
170  stats_.Scale(f);
171 }
void Scale(Real alpha)
Multiply each element with a scalar value.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetZero()

void SetZero ( )
inlinevirtual

Set stats to empty.

Implements Clusterable.

Definition at line 102 of file clusterable-classes.h.

References ScalarClusterable::count_.

102  {
103  count_ = 0;
104  stats_.SetZero();
105 }
void SetZero()
Sets matrix to zero.

◆ Sub()

void Sub ( const Clusterable other)
virtual

Subtract other stats.

Implements Clusterable.

Definition at line 152 of file clusterable-classes.cc.

References GaussClusterable::count_, KALDI_ASSERT, and Clusterable::Type().

Referenced by VectorClusterable::SetZero().

152  {
153  KALDI_ASSERT(other_in.Type() == "gauss");
154  const GaussClusterable *other =
155  static_cast<const GaussClusterable*>(&other_in);
156  count_ -= other->count_;
157  stats_.AddMat(-1.0, other->stats_);
158 }
void AddMat(const Real alpha, const MatrixBase< Real > &M, MatrixTransposeType transA=kNoTrans)
*this += alpha * M [or M^T]
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Type()

virtual std::string Type ( ) const
inlinevirtual

Return a string that describes the inherited type.

Implements Clusterable.

Definition at line 75 of file clusterable-classes.h.

References ScalarClusterable::Add(), ScalarClusterable::Objf(), ScalarClusterable::SetZero(), and ScalarClusterable::Sub().

75 { return "gauss"; }

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write data to stream.

Implements Clusterable.

Definition at line 173 of file clusterable-classes.cc.

References kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by VectorClusterable::Normalizer().

173  {
174  WriteToken(os, binary, "GCL"); // magic string.
175  WriteBasicType(os, binary, count_);
176  WriteBasicType(os, binary, var_floor_);
177  stats_.Write(os, binary);
178 }
void Write(std::ostream &out, bool binary) const
write to stream.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

◆ x2_stats()

SubVector<double> x2_stats ( ) const
inline

Definition at line 91 of file clusterable-classes.h.

Referenced by kaldi::ClusterGaussiansToUbm(), DiagGmm::DiagGmm(), and DiagGmm::MergeKmeans().

91 { return stats_.Row(1); }
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188

◆ x_stats()

SubVector<double> x_stats ( ) const
inline

Definition at line 90 of file clusterable-classes.h.

Referenced by kaldi::ClusterGaussiansToUbm(), DiagGmm::DiagGmm(), and DiagGmm::MergeKmeans().

90 { return stats_.Row(0); }
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188

Member Data Documentation

◆ count_

double count_
private

Definition at line 93 of file clusterable-classes.h.

Referenced by GaussClusterable::Add(), and GaussClusterable::Sub().

◆ stats_

Matrix<double> stats_
private

◆ var_floor_

double var_floor_
private

Definition at line 95 of file clusterable-classes.h.


The documentation for this class was generated from the following files: