mle-am-diag-gmm.cc
Go to the documentation of this file.
1 // gmm/mle-am-diag-gmm.cc
2 
3 // Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal);
4 // Microsoft Corporation; Georg Stemmer; Yanmin Qian
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "gmm/am-diag-gmm.h"
22 #include "gmm/mle-am-diag-gmm.h"
23 #include "util/stl-utils.h"
24 
25 namespace kaldi {
26 
28  KALDI_ASSERT(index >= 0 && index < static_cast<int32>(gmm_accumulators_.size()));
29  return *(gmm_accumulators_[index]);
30 }
31 
33  KALDI_ASSERT(index >= 0 && index < static_cast<int32>(gmm_accumulators_.size()));
34  return *(gmm_accumulators_[index]);
35 }
36 
39 }
40 
41 void AccumAmDiagGmm::Init(const AmDiagGmm &model,
42  GmmFlagsType flags) {
43  DeletePointers(&gmm_accumulators_); // in case was non-empty when called.
44  gmm_accumulators_.resize(model.NumPdfs(), NULL);
45  for (int32 i = 0; i < model.NumPdfs(); i++) {
47  gmm_accumulators_[i]->Resize(model.GetPdf(i), flags);
48  }
49 }
50 
51 void AccumAmDiagGmm::Init(const AmDiagGmm &model,
52  int32 dim, GmmFlagsType flags) {
53  KALDI_ASSERT(dim > 0);
54  DeletePointers(&gmm_accumulators_); // in case was non-empty when called.
55  gmm_accumulators_.resize(model.NumPdfs(), NULL);
56  for (int32 i = 0; i < model.NumPdfs(); i++) {
58  gmm_accumulators_[i]->Resize(model.GetPdf(i).NumGauss(),
59  dim, flags);
60  }
61 }
62 
64  for (size_t i = 0; i < gmm_accumulators_.size(); i++) {
65  gmm_accumulators_[i]->SetZero(flags);
66  }
67 }
68 
70  const AmDiagGmm &model, const VectorBase<BaseFloat> &data,
71  int32 gmm_index, BaseFloat weight) {
72  KALDI_ASSERT(static_cast<size_t>(gmm_index) < gmm_accumulators_.size());
73  BaseFloat log_like =
74  gmm_accumulators_[gmm_index]->AccumulateFromDiag(model.GetPdf(gmm_index),
75  data, weight);
76  total_log_like_ += log_like * weight;
77  total_frames_ += weight;
78  return log_like;
79 }
80 
82  const AmDiagGmm &model,
83  const VectorBase<BaseFloat> &data1,
84  const VectorBase<BaseFloat> &data2,
85  int32 gmm_index,
86  BaseFloat weight) {
87  KALDI_ASSERT(static_cast<size_t>(gmm_index) < gmm_accumulators_.size());
88  const DiagGmm &gmm = model.GetPdf(gmm_index);
89  AccumDiagGmm &acc = *(gmm_accumulators_[gmm_index]);
90  Vector<BaseFloat> posteriors;
91  BaseFloat log_like = gmm.ComponentPosteriors(data1, &posteriors);
92  posteriors.Scale(weight);
93  acc.AccumulateFromPosteriors(data2, posteriors);
94  total_log_like_ += log_like * weight;
95  total_frames_ += weight;
96  return log_like;
97 }
98 
99 
101  const AmDiagGmm &model, const VectorBase<BaseFloat> &data,
102  int32 gmm_index, const VectorBase<BaseFloat> &posteriors) {
103  KALDI_ASSERT(gmm_index >= 0 && gmm_index < NumAccs());
104  gmm_accumulators_[gmm_index]->AccumulateFromPosteriors(data, posteriors);
105  total_frames_ += posteriors.Sum();
106 }
107 
109  const AmDiagGmm &am, const VectorBase<BaseFloat> &data,
110  int32 gmm_index, int32 gauss_index, BaseFloat weight) {
111  KALDI_ASSERT(gmm_index >= 0 && gmm_index < NumAccs());
112  KALDI_ASSERT(gauss_index >= 0
113  && gauss_index < am.GetPdf(gmm_index).NumGauss());
114  gmm_accumulators_[gmm_index]->AccumulateForComponent(data, gauss_index, weight);
115 }
116 
117 void AccumAmDiagGmm::Read(std::istream &in_stream, bool binary,
118  bool add) {
119  int32 num_pdfs;
120  ExpectToken(in_stream, binary, "<NUMPDFS>");
121  ReadBasicType(in_stream, binary, &num_pdfs);
122  KALDI_ASSERT(num_pdfs > 0);
123  if (!add || (add && gmm_accumulators_.empty())) {
124  gmm_accumulators_.resize(num_pdfs, NULL);
125  for (std::vector<AccumDiagGmm*>::iterator it = gmm_accumulators_.begin(),
126  end = gmm_accumulators_.end(); it != end; ++it) {
127  delete *it;
128  *it = new AccumDiagGmm();
129  (*it)->Read(in_stream, binary, add);
130  }
131  } else {
132  if (gmm_accumulators_.size() != static_cast<size_t> (num_pdfs))
133  KALDI_ERR << "Adding accumulators but num-pdfs do not match: "
134  << (gmm_accumulators_.size()) << " vs. "
135  << (num_pdfs);
136  for (std::vector<AccumDiagGmm*>::iterator it = gmm_accumulators_.begin(),
137  end = gmm_accumulators_.end(); it != end; ++it)
138  (*it)->Read(in_stream, binary, add);
139  }
140  // TODO(arnab): Bad hack! Need to make this self-delimiting.
141  in_stream.peek(); // This will set the EOF bit for older accs.
142  if (!in_stream.eof()) {
143  double like, frames;
144  ExpectToken(in_stream, binary, "<total_like>");
145  ReadBasicType(in_stream, binary, &like);
146  total_log_like_ = (add)? total_log_like_ + like : like;
147  ExpectToken(in_stream, binary, "<total_frames>");
148  ReadBasicType(in_stream, binary, &frames);
149  total_frames_ = (add)? total_frames_ + frames : frames;
150  }
151 }
152 
153 void AccumAmDiagGmm::Write(std::ostream &out_stream, bool binary) const {
154  int32 num_pdfs = gmm_accumulators_.size();
155  WriteToken(out_stream, binary, "<NUMPDFS>");
156  WriteBasicType(out_stream, binary, num_pdfs);
157  for (std::vector<AccumDiagGmm*>::const_iterator it =
158  gmm_accumulators_.begin(), end = gmm_accumulators_.end(); it != end; ++it) {
159  (*it)->Write(out_stream, binary);
160  }
161  WriteToken(out_stream, binary, "<total_like>");
162  WriteBasicType(out_stream, binary, total_log_like_);
163 
164  WriteToken(out_stream, binary, "<total_frames>");
165  WriteBasicType(out_stream, binary, total_frames_);
166 }
167 
168 
169 // BaseFloat AccumAmDiagGmm::TotCount() const {
170 // BaseFloat ans = 0.0;
171 // for (int32 pdf = 0; pdf < NumAccs(); pdf++)
172 // ans += gmm_accumulators_[pdf]->occupancy().Sum();
173 // return ans;
174 // }
175 
176 void ResizeModel (int32 dim, AmDiagGmm *am_gmm) {
177  for (int32 pdf_id = 0; pdf_id < am_gmm->NumPdfs(); pdf_id++) {
178  DiagGmm &pdf = am_gmm->GetPdf(pdf_id);
179  pdf.Resize(pdf.NumGauss(), dim);
180  Matrix<BaseFloat> inv_vars(pdf.NumGauss(), dim);
181  inv_vars.Set(1.0); // make all vars 1.
182  pdf.SetInvVars(inv_vars);
183  pdf.ComputeGconsts();
184  }
185 }
186 
188  const AccumAmDiagGmm &am_diag_gmm_acc,
189  GmmFlagsType flags,
190  AmDiagGmm *am_gmm,
191  BaseFloat *obj_change_out,
192  BaseFloat *count_out) {
193  if (am_diag_gmm_acc.Dim() != am_gmm->Dim()) {
194  KALDI_ASSERT(am_diag_gmm_acc.Dim() != 0);
195  KALDI_WARN << "Dimensions of accumulator " << am_diag_gmm_acc.Dim()
196  << " and gmm " << am_gmm->Dim() << " do not match, resizing "
197  << " GMM and setting to zero-mean, unit-variance.";
198  ResizeModel(am_diag_gmm_acc.Dim(), am_gmm);
199  }
200 
201  KALDI_ASSERT(am_gmm != NULL);
202  KALDI_ASSERT(am_diag_gmm_acc.NumAccs() == am_gmm->NumPdfs());
203  if (obj_change_out != NULL) *obj_change_out = 0.0;
204  if (count_out != NULL) *count_out = 0.0;
205 
206  BaseFloat tot_obj_change = 0.0, tot_count = 0.0;
207  int32 tot_elems_floored = 0, tot_gauss_floored = 0,
208  tot_gauss_removed = 0;
209  for (int32 i = 0; i < am_diag_gmm_acc.NumAccs(); i++) {
210  BaseFloat obj_change, count;
211  int32 elems_floored, gauss_floored, gauss_removed;
212 
213  MleDiagGmmUpdate(config, am_diag_gmm_acc.GetAcc(i), flags,
214  &(am_gmm->GetPdf(i)),
215  &obj_change, &count, &elems_floored,
216  &gauss_floored, &gauss_removed);
217  tot_obj_change += obj_change;
218  tot_count += count;
219  tot_elems_floored += elems_floored;
220  tot_gauss_floored += gauss_floored;
221  tot_gauss_removed += gauss_removed;
222  }
223  if (obj_change_out != NULL) *obj_change_out = tot_obj_change;
224  if (count_out != NULL) *count_out = tot_count;
225  KALDI_LOG << tot_elems_floored << " variance elements floored in "
226  << tot_gauss_floored << " Gaussians, out of "
227  << am_gmm->NumGauss();
228  if (config.remove_low_count_gaussians) {
229  KALDI_LOG << "Removed " << tot_gauss_removed
230  << " Gaussians due to counts < --min-gaussian-occupancy="
231  << config.min_gaussian_occupancy
232  << " and --remove-low-count-gaussians=true";
233  }
234 }
235 
236 
238  const AccumAmDiagGmm &am_diag_gmm_acc,
239  GmmFlagsType flags,
240  AmDiagGmm *am_gmm,
241  BaseFloat *obj_change_out,
242  BaseFloat *count_out) {
243  KALDI_ASSERT(am_gmm != NULL && am_diag_gmm_acc.Dim() == am_gmm->Dim() &&
244  am_diag_gmm_acc.NumAccs() == am_gmm->NumPdfs());
245  if (obj_change_out != NULL) *obj_change_out = 0.0;
246  if (count_out != NULL) *count_out = 0.0;
247  BaseFloat tmp_obj_change, tmp_count;
248  BaseFloat *p_obj = (obj_change_out != NULL) ? &tmp_obj_change : NULL,
249  *p_count = (count_out != NULL) ? &tmp_count : NULL;
250 
251  for (int32 i = 0; i < am_diag_gmm_acc.NumAccs(); i++) {
252  MapDiagGmmUpdate(config, am_diag_gmm_acc.GetAcc(i), flags,
253  &(am_gmm->GetPdf(i)), p_obj, p_count);
254 
255  if (obj_change_out != NULL) *obj_change_out += tmp_obj_change;
256  if (count_out != NULL) *count_out += tmp_count;
257  }
258 }
259 
260 
262  double ans = 0.0;
263  for (int32 i = 0; i < NumAccs(); i++) {
264  const AccumDiagGmm &acc = GetAcc(i);
265  ans += acc.occupancy().Sum();
266  }
267  return ans;
268 }
269 
271  for (int32 i = 0; i < NumAccs(); i++) {
272  AccumDiagGmm &acc = GetAcc(i);
273  acc.Scale(scale, acc.Flags());
274  }
275  total_frames_ *= scale;
276  total_log_like_ *= scale;
277 }
278 
279 void AccumAmDiagGmm::Add(BaseFloat scale, const AccumAmDiagGmm &other) {
280  total_frames_ += scale * other.total_frames_;
281  total_log_like_ += scale * other.total_log_like_;
282 
283  int32 num_accs = NumAccs();
284  KALDI_ASSERT(num_accs == other.NumAccs());
285  for (int32 i = 0; i < num_accs; i++)
286  gmm_accumulators_[i]->Add(scale, *(other.gmm_accumulators_[i]));
287 }
288 
289 } // namespace kaldi
void MleAmDiagGmmUpdate(const MleDiagGmmOptions &config, const AccumAmDiagGmm &am_diag_gmm_acc, GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
for computing the maximum-likelihood estimates of the parameters of an acoustic model that uses diago...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void MapDiagGmmUpdate(const MapDiagGmmOptions &config, const AccumDiagGmm &diag_gmm_acc, GmmFlagsType flags, DiagGmm *gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
Maximum A Posteriori estimation of the model.
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
int32 NumGauss() const
Definition: am-diag-gmm.cc:72
void MapAmDiagGmmUpdate(const MapDiagGmmOptions &config, const AccumAmDiagGmm &am_diag_gmm_acc, GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
Maximum A Posteriori update.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
BaseFloat min_gaussian_occupancy
Minimum count below which a Gaussian is not updated (and is removed, if remove_low_count_gaussians ==...
Definition: mle-diag-gmm.h:47
void MleDiagGmmUpdate(const MleDiagGmmOptions &config, const AccumDiagGmm &diag_gmm_acc, GmmFlagsType flags, DiagGmm *gmm, BaseFloat *obj_change_out, BaseFloat *count_out, int32 *floored_elements_out, int32 *floored_gaussians_out, int32 *removed_gaussians_out)
for computing the maximum-likelihood estimates of the parameters of a Gaussian mixture model...
void SetZero(GmmFlagsType flags)
BaseFloat AccumulateForGmmTwofeats(const AmDiagGmm &model, const VectorBase< BaseFloat > &data1, const VectorBase< BaseFloat > &data2, int32 gmm_index, BaseFloat weight)
Accumulate stats for a single GMM in the model; uses data1 for getting posteriors and data2 for stats...
BaseFloat AccumulateForGmm(const AmDiagGmm &model, const VectorBase< BaseFloat > &data, int32 gmm_index, BaseFloat weight)
Accumulate stats for a single GMM in the model; returns log likelihood.
void Resize(int32 nMix, int32 dim)
Resizes arrays to this dim. Does not initialize data.
Definition: diag-gmm.cc:66
int32 ComputeGconsts()
Sets the gconsts.
Definition: diag-gmm.cc:114
const VectorBase< double > & occupancy() const
Definition: mle-diag-gmm.h:183
kaldi::int32 int32
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void Scale(BaseFloat f, GmmFlagsType flags)
void Add(BaseFloat scale, const AccumAmDiagGmm &other)
void ResizeModel(int32 dim, AmDiagGmm *am_gmm)
void AccumulateFromPosteriors(const AmDiagGmm &model, const VectorBase< BaseFloat > &data, int32 gmm_index, const VectorBase< BaseFloat > &posteriors)
Accumulates stats for a single GMM in the model using pre-computed Gaussian posteriors.
double total_frames_
Total counts & likelihood (for diagnostics)
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
void AccumulateForGaussian(const AmDiagGmm &am, const VectorBase< BaseFloat > &data, int32 gmm_index, int32 gauss_index, BaseFloat weight)
Accumulate stats for a single Gaussian component in the model.
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
GmmFlagsType Flags() const
Definition: mle-diag-gmm.h:182
void Scale(BaseFloat scale)
BaseFloat TotStatsCount() const
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
Configuration variables like variance floor, minimum occupancy, etc.
Definition: mle-diag-gmm.h:38
void Scale(Real alpha)
Multiplies all elements by this constant.
Real Sum() const
Returns sum of the elements.
void Read(std::istream &in_stream, bool binary, bool add=false)
void SetInvVars(const MatrixBase< Real > &v)
Set the (inverse) variances and recompute means_invvars_.
Definition: diag-gmm-inl.h:78
int32 Dim() const
Definition: am-diag-gmm.h:79
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
Definition: am-diag-gmm.h:119
A class representing a vector.
Definition: kaldi-vector.h:406
const AccumDiagGmm & GetAcc(int32 index) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Write(std::ostream &out_stream, bool binary) const
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
void AccumulateFromPosteriors(const VectorBase< BaseFloat > &data, const VectorBase< BaseFloat > &gauss_posteriors)
Accumulate for all components, given the posteriors.
void Init(const AmDiagGmm &model, GmmFlagsType flags)
Initializes accumulators for each GMM based on the number of components and dimension.
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
void Set(Real)
Sets all elements to a specific value.
Configuration variables for Maximum A Posteriori (MAP) update.
Definition: mle-diag-gmm.h:76