All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
GaussClusterable Class Reference

GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms. More...

#include <clusterable-classes.h>

Inheritance diagram for GaussClusterable:
Collaboration diagram for GaussClusterable:

Public Member Functions

 GaussClusterable ()
 
 GaussClusterable (int32 dim, BaseFloat var_floor)
 
 GaussClusterable (const Vector< BaseFloat > &x_stats, const Vector< BaseFloat > &x2_stats, BaseFloat var_floor, BaseFloat count)
 
virtual std::string Type () const
 Return a string that describes the inherited type. More...
 
void AddStats (const VectorBase< BaseFloat > &vec, BaseFloat weight=1.0)
 
virtual BaseFloat Objf () const
 Return the objective function associated with the stats [assuming ML estimation]. More...
 
virtual void SetZero ()
 Set stats to empty. More...
 
virtual void Add (const Clusterable &other_in)
 Add other stats. More...
 
virtual void Sub (const Clusterable &other_in)
 Subtract other stats. More...
 
virtual BaseFloat Normalizer () const
 Return the normalizer (typically, count) associated with the stats. More...
 
virtual ClusterableCopy () const
 Return a copy of this object. More...
 
virtual void Scale (BaseFloat f)
 Scale the stats by a positive number f [not mandatory to supply this]. More...
 
virtual void Write (std::ostream &os, bool binary) const
 Write data to stream. More...
 
virtual ClusterableReadNew (std::istream &is, bool binary) const
 Read data from a stream and return the corresponding object (const function; it's a class member because we need access to the vtable so generic code can read derived types). More...
 
virtual ~GaussClusterable ()
 
BaseFloat count () const
 
SubVector< double > x_stats () const
 
SubVector< double > x2_stats () const
 
- Public Member Functions inherited from Clusterable
virtual ~Clusterable ()
 
virtual BaseFloat ObjfPlus (const Clusterable &other) const
 Return the objective function of the combined object this + other. More...
 
virtual BaseFloat ObjfMinus (const Clusterable &other) const
 Return the objective function of the subtracted object this - other. More...
 
virtual BaseFloat Distance (const Clusterable &other) const
 Return the objective function decrease from merging the two clusters, negated to be a positive number (or zero). More...
 

Private Member Functions

void Read (std::istream &is, bool binary)
 

Private Attributes

double count_
 
Matrix< double > stats_
 
double var_floor_
 

Detailed Description

GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms.

Definition at line 65 of file clusterable-classes.h.

Constructor & Destructor Documentation

GaussClusterable ( int32  dim,
BaseFloat  var_floor 
)
inline

Definition at line 68 of file clusterable-classes.h.

68  :
69  count_(0.0), stats_(2, dim), var_floor_(var_floor) {}
GaussClusterable ( const Vector< BaseFloat > &  x_stats,
const Vector< BaseFloat > &  x2_stats,
BaseFloat  var_floor,
BaseFloat  count 
)
inline

Definition at line 107 of file clusterable-classes.h.

References MatrixBase< Real >::Row(), and GaussClusterable::stats_.

109  :
110  count_(count), stats_(2, x_stats.Dim()), var_floor_(var_floor) {
111  stats_.Row(0).CopyFromVec(x_stats);
112  stats_.Row(1).CopyFromVec(x2_stats);
113 }
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:182
BaseFloat count() const
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:62
virtual ~GaussClusterable ( )
inlinevirtual

Definition at line 86 of file clusterable-classes.h.

86 {}

Member Function Documentation

void Add ( const Clusterable other)
virtual

Add other stats.

Implements Clusterable.

Definition at line 144 of file clusterable-classes.cc.

References MatrixBase< Real >::AddMat(), GaussClusterable::count_, KALDI_ASSERT, GaussClusterable::stats_, and Clusterable::Type().

144  {
145  KALDI_ASSERT(other_in.Type() == "gauss");
146  const GaussClusterable *other =
147  static_cast<const GaussClusterable*>(&other_in);
148  count_ += other->count_;
149  stats_.AddMat(1.0, other->stats_);
150 }
void AddMat(const Real alpha, const MatrixBase< Real > &M, MatrixTransposeType transA=kNoTrans)
*this += alpha * M [or M^T]
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void AddStats ( const VectorBase< BaseFloat > &  vec,
BaseFloat  weight = 1.0 
)

Definition at line 137 of file clusterable-classes.cc.

References GaussClusterable::count_, MatrixBase< Real >::Row(), and GaussClusterable::stats_.

Referenced by kaldi::GenRandStats().

138  {
139  count_ += weight;
140  stats_.Row(0).AddVec(weight, vec);
141  stats_.Row(1).AddVec2(weight, vec);
142 }
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:182
Clusterable * Copy ( ) const
virtual

Return a copy of this object.

Implements Clusterable.

Definition at line 160 of file clusterable-classes.cc.

References GaussClusterable::GaussClusterable(), KALDI_ASSERT, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), GaussClusterable::stats_, and GaussClusterable::var_floor_.

