diag-gmm.h
Go to the documentation of this file.
1 // gmm/diag-gmm.h
2 
3 // Copyright 2009-2011 Microsoft Corporation;
4 // Saarland University (Author: Arnab Ghoshal);
5 // Georg Stemmer; Jan Silovsky
6 // 2012 Arnab Ghoshal
7 // 2013-2014 Johns Hopkins University (author: Daniel Povey)
8 
9 // See ../../COPYING for clarification regarding multiple authors
10 //
11 // Licensed under the Apache License, Version 2.0 (the "License");
12 // you may not use this file except in compliance with the License.
13 // You may obtain a copy of the License at
14 //
15 // http://www.apache.org/licenses/LICENSE-2.0
16 //
17 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
18 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
19 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
20 // MERCHANTABLITY OR NON-INFRINGEMENT.
21 // See the Apache 2 License for the specific language governing permissions and
22 // limitations under the License.
23 
24 #ifndef KALDI_GMM_DIAG_GMM_H_
25 #define KALDI_GMM_DIAG_GMM_H_ 1
26 
27 #include <utility>
28 #include <vector>
29 
30 #include "base/kaldi-common.h"
31 #include "gmm/model-common.h"
32 #include "matrix/matrix-lib.h"
33 #include "tree/cluster-utils.h"
35 
36 namespace kaldi {
37 
38 class FullGmm;
39 class DiagGmmNormal;
40 
42 class DiagGmm {
44  friend class DiagGmmNormal;
45 
46  public:
48  DiagGmm() : valid_gconsts_(false) { }
49 
50  explicit DiagGmm(const DiagGmm &gmm): valid_gconsts_(false) {
51  CopyFromDiagGmm(gmm);
52  }
53 
56  DiagGmm(const GaussClusterable &gc, BaseFloat var_floor);
57 
59  void CopyFromNormal(const DiagGmmNormal &diag_gmm_normal);
60 
61  DiagGmm(int32 nMix, int32 dim): valid_gconsts_(false) { Resize(nMix, dim); }
62 
66  explicit DiagGmm(const std::vector<std::pair<BaseFloat, const DiagGmm*> > &gmms);
67 
69  void Resize(int32 nMix, int32 dim);
70 
72  int32 NumGauss() const { return weights_.Dim(); }
74  int32 Dim() const { return means_invvars_.NumCols(); }
75 
77  void CopyFromDiagGmm(const DiagGmm &diaggmm);
79  void CopyFromFullGmm(const FullGmm &fullgmm);
80 
83 
85  void LogLikelihoods(const VectorBase<BaseFloat> &data,
86  Vector<BaseFloat> *loglikes) const;
87 
91  void LogLikelihoods(const MatrixBase<BaseFloat> &data,
92  Matrix<BaseFloat> *loglikes) const;
93 
94 
100  const std::vector<int32> &indices,
101  Vector<BaseFloat> *loglikes) const;
102 
107  int32 num_gselect,
108  std::vector<int32> *output) const;
109 
114  int32 num_gselect,
115  std::vector<std::vector<int32> > *output) const;
116 
122  const std::vector<int32> &preselect,
123  int32 num_gselect,
124  std::vector<int32> *output) const;
125 
129  Vector<BaseFloat> *posteriors) const;
130 
135  int32 comp_id) const;
136 
140 
142  void Generate(VectorBase<BaseFloat> *output);
143 
146  void Split(int32 target_components, float perturb_factor,
147  std::vector<int32> *history = NULL);
148 
151  void Perturb(float perturb_factor);
152 
155  void Merge(int32 target_components, std::vector<int32> *history = NULL);
156 
157  // Merge the components to a specified target #components: this
158  // version uses a different approach based on K-means.
159  void MergeKmeans(int32 target_components,
161 
162  void Write(std::ostream &os, bool binary) const;
163  void Read(std::istream &in, bool binary);
164 
166  void Interpolate(BaseFloat rho, const DiagGmm &source,
167  GmmFlagsType flags = kGmmAll);
168 
170  void Interpolate(BaseFloat rho, const FullGmm &source,
171  GmmFlagsType flags = kGmmAll);
172 
174  const Vector<BaseFloat> &gconsts() const {
176  return gconsts_;
177  }
178  const Vector<BaseFloat> &weights() const { return weights_; }
179  const Matrix<BaseFloat> &means_invvars() const { return means_invvars_; }
180  const Matrix<BaseFloat> &inv_vars() const { return inv_vars_; }
181  bool valid_gconsts() const { return valid_gconsts_; }
182 
184  void RemoveComponent(int32 gauss, bool renorm_weights);
185 
187  void RemoveComponents(const std::vector<int32> &gauss, bool renorm_weights);
188 
190  template<class Real>
191  void SetWeights(const VectorBase<Real> &w);
192 
194  template<class Real>
195  void SetMeans(const MatrixBase<Real> &m);
197  template<class Real>
198  void SetInvVarsAndMeans(const MatrixBase<Real> &invvars,
199  const MatrixBase<Real> &means);
201  template<class Real>
202  void SetInvVars(const MatrixBase<Real> &v);
203 
205  template<class Real>
206  void GetVars(Matrix<Real> *v) const;
208  template<class Real>
209  void GetMeans(Matrix<Real> *m) const;
210 
213  template<class Real>
214  void SetComponentMean(int32 gauss, const VectorBase<Real> &in);
217  template<class Real>
218  void SetComponentInvVar(int32 gauss, const VectorBase<Real> &in);
220  inline void SetComponentWeight(int32 gauss, BaseFloat weight);
221 
223  template<class Real>
224  void GetComponentMean(int32 gauss, VectorBase<Real> *out) const;
225 
227  template<class Real>
228  void GetComponentVariance(int32 gauss, VectorBase<Real> *out) const;
229 
230  private:
237 
238  // merged_components_logdet computes logdet for merged components
239  // f1, f2 are first-order stats (normalized by zero-order stats)
240  // s1, s2 are second-order stats (normalized by zero-order stats)
242  const VectorBase<BaseFloat> &f1,
243  const VectorBase<BaseFloat> &f2,
244  const VectorBase<BaseFloat> &s1,
245  const VectorBase<BaseFloat> &s2) const;
246 
247  private:
248  const DiagGmm &operator=(const DiagGmm &other); // Disallow assignment
249 };
250 
252 std::ostream &
253 operator << (std::ostream &os, const kaldi::DiagGmm &gmm);
255 std::istream &
256 operator >> (std::istream &is, kaldi::DiagGmm &gmm);
257 
258 } // End namespace kaldi
259 
260 #include "gmm/diag-gmm-inl.h" // templated functions.
261 
262 #endif // KALDI_GMM_DIAG_GMM_H_
std::ostream & operator<<(std::ostream &os, const MatrixBase< Real > &M)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 Dim() const
Returns the dimensionality of the Gaussian mean vectors.
Definition: diag-gmm.h:74
void CopyFromDiagGmm(const DiagGmm &diaggmm)
Copies from given DiagGmm.
Definition: diag-gmm.cc:83
void Perturb(float perturb_factor)
Perturbs the component means with a random vector multiplied by the pertrub factor.
Definition: diag-gmm.cc:215
void Interpolate(BaseFloat rho, const DiagGmm &source, GmmFlagsType flags=kGmmAll)
this = rho x source + (1-rho) x this
Definition: diag-gmm.cc:645
void SetInvVarsAndMeans(const MatrixBase< Real > &invvars, const MatrixBase< Real > &means)
Use SetInvVarsAndMeans if updating both means and (inverse) variances.
Definition: diag-gmm-inl.h:63
void Write(std::ostream &os, bool binary) const
Definition: diag-gmm.cc:705
void Merge(int32 target_components, std::vector< int32 > *history=NULL)
Merge the components and remember the order in which the components were merged (flat list of pairs) ...
Definition: diag-gmm.cc:295
void LogLikelihoodsPreselect(const VectorBase< BaseFloat > &data, const std::vector< int32 > &indices, Vector< BaseFloat > *loglikes) const
Outputs the per-component log-likelihoods of a subset of mixture components.
Definition: diag-gmm.cc:566
void Split(int32 target_components, float perturb_factor, std::vector< int32 > *history=NULL)
Split the components and remember the order in which the components were split.
Definition: diag-gmm.cc:154
Definition for Gaussian Mixture Model with diagonal covariances in normal mode: where the parameters ...
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
const Matrix< BaseFloat > & means_invvars() const
Definition: diag-gmm.h:179
Definition for Gaussian Mixture Model with full covariances.
Definition: full-gmm.h:40
void GetComponentMean(int32 gauss, VectorBase< Real > *out) const
Accessor for single component mean.
Definition: diag-gmm-inl.h:135
void Resize(int32 nMix, int32 dim)
Resizes arrays to this dim. Does not initialize data.
Definition: diag-gmm.cc:66
const Vector< BaseFloat > & gconsts() const
Const accessors.
Definition: diag-gmm.h:174
bool valid_gconsts() const
Definition: diag-gmm.h:181
int32 ComputeGconsts()
Sets the gconsts.
Definition: diag-gmm.cc:114
kaldi::int32 int32
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void SetMeans(const MatrixBase< Real > &m)
Use SetMeans to update only the Gaussian means (and not variances)
Definition: diag-gmm-inl.h:43
void SetComponentMean(int32 gauss, const VectorBase< Real > &in)
Mutators for single component, supports float or double Set mean for a single component - internally ...
Definition: diag-gmm-inl.h:52
void GetVars(Matrix< Real > *v) const
Accessor for covariances.
Definition: diag-gmm-inl.h:115
BaseFloat ComponentLogLikelihood(const VectorBase< BaseFloat > &data, int32 comp_id) const
Computes the log-likelihood of a data point given a single Gaussian component.
Definition: diag-gmm.cc:497
BaseFloat ComponentPosteriors(const VectorBase< BaseFloat > &data, Vector< BaseFloat > *posteriors) const
Computes the posterior probabilities of all Gaussian components given a data point.
Definition: diag-gmm.cc:601
void RemoveComponent(int32 gauss, bool renorm_weights)
Removes single component from model.
Definition: diag-gmm.cc:617
BaseFloat LogLikelihood(const VectorBase< BaseFloat > &data) const
Returns the log-likelihood of a data point (vector) given the GMM.
Definition: diag-gmm.cc:517
void GetMeans(Matrix< Real > *m) const
Accessor for means.
Definition: diag-gmm-inl.h:123
void RemoveComponents(const std::vector< int32 > &gauss, bool renorm_weights)
Removes multiple components from model; "gauss" must not have dups.
Definition: diag-gmm.cc:632
DiagGmm(const DiagGmm &gmm)
Definition: diag-gmm.h:50
bool valid_gconsts_
Recompute gconsts_ if false.
Definition: diag-gmm.h:233
Matrix< BaseFloat > inv_vars_
Inverted (diagonal) variances.
Definition: diag-gmm.h:235
const Vector< BaseFloat > & weights() const
Definition: diag-gmm.h:178
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
void LogLikelihoods(const VectorBase< BaseFloat > &data, Vector< BaseFloat > *loglikes) const
Outputs the per-component log-likelihoods.
Definition: diag-gmm.cc:528
BaseFloat GaussianSelectionPreselect(const VectorBase< BaseFloat > &data, const std::vector< int32 > &preselect, int32 num_gselect, std::vector< int32 > *output) const
Get gaussian selection information for one frame.
Definition: diag-gmm.cc:875
void MergeKmeans(int32 target_components, ClusterKMeansOptions cfg=ClusterKMeansOptions())
Definition: diag-gmm.cc:231
void SetInvVars(const MatrixBase< Real > &v)
Set the (inverse) variances and recompute means_invvars_.
Definition: diag-gmm-inl.h:78
std::istream & operator>>(std::istream &is, Matrix< Real > &M)
void GetComponentVariance(int32 gauss, VectorBase< Real > *out) const
Accessor for single component variance.
Definition: diag-gmm-inl.h:145
void CopyFromFullGmm(const FullGmm &fullgmm)
Copies from given FullGmm.
Definition: diag-gmm.cc:92
void Read(std::istream &in, bool binary)
Definition: diag-gmm.cc:728
Vector< BaseFloat > weights_
weights (not log).
Definition: diag-gmm.h:234
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
Vector< BaseFloat > gconsts_
Equals log(weight) - 0.5 * (log det(var) + mean*mean*inv(var))
Definition: diag-gmm.h:232
BaseFloat GaussianSelection(const VectorBase< BaseFloat > &data, int32 num_gselect, std::vector< int32 > *output) const
Get gaussian selection information for one frame.
Definition: diag-gmm.cc:765
DiagGmm(int32 nMix, int32 dim)
Definition: diag-gmm.h:61
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
void Generate(VectorBase< BaseFloat > *output)
Generates a random data-point from this distribution.
Definition: diag-gmm.cc:922
void SetComponentInvVar(int32 gauss, const VectorBase< Real > &in)
Set inv-var for single component (recommend to do this before setting the mean, if doing both...
Definition: diag-gmm-inl.h:97
DiagGmm()
Empty constructor.
Definition: diag-gmm.h:48
void SetWeights(const VectorBase< Real > &w)
Mutators for both float or double.
Definition: diag-gmm-inl.h:28
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
void CopyFromNormal(const DiagGmmNormal &diag_gmm_normal)
Copies from DiagGmmNormal; does not resize.
Definition: diag-gmm.cc:918
Matrix< BaseFloat > means_invvars_
Means times inverted variance.
Definition: diag-gmm.h:236
void SetComponentWeight(int32 gauss, BaseFloat weight)
Set weight for single component.
Definition: diag-gmm-inl.h:34
const Matrix< BaseFloat > & inv_vars() const
Definition: diag-gmm.h:180
BaseFloat merged_components_logdet(BaseFloat w1, BaseFloat w2, const VectorBase< BaseFloat > &f1, const VectorBase< BaseFloat > &f2, const VectorBase< BaseFloat > &s1, const VectorBase< BaseFloat > &s2) const
Definition: diag-gmm.cc:471
const DiagGmm & operator=(const DiagGmm &other)