clusterable-classes.cc
Go to the documentation of this file.
1 // tree/clusterable-classes.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation; Saarland University
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <algorithm>
21 #include <string>
22 #include "base/kaldi-math.h"
23 #include "itf/clusterable-itf.h"
25 
26 namespace kaldi {
27 
28 // ============================================================================
29 // Implementations common to all Clusterable classes (may be overridden for
30 // speed).
31 // ============================================================================
32 
34  Clusterable *copy = this->Copy();
35  copy->Add(other);
36  BaseFloat ans = copy->Objf();
37  delete copy;
38  return ans;
39 }
40 
42  Clusterable *copy = this->Copy();
43  copy->Sub(other);
44  BaseFloat ans = copy->Objf();
45  delete copy;
46  return ans;
47 }
48 
50  Clusterable *copy = this->Copy();
51  copy->Add(other);
52  BaseFloat ans = this->Objf() + other.Objf() - copy->Objf();
53  if (ans < 0) {
54  // This should not happen. Check if it is more than just rounding error.
55  if (std::fabs(ans) > 0.01 * (1.0 + std::fabs(copy->Objf()))) {
56  KALDI_WARN << "Negative number returned (badly defined Clusterable "
57  << "class?): ans= " << ans;
58  }
59  ans = 0;
60  }
61  delete copy;
62  return ans;
63 }
64 
65 // ============================================================================
66 // Implementation of ScalarClusterable class.
67 // ============================================================================
68 
70  if (count_ == 0) {
71  return 0;
72  } else {
73  KALDI_ASSERT(count_ > 0);
74  return -(x2_ - x_ * x_ / count_);
75  }
76 }
77 
78 void ScalarClusterable::Add(const Clusterable &other_in) {
79  KALDI_ASSERT(other_in.Type() == "scalar");
80  const ScalarClusterable *other =
81  static_cast<const ScalarClusterable*>(&other_in);
82  x_ += other->x_;
83  x2_ += other->x2_;
84  count_ += other->count_;
85 }
86 
87 void ScalarClusterable::Sub(const Clusterable &other_in) {
88  KALDI_ASSERT(other_in.Type() == "scalar");
89  const ScalarClusterable *other =
90  static_cast<const ScalarClusterable*>(&other_in);
91  x_ -= other->x_;
92  x2_ -= other->x2_;
93  count_ -= other->count_;
94 }
95 
98  ans->Add(*this);
99  return ans;
100 }
101 
102 void ScalarClusterable::Write(std::ostream &os, bool binary) const {
103  WriteToken(os, binary, "SCL"); // magic string.
104  WriteBasicType(os, binary, x_);
105  WriteBasicType(os, binary, x2_);
106  WriteBasicType(os, binary, count_);
107 }
108 
109 Clusterable* ScalarClusterable::ReadNew(std::istream &is, bool binary) const {
111  sc->Read(is, binary);
112  return sc;
113 }
114 
115 void ScalarClusterable::Read(std::istream &is, bool binary) {
116  ExpectToken(is, binary, "SCL");
117  ReadBasicType(is, binary, &x_);
118  ReadBasicType(is, binary, &x2_);
119  ReadBasicType(is, binary, &count_);
120 }
121 
122 std::string ScalarClusterable::Info() {
123  std::stringstream str;
124  if (count_ == 0) {
125  str << "[empty]";
126  } else {
127  str << "[mean " << (x_ / count_) << ", var " << (x2_ / count_ -
128  (x_ * x_ / (count_ * count_))) << "]";
129  }
130  return str.str();
131 }
132 
133 // ============================================================================
134 // Implementation of GaussClusterable class.
135 // ============================================================================
136 
138  BaseFloat weight) {
139  count_ += weight;
140  stats_.Row(0).AddVec(weight, vec);
141  stats_.Row(1).AddVec2(weight, vec);
142 }
143 
144 void GaussClusterable::Add(const Clusterable &other_in) {
145  KALDI_ASSERT(other_in.Type() == "gauss");
146  const GaussClusterable *other =
147  static_cast<const GaussClusterable*>(&other_in);
148  count_ += other->count_;
149  stats_.AddMat(1.0, other->stats_);
150 }
151 
152 void GaussClusterable::Sub(const Clusterable &other_in) {
153  KALDI_ASSERT(other_in.Type() == "gauss");
154  const GaussClusterable *other =
155  static_cast<const GaussClusterable*>(&other_in);
156  count_ -= other->count_;
157  stats_.AddMat(-1.0, other->stats_);
158 }
159 
161  KALDI_ASSERT(stats_.NumRows() == 2);
162  GaussClusterable *ans = new GaussClusterable(stats_.NumCols(), var_floor_);
163  ans->Add(*this);
164  return ans;
165 }
166 
168  KALDI_ASSERT(f >= 0.0);
169  count_ *= f;
170  stats_.Scale(f);
171 }
172 
173 void GaussClusterable::Write(std::ostream &os, bool binary) const {
174  WriteToken(os, binary, "GCL"); // magic string.
175  WriteBasicType(os, binary, count_);
176  WriteBasicType(os, binary, var_floor_);
177  stats_.Write(os, binary);
178 }
179 
180 Clusterable* GaussClusterable::ReadNew(std::istream &is, bool binary) const {
182  gc->Read(is, binary);
183  return gc;
184 }
185 
186 void GaussClusterable::Read(std::istream &is, bool binary) {
187  ExpectToken(is, binary, "GCL"); // magic string.
188  ReadBasicType(is, binary, &count_);
189  ReadBasicType(is, binary, &var_floor_);
190  stats_.Read(is, binary);
191 }
192 
194  if (count_ <= 0.0) {
195  if (count_ < -0.1) {
196  KALDI_WARN << "GaussClusterable::Objf(), count is negative " << count_;
197  }
198  return 0.0;
199  } else {
200  size_t dim = stats_.NumCols();
201  Vector<double> vars(dim);
202  double objf_per_frame = 0.0;
203  for (size_t d = 0; d < dim; d++) {
204  double mean(stats_(0, d) / count_), var = stats_(1, d) / count_ - mean
205  * mean, floored_var = std::max(var, var_floor_);
206  vars(d) = floored_var;
207  objf_per_frame += -0.5 * var / floored_var;
208  }
209  objf_per_frame += -0.5 * (vars.SumLog() + M_LOG_2PI * dim);
210  if (KALDI_ISNAN(objf_per_frame)) {
211  KALDI_WARN << "GaussClusterable::Objf(), objf is NaN";
212  return 0.0;
213  }
214  // KALDI_VLOG(2) << "count = " << count_ << ", objf_per_frame = "<< objf_per_frame
215  // << ", returning " << (objf_per_frame*count_) << ", floor = " << var_floor_;
216  return objf_per_frame * count_;
217  }
218 }
219 
220 
221 // ============================================================================
222 // Implementation of VectorClusterable class.
223 // ============================================================================
224 
225 void VectorClusterable::Add(const Clusterable &other_in) {
226  KALDI_ASSERT(other_in.Type() == "vector");
227  const VectorClusterable *other =
228  static_cast<const VectorClusterable*>(&other_in);
229  weight_ += other->weight_;
230  stats_.AddVec(1.0, other->stats_);
231  sumsq_ += other->sumsq_;
232 }
233 
234 void VectorClusterable::Sub(const Clusterable &other_in) {
235  KALDI_ASSERT(other_in.Type() == "vector");
236  const VectorClusterable *other =
237  static_cast<const VectorClusterable*>(&other_in);
238  weight_ -= other->weight_;
239  sumsq_ -= other->sumsq_;
240  stats_.AddVec(-1.0, other->stats_);
241  if (weight_ < 0.0) {
242  if (weight_ < -0.1 && weight_ < -0.0001 * fabs(other->weight_)) {
243  // a negative weight may indicate an algorithmic error if it is
244  // encountered.
245  KALDI_WARN << "Negative weight encountered " << weight_;
246  }
247  weight_ = 0.0;
248  }
249  if (weight_ == 0.0) {
250  sumsq_ = 0.0;
251  stats_.Set(0.0);
252  }
253 }
254 
257  ans->weight_ = weight_;
258  ans->sumsq_ = sumsq_;
259  ans->stats_ = stats_;
260  return ans;
261 }
262 
264  KALDI_ASSERT(f >= 0.0);
265  weight_ *= f;
266  stats_.Scale(f);
267  sumsq_ *= f;
268 }
269 
270 void VectorClusterable::Write(std::ostream &os, bool binary) const {
271  WriteToken(os, binary, "VCL"); // magic string.
272  WriteToken(os, binary, "<Weight>");
273  WriteBasicType(os, binary, weight_);
274  WriteToken(os, binary, "<Sumsq>");
275  WriteBasicType(os, binary, sumsq_);
276  WriteToken(os, binary, "<Stats>");
277  stats_.Write(os, binary);
278 }
279 
280 Clusterable* VectorClusterable::ReadNew(std::istream &is, bool binary) const {
282  vc->Read(is, binary);
283  return vc;
284 }
285 
286 void VectorClusterable::Read(std::istream &is, bool binary) {
287  ExpectToken(is, binary, "VCL"); // magic string.
288  ExpectToken(is, binary, "<Weight>");
289  ReadBasicType(is, binary, &weight_);
290  ExpectToken(is, binary, "<Sumsq>");
291  ReadBasicType(is, binary, &sumsq_);
292  ExpectToken(is, binary, "<Stats>");
293  stats_.Read(is, binary);
294 }
295 
297  BaseFloat weight):
298  weight_(weight), stats_(vector), sumsq_(0.0) {
299  stats_.Scale(weight);
300  KALDI_ASSERT(weight >= 0.0);
301  sumsq_ = VecVec(vector, vector) * weight;
302 }
303 
304 
306  double direct_sumsq;
307  if (weight_ > std::numeric_limits<BaseFloat>::min()) {
308  direct_sumsq = VecVec(stats_, stats_) / weight_;
309  } else {
310  direct_sumsq = 0.0;
311  }
312  // ans is a negated weighted sum of squared distances; it should not be
313  // positive.
314  double ans = -(sumsq_ - direct_sumsq);
315  if (ans > 0.0) {
316  if (ans > 1.0) {
317  KALDI_WARN << "Positive objective function encountered (treating as zero): "
318  << ans;
319  }
320  ans = 0.0;
321  }
322  return ans;
323 }
324 
325 
326 } // end namespace kaldi.
void AddStats(const VectorBase< BaseFloat > &vec, BaseFloat weight=1.0)
virtual void Sub(const Clusterable &other)=0
Subtract other stats.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
virtual void Add(const Clusterable &other)=0
Add other stats.
virtual std::string Type() const =0
Return a string that describes the inherited type.
#define M_LOG_2PI
Definition: kaldi-math.h:60
virtual void Write(std::ostream &os, bool binary) const
Write data to stream.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
virtual BaseFloat Objf() const =0
Return the objective function associated with the stats [assuming ML estimation]. ...
virtual BaseFloat ObjfMinus(const Clusterable &other) const
Return the objective function of the subtracted object this - other.
Real SumLog() const
Returns sum of the logs of the elements.
virtual void Scale(BaseFloat f)
Scale the stats by a positive number f [not mandatory to supply this].
virtual Clusterable * ReadNew(std::istream &is, bool binary) const
Read data from a stream and return the corresponding object (const function; it&#39;s a class member beca...
virtual void Sub(const Clusterable &other_in)
Subtract other stats.
VectorClusterable wraps vectors in a form accessible to generic clustering algorithms.
virtual void Add(const Clusterable &other_in)
Add other stats.
virtual Clusterable * Copy() const
Return a copy of this object.
virtual void Write(std::ostream &os, bool binary) const
Write data to stream.
virtual Clusterable * Copy() const =0
Return a copy of this object.
virtual void Add(const Clusterable &other_in)
Add other stats.
virtual BaseFloat Distance(const Clusterable &other) const
Return the objective function decrease from merging the two clusters, negated to be a positive number...
virtual Clusterable * Copy() const
Return a copy of this object.
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
virtual Clusterable * Copy() const
Return a copy of this object.
virtual void Add(const Clusterable &other_in)
Add other stats.
#define KALDI_WARN
Definition: kaldi-error.h:150
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
virtual Clusterable * ReadNew(std::istream &is, bool binary) const
Read data from a stream and return the corresponding object (const function; it&#39;s a class member beca...
void Scale(Real alpha)
Multiplies all elements by this constant.
virtual BaseFloat Objf() const
Return the objective function associated with the stats [assuming ML estimation]. ...
virtual BaseFloat ObjfPlus(const Clusterable &other) const
Return the objective function of the combined object this + other.
virtual void Write(std::ostream &os, bool binary) const
Write data to stream.
virtual Clusterable * ReadNew(std::istream &is, bool binary) const
Read data from a stream and return the corresponding object (const function; it&#39;s a class member beca...
#define KALDI_ISNAN
Definition: kaldi-math.h:72
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual BaseFloat Objf() const
Return the objective function associated with the stats [assuming ML estimation]. ...
virtual BaseFloat Objf() const
Return the objective function associated with the stats [assuming ML estimation]. ...
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
virtual void Sub(const Clusterable &other_in)
Subtract other stats.
void Read(std::istream &is, bool binary)
void Read(std::istream &is, bool binary)
virtual void Scale(BaseFloat f)
Scale the stats by a positive number f [not mandatory to supply this].
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:37
void Read(std::istream &is, bool binary)
virtual void Sub(const Clusterable &other_in)
Subtract other stats.
ScalarClusterable clusters scalars with x^2 loss.