mle-diag-gmm.h
Go to the documentation of this file.
1 // gmm/mle-diag-gmm.h
2 
3 // Copyright 2009-2012 Saarland University; Georg Stemmer;
4 // Microsoft Corporation; Jan Silovsky; Yanmin Qian
5 // Johns Hopkins University (author: Daniel Povey)
6 // Cisco Systems (author: Neha Agrawal)
7 
8 // See ../../COPYING for clarification regarding multiple authors
9 //
10 // Licensed under the Apache License, Version 2.0 (the "License");
11 // you may not use this file except in compliance with the License.
12 // You may obtain a copy of the License at
13 //
14 // http://www.apache.org/licenses/LICENSE-2.0
15 //
16 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
18 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
19 // MERCHANTABLITY OR NON-INFRINGEMENT.
20 // See the Apache 2 License for the specific language governing permissions and
21 // limitations under the License.
22 
23 
24 #ifndef KALDI_GMM_MLE_DIAG_GMM_H_
25 #define KALDI_GMM_MLE_DIAG_GMM_H_ 1
26 
27 #include "gmm/diag-gmm.h"
28 #include "gmm/diag-gmm-normal.h"
29 #include "gmm/model-common.h"
30 #include "itf/options-itf.h"
31 
32 namespace kaldi {
33 
50  double min_variance;
53  // don't set var floor vector by default.
54  min_gaussian_weight = 1.0e-05;
55  min_gaussian_occupancy = 10.0;
56  min_variance = 0.001;
57  remove_low_count_gaussians = true;
58  }
59  void Register(OptionsItf *opts) {
60  std::string module = "MleDiagGmmOptions: ";
61  opts->Register("min-gaussian-weight", &min_gaussian_weight,
62  module+"Min Gaussian weight before we remove it.");
63  opts->Register("min-gaussian-occupancy", &min_gaussian_occupancy,
64  module+"Minimum occupancy to update a Gaussian.");
65  opts->Register("min-variance", &min_variance,
66  module+"Variance floor (absolute variance).");
67  opts->Register("remove-low-count-gaussians", &remove_low_count_gaussians,
68  module+"If true, remove Gaussians that fall below the floors.");
69  }
70 };
71 
72 
79 
84 
88 
89  MapDiagGmmOptions(): mean_tau(10.0),
90  variance_tau(50.0),
91  weight_tau(10.0) { }
92 
93  void Register(OptionsItf *opts) {
94  opts->Register("mean-tau", &mean_tau,
95  "Tau value for updating means.");
96  opts->Register("variance-tau", &mean_tau,
97  "Tau value for updating variances (note: only relevant if "
98  "update-flags contains \"v\".");
99  opts->Register("weight-tau", &weight_tau,
100  "Tau value for updating weights.");
101  }
102 };
103 
104 
105 
107  public:
108  AccumDiagGmm(): dim_(0), num_comp_(0), flags_(0) { }
109  explicit AccumDiagGmm(const DiagGmm &gmm, GmmFlagsType flags) {
110  Resize(gmm, flags);
111  }
112  // provide copy constructor.
113  explicit AccumDiagGmm(const AccumDiagGmm &other);
114 
115  void Read(std::istream &in_stream, bool binary, bool add);
116  void Write(std::ostream &out_stream, bool binary) const;
117 
119  void Resize(int32 num_gauss, int32 dim, GmmFlagsType flags);
121  void Resize(const DiagGmm &gmm, GmmFlagsType flags);
122 
124  int32 NumGauss() const { return num_comp_; }
126  int32 Dim() const { return dim_; }
127 
128  void SetZero(GmmFlagsType flags);
129  void Scale(BaseFloat f, GmmFlagsType flags);
130 
132  void AccumulateForComponent(const VectorBase<BaseFloat> &data,
133  int32 comp_index, BaseFloat weight);
134 
136  void AccumulateFromPosteriors(const VectorBase<BaseFloat> &data,
137  const VectorBase<BaseFloat> &gauss_posteriors);
138 
141  BaseFloat AccumulateFromDiag(const DiagGmm &gmm,
142  const VectorBase<BaseFloat> &data,
143  BaseFloat frame_posterior);
144 
148  BaseFloat AccumulateFromDiagMultiThreaded(
149  const DiagGmm &gmm,
150  const MatrixBase<BaseFloat> &data,
151  const VectorBase<BaseFloat> &frame_weights,
152  int32 num_threads);
153 
154 
158  void AddStatsForComponent(int32 comp_id,
159  double occ,
160  const VectorBase<double> &x_stats,
161  const VectorBase<double> &x2_stats);
162 
164  void Add(double scale, const AccumDiagGmm &acc);
165 
168  void SmoothStats(BaseFloat tau);
169 
174  void SmoothWithAccum(BaseFloat tau, const AccumDiagGmm &src_acc);
175 
179  void SmoothWithModel(BaseFloat tau, const DiagGmm &src_gmm);
180 
181  // Const accessors
182  GmmFlagsType Flags() const { return flags_; }
183  const VectorBase<double> &occupancy() const { return occupancy_; }
184  const MatrixBase<double> &mean_accumulator() const { return mean_accumulator_; }
185  const MatrixBase<double> &variance_accumulator() const { return variance_accumulator_; }
186 
187  // used in testing.
188  void AssertEqual(const AccumDiagGmm &other);
189  private:
194 
198 };
199 
200 
204 
205 
206 inline void AccumDiagGmm::Resize(const DiagGmm &gmm, GmmFlagsType flags) {
207  Resize(gmm.NumGauss(), gmm.Dim(), flags);
208 }
209 
214 void MleDiagGmmUpdate(const MleDiagGmmOptions &config,
215  const AccumDiagGmm &diag_gmm_acc,
216  GmmFlagsType flags,
217  DiagGmm *gmm,
218  BaseFloat *obj_change_out,
219  BaseFloat *count_out,
220  int32 *floored_elements_out = NULL,
221  int32 *floored_gauss_out = NULL,
222  int32 *removed_gauss_out = NULL);
223 
225 void MapDiagGmmUpdate(const MapDiagGmmOptions &config,
226  const AccumDiagGmm &diag_gmm_acc,
227  GmmFlagsType flags,
228  DiagGmm *gmm,
229  BaseFloat *obj_change_out,
230  BaseFloat *count_out);
231 
233 BaseFloat MlObjective(const DiagGmm &gmm,
234  const AccumDiagGmm &diaggmm_acc);
235 
236 } // End namespace kaldi
237 
238 
239 #endif // KALDI_GMM_MLE_DIAG_GMM_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 Dim() const
Returns the dimensionality of the Gaussian mean vectors.
Definition: diag-gmm.h:74
void MapDiagGmmUpdate(const MapDiagGmmOptions &config, const AccumDiagGmm &diag_gmm_acc, GmmFlagsType flags, DiagGmm *gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
Maximum A Posteriori estimation of the model.
GmmFlagsType AugmentGmmFlags(GmmFlagsType f)
Returns "augmented" version of flags: e.g.
Definition: model-common.cc:52
BaseFloat weight_tau
Tau value for the weights– this tau value is applied per state, not per Gaussian.
Definition: mle-diag-gmm.h:87
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
BaseFloat min_gaussian_occupancy
Minimum count below which a Gaussian is not updated (and is removed, if remove_low_count_gaussians ==...
Definition: mle-diag-gmm.h:47
void MleDiagGmmUpdate(const MleDiagGmmOptions &config, const AccumDiagGmm &diag_gmm_acc, GmmFlagsType flags, DiagGmm *gmm, BaseFloat *obj_change_out, BaseFloat *count_out, int32 *floored_elements_out, int32 *floored_gaussians_out, int32 *removed_gaussians_out)
for computing the maximum-likelihood estimates of the parameters of a Gaussian mixture model...
BaseFloat MlObjective(const DiagGmm &gmm, const AccumDiagGmm &diag_gmm_acc)
Calc using the DiagGMM exponential form.
Matrix< double > mean_accumulator_
Definition: mle-diag-gmm.h:196
const VectorBase< double > & occupancy() const
Definition: mle-diag-gmm.h:183
kaldi::int32 int32
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void Register(OptionsItf *opts)
Definition: mle-diag-gmm.h:93
double min_variance
Minimum allowed variance in any dimension (if no variance floor) It is in double since the variance i...
Definition: mle-diag-gmm.h:50
AccumDiagGmm(const DiagGmm &gmm, GmmFlagsType flags)
Definition: mle-diag-gmm.h:109
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
GmmFlagsType flags_
Flags corresponding to the accumulators that are stored.
Definition: mle-diag-gmm.h:193
const MatrixBase< double > & variance_accumulator() const
Definition: mle-diag-gmm.h:185
const MatrixBase< double > & mean_accumulator() const
Definition: mle-diag-gmm.h:184
GmmFlagsType Flags() const
Definition: mle-diag-gmm.h:182
void Register(OptionsItf *opts)
Definition: mle-diag-gmm.h:59
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
Configuration variables like variance floor, minimum occupancy, etc.
Definition: mle-diag-gmm.h:38
Vector< double > occupancy_
Definition: mle-diag-gmm.h:195
int32 Dim() const
Returns the dimensionality of the feature vectors.
Definition: mle-diag-gmm.h:126
Vector< double > variance_floor_vector
Variance floor for each dimension [empty if not supplied].
Definition: mle-diag-gmm.h:41
BaseFloat min_gaussian_weight
Minimum weight below which a Gaussian is not updated (and is removed, if remove_low_count_gaussians =...
Definition: mle-diag-gmm.h:44
Matrix< double > variance_accumulator_
Definition: mle-diag-gmm.h:197
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void Resize(int32 num_gauss, int32 dim, GmmFlagsType flags)
Allocates memory for accumulators.
BaseFloat variance_tau
Tau value for the variances.
Definition: mle-diag-gmm.h:83
BaseFloat mean_tau
Tau value for the means.
Definition: mle-diag-gmm.h:78
int32 NumGauss() const
Returns the number of mixture components.
Definition: mle-diag-gmm.h:124
Configuration variables for Maximum A Posteriori (MAP) update.
Definition: mle-diag-gmm.h:76