160  {
161  KALDI_ASSERT(stats_.NumRows() == 2);
163  ans->Add(*this);
164  return ans;
165 }
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
BaseFloat count ( ) const
inline
virtual BaseFloat Normalizer ( ) const
inlinevirtual

Return the normalizer (typically, count) associated with the stats.

Implements Clusterable.

Definition at line 81 of file clusterable-classes.h.

References GaussClusterable::count_.

81 { return count_; }
BaseFloat Objf ( ) const
virtual

Return the objective function associated with the stats [assuming ML estimation].

Implements Clusterable.

Definition at line 193 of file clusterable-classes.cc.

References GaussClusterable::count_, rnnlm::d, KALDI_ISNAN, KALDI_WARN, M_LOG_2PI, MatrixBase< Real >::NumCols(), GaussClusterable::stats_, VectorBase< Real >::SumLog(), and GaussClusterable::var_floor_.

193  {
194  if (count_ <= 0.0) {
195  if (count_ < -0.1) {
196  KALDI_WARN << "GaussClusterable::Objf(), count is negative " << count_;
197  }
198  return 0.0;
199  } else {
200  size_t dim = stats_.NumCols();
201  Vector<double> vars(dim);
202  double objf_per_frame = 0.0;
203  for (size_t d = 0; d < dim; d++) {
204  double mean(stats_(0, d) / count_), var = stats_(1, d) / count_ - mean
205  * mean, floored_var = std::max(var, var_floor_);
206  vars(d) = floored_var;
207  objf_per_frame += -0.5 * var / floored_var;
208  }
209  objf_per_frame += -0.5 * (vars.SumLog() + M_LOG_2PI * dim);
210  if (KALDI_ISNAN(objf_per_frame)) {
211  KALDI_WARN << "GaussClusterable::Objf(), objf is NaN";
212  return 0.0;
213  }
214  // KALDI_VLOG(2) << "count = " << count_ << ", objf_per_frame = "<< objf_per_frame
215  // << ", returning " << (objf_per_frame*count_) << ", floor = " << var_floor_;
216  return objf_per_frame * count_;
217  }
218 }
#define M_LOG_2PI
Definition: kaldi-math.h:60
#define KALDI_WARN
Definition: kaldi-error.h:130
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
#define KALDI_ISNAN
Definition: kaldi-math.h:72
void Read ( std::istream &  is,
bool  binary 
)
private

Definition at line 186 of file clusterable-classes.cc.

References GaussClusterable::count_, kaldi::ExpectToken(), Matrix< Real >::Read(), kaldi::ReadBasicType(), GaussClusterable::stats_, and GaussClusterable::var_floor_.

Referenced by GaussClusterable::ReadNew().

186  {
187  ExpectToken(is, binary, "GCL"); // magic string.
188  ReadBasicType(is, binary, &count_);
189  ReadBasicType(is, binary, &var_floor_);
190  stats_.Read(is, binary);
191 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void Read(std::istream &in, bool binary, bool add=false)
read from stream.
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:188
Clusterable * ReadNew ( std::istream &  os,
bool  binary 
) const
virtual

Read data from a stream and return the corresponding object (const function; it's a class member because we need access to the vtable so generic code can read derived types).

Implements Clusterable.

Definition at line 180 of file clusterable-classes.cc.

References GaussClusterable::GaussClusterable(), and GaussClusterable::Read().

180  {
182  gc->Read(is, binary);
183  return gc;
184 }
void Scale ( BaseFloat  f)
virtual

Scale the stats by a positive number f [not mandatory to supply this].

Reimplemented from Clusterable.

Definition at line 167 of file clusterable-classes.cc.

References GaussClusterable::count_, KALDI_ASSERT, MatrixBase< Real >::Scale(), and GaussClusterable::stats_.

167  {
168  KALDI_ASSERT(f >= 0.0);
169  count_ *= f;
170  stats_.Scale(f);
171 }
void Scale(Real alpha)
Multiply each element with a scalar value.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void SetZero ( )
inlinevirtual

Set stats to empty.

Implements Clusterable.

Definition at line 102 of file clusterable-classes.h.

References GaussClusterable::count_, MatrixBase< Real >::SetZero(), and GaussClusterable::stats_.

102  {
103  count_ = 0;
104  stats_.SetZero();
105 }
void SetZero()
Sets matrix to zero.
void Sub ( const Clusterable other)
virtual

Subtract other stats.

Implements Clusterable.

Definition at line 152 of file clusterable-classes.cc.

References MatrixBase< Real >::AddMat(), GaussClusterable::count_, KALDI_ASSERT, GaussClusterable::stats_, and Clusterable::Type().

152  {
153  KALDI_ASSERT(other_in.Type() == "gauss");
154  const GaussClusterable *other =
155  static_cast<const GaussClusterable*>(&other_in);
156  count_ -= other->count_;
157  stats_.AddMat(-1.0, other->stats_);
158 }
void AddMat(const Real alpha, const MatrixBase< Real > &M, MatrixTransposeType transA=kNoTrans)
*this += alpha * M [or M^T]
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
virtual std::string Type ( ) const
inlinevirtual

Return a string that describes the inherited type.

Implements Clusterable.

Definition at line 75 of file clusterable-classes.h.

75 { return "gauss"; }
void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write data to stream.

Implements Clusterable.

Definition at line 173 of file clusterable-classes.cc.

References GaussClusterable::count_, GaussClusterable::stats_, GaussClusterable::var_floor_, MatrixBase< Real >::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().

173  {
174  WriteToken(os, binary, "GCL"); // magic string.
175  WriteBasicType(os, binary, count_);
176  WriteBasicType(os, binary, var_floor_);
177  stats_.Write(os, binary);
178 }
void Write(std::ostream &out, bool binary) const
write to stream.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
SubVector<double> x2_stats ( ) const
inline

Definition at line 91 of file clusterable-classes.h.

References MatrixBase< Real >::Row(), and GaussClusterable::stats_.

Referenced by kaldi::ClusterGaussiansToUbm(), DiagGmm::DiagGmm(), and DiagGmm::MergeKmeans().

91 { return stats_.Row(1); }
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:182
SubVector<double> x_stats ( ) const
inline

Definition at line 90 of file clusterable-classes.h.

References MatrixBase< Real >::Row(), and GaussClusterable::stats_.

Referenced by kaldi::ClusterGaussiansToUbm(), DiagGmm::DiagGmm(), and DiagGmm::MergeKmeans().

90 { return stats_.Row(0); }
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:182

Member Data Documentation

double var_floor_
private

The documentation for this class was generated from the following files: