clusterable-classes.h
Go to the documentation of this file.
1 // tree/clusterable-classes.h
2 
3 // Copyright 2009-2011 Microsoft Corporation; Saarland University
4 // 2014 Daniel Povey
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_TREE_CLUSTERABLE_CLASSES_H_
22 #define KALDI_TREE_CLUSTERABLE_CLASSES_H_ 1
23 
24 #include <string>
25 #include "itf/clusterable-itf.h"
26 #include "matrix/matrix-lib.h"
27 
28 namespace kaldi {
29 
32 
35  public:
36  ScalarClusterable(): x_(0), x2_(0), count_(0) {}
37  explicit ScalarClusterable(BaseFloat x): x_(x), x2_(x*x), count_(1) {}
38  virtual std::string Type() const { return "scalar"; }
39  virtual BaseFloat Objf() const;
40  virtual void SetZero() { count_ = x_ = x2_ = 0.0; }
41  virtual void Add(const Clusterable &other_in);
42  virtual void Sub(const Clusterable &other_in);
43  virtual Clusterable* Copy() const;
44  virtual BaseFloat Normalizer() const {
45  return static_cast<BaseFloat>(count_);
46  }
47 
48  // Function to write data to stream. Will organize input later [more complex]
49  virtual void Write(std::ostream &os, bool binary) const;
50  virtual Clusterable* ReadNew(std::istream &is, bool binary) const;
51 
52  std::string Info(); // For debugging.
53  BaseFloat Mean() { return (count_ != 0 ? x_/count_ : 0.0); }
54  private:
58 
59  void Read(std::istream &is, bool binary);
60 };
61 
62 
66  public:
67  GaussClusterable(): count_(0.0), var_floor_(0.0) {}
68  GaussClusterable(int32 dim, BaseFloat var_floor):
69  count_(0.0), stats_(2, dim), var_floor_(var_floor) {}
70 
71  GaussClusterable(const Vector<BaseFloat> &x_stats,
72  const Vector<BaseFloat> &x2_stats,
73  BaseFloat var_floor, BaseFloat count);
74 
75  virtual std::string Type() const { return "gauss"; }
76  void AddStats(const VectorBase<BaseFloat> &vec, BaseFloat weight = 1.0);
77  virtual BaseFloat Objf() const;
78  virtual void SetZero();
79  virtual void Add(const Clusterable &other_in);
80  virtual void Sub(const Clusterable &other_in);
81  virtual BaseFloat Normalizer() const { return count_; }
82  virtual Clusterable *Copy() const;
83  virtual void Scale(BaseFloat f);
84  virtual void Write(std::ostream &os, bool binary) const;
85  virtual Clusterable *ReadNew(std::istream &is, bool binary) const;
86  virtual ~GaussClusterable() {}
87 
88  BaseFloat count() const { return count_; }
89  // The next two functions are not const-correct, because of SubVector.
90  SubVector<double> x_stats() const { return stats_.Row(0); }
91  SubVector<double> x2_stats() const { return stats_.Row(1); }
92  private:
93  double count_;
94  Matrix<double> stats_; // two rows: sum, then sum-squared.
95  double var_floor_; // should be common for all objects created.
96 
97  void Read(std::istream &is, bool binary);
98 };
99 
101 
103  count_ = 0;
104  stats_.SetZero();
105 }
106 
108  const Vector<BaseFloat> &x2_stats,
109  BaseFloat var_floor, BaseFloat count):
110  count_(count), stats_(2, x_stats.Dim()), var_floor_(var_floor) {
111  stats_.Row(0).CopyFromVec(x_stats);
112  stats_.Row(1).CopyFromVec(x2_stats);
113 }
114 
115 
122  public:
123  VectorClusterable(): weight_(0.0), sumsq_(0.0) {}
124 
125  VectorClusterable(const Vector<BaseFloat> &vector,
126  BaseFloat weight);
127 
128  virtual std::string Type() const { return "vector"; }
129  // Objf is negated weighted sum of squared distances.
130  virtual BaseFloat Objf() const;
131  virtual void SetZero() { weight_ = 0.0; sumsq_ = 0.0; stats_.Set(0.0); }
132  virtual void Add(const Clusterable &other_in);
133  virtual void Sub(const Clusterable &other_in);
134  virtual BaseFloat Normalizer() const { return weight_; }
135  virtual Clusterable *Copy() const;
136  virtual void Scale(BaseFloat f);
137  virtual void Write(std::ostream &os, bool binary) const;
138  virtual Clusterable *ReadNew(std::istream &is, bool binary) const;
139  virtual ~VectorClusterable() {}
140 
141  private:
142  double weight_; // sum of weights of the source vectors. Never negative.
143  Vector<double> stats_; // Equals the weighted sum of the source vectors.
144  double sumsq_; // Equals the sum over all sources, of weight_ * vec.vec,
145  // where vec = stats_ / weight_. Used in computing
146  // the objective function.
147  void Read(std::istream &is, bool binary);
148 };
149 
150 
151 
152 } // end namespace kaldi.
153 
154 #endif // KALDI_TREE_CLUSTERABLE_CLASSES_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
virtual void SetZero()
Set stats to empty.
virtual void Write(std::ostream &os, bool binary) const
Write data to stream.
virtual BaseFloat Normalizer() const
Return the normalizer (typically, count) associated with the stats.
virtual void SetZero()
Set stats to empty.
virtual void Scale(BaseFloat f)
Scale the stats by a positive number f [not mandatory to supply this].
virtual Clusterable * ReadNew(std::istream &is, bool binary) const
Read data from a stream and return the corresponding object (const function; it&#39;s a class member beca...
kaldi::int32 int32
VectorClusterable wraps vectors in a form accessible to generic clustering algorithms.
virtual void Add(const Clusterable &other_in)
Add other stats.
virtual Clusterable * Copy() const
Return a copy of this object.
GaussClusterable(int32 dim, BaseFloat var_floor)
virtual BaseFloat Normalizer() const
Return the normalizer (typically, count) associated with the stats.
virtual BaseFloat Normalizer() const
Return the normalizer (typically, count) associated with the stats.
const size_t count
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
virtual std::string Type() const
Return a string that describes the inherited type.
virtual Clusterable * Copy() const
Return a copy of this object.
virtual void Add(const Clusterable &other_in)
Add other stats.
virtual Clusterable * ReadNew(std::istream &is, bool binary) const
Read data from a stream and return the corresponding object (const function; it&#39;s a class member beca...
virtual BaseFloat Objf() const
Return the objective function associated with the stats [assuming ML estimation]. ...
virtual void Scale(BaseFloat f)
Scale the stats by a positive number f [not mandatory to supply this].
virtual void Write(std::ostream &os, bool binary) const
Write data to stream.
A class representing a vector.
Definition: kaldi-vector.h:406
virtual void SetZero()
Set stats to empty.
SubVector< double > x2_stats() const
virtual BaseFloat Objf() const
Return the objective function associated with the stats [assuming ML estimation]. ...
virtual std::string Type() const
Return a string that describes the inherited type.
virtual void Sub(const Clusterable &other_in)
Subtract other stats.
void Read(std::istream &is, bool binary)
virtual std::string Type() const
Return a string that describes the inherited type.
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
SubVector< double > x_stats() const
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
void Read(std::istream &is, bool binary)
void Set(Real)
Sets all elements to a specific value.
virtual void Sub(const Clusterable &other_in)
Subtract other stats.
ScalarClusterable clusters scalars with x^2 loss.