am-diag-gmm.h
Go to the documentation of this file.
1 // gmm/am-diag-gmm.h
2 
3 // Copyright 2009-2012 Saarland University (Author: Arnab Ghoshal)
4 // Johns Hopkins University (Author: Daniel Povey)
5 // Karel Vesely
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #ifndef KALDI_GMM_AM_DIAG_GMM_H_
23 #define KALDI_GMM_AM_DIAG_GMM_H_ 1
24 
25 #include <vector>
26 
27 #include "base/kaldi-common.h"
28 #include "gmm/diag-gmm.h"
29 #include "itf/options-itf.h"
30 
31 namespace kaldi {
35 
36 class AmDiagGmm {
37  public:
38  AmDiagGmm() {}
39  ~AmDiagGmm();
40 
42  void Init(const DiagGmm &proto, int32 num_pdfs);
44  void AddPdf(const DiagGmm &gmm);
46  void CopyFromAmDiagGmm(const AmDiagGmm &other);
47 
48  void SplitPdf(int32 idx, int32 target_components, float perturb_factor);
49 
50  // In SplitByCount we use the "target_components" and "power"
51  // to work out targets for each state (according to power-of-occupancy rule),
52  // and any state less than its target gets mixed up. If some states
53  // were over their target, this may take the #Gauss over the target.
54  // we enforce a min-count on Gaussians while splitting (don't split
55  // if it would take it below min-count).
56  void SplitByCount(const Vector<BaseFloat> &state_occs,
57  int32 target_components, float perturb_factor,
58  BaseFloat power, BaseFloat min_count);
59 
60 
61  // In SplitByCount we use the "target_components" and "power"
62  // to work out targets for each state (according to power-of-occupancy rule),
63  // and any state over its target gets mixed down. If some states
64  // were under their target, this may take the #Gauss below the target.
65  void MergeByCount(const Vector<BaseFloat> &state_occs,
66  int32 target_components,
67  BaseFloat power, BaseFloat min_count);
68 
72 
73  BaseFloat LogLikelihood(const int32 pdf_index,
74  const VectorBase<BaseFloat> &data) const;
75 
76  void Read(std::istream &in_stream, bool binary);
77  void Write(std::ostream &out_stream, bool binary) const;
78 
79  int32 Dim() const {
80  return (densities_.size() > 0)? densities_[0]->Dim() : 0;
81  }
82  int32 NumPdfs() const { return densities_.size(); }
83  int32 NumGauss() const;
84  int32 NumGaussInPdf(int32 pdf_index) const;
85 
87  DiagGmm& GetPdf(int32 pdf_index);
88  const DiagGmm& GetPdf(int32 pdf_index) const;
89  void GetGaussianMean(int32 pdf_index, int32 gauss,
90  VectorBase<BaseFloat> *out) const;
91  void GetGaussianVariance(int32 pdf_index, int32 gauss,
92  VectorBase<BaseFloat> *out) const;
93 
95  void SetGaussianMean(int32 pdf_index, int32 gauss_index,
96  const VectorBase<BaseFloat> &in);
97 
98  private:
99  std::vector<DiagGmm*> densities_;
100 // int32 dim_;
101 
102  void RemovePdf(int32 pdf_index);
103 
105 };
106 
107 
109  const int32 pdf_index, const VectorBase<BaseFloat> &data) const {
110  return densities_[pdf_index]->LogLikelihood(data);
111 }
112 
113 inline int32 AmDiagGmm::NumGaussInPdf(int32 pdf_index) const {
114  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
115  && (densities_[pdf_index] != NULL));
116  return densities_[pdf_index]->NumGauss();
117 }
118 
119 inline DiagGmm& AmDiagGmm::GetPdf(int32 pdf_index) {
120  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
121  && (densities_[pdf_index] != NULL));
122  return *(densities_[pdf_index]);
123 }
124 
125 inline const DiagGmm& AmDiagGmm::GetPdf(int32 pdf_index) const {
126  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
127  && (densities_[pdf_index] != NULL));
128  return *(densities_[pdf_index]);
129 }
130 
131 inline void AmDiagGmm::GetGaussianMean(int32 pdf_index, int32 gauss,
132  VectorBase<BaseFloat> *out) const {
133  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
134  && (densities_[pdf_index] != NULL));
135  densities_[pdf_index]->GetComponentMean(gauss, out);
136 }
137 
138 inline void AmDiagGmm::GetGaussianVariance(int32 pdf_index, int32 gauss,
139  VectorBase<BaseFloat> *out) const {
140  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
141  && (densities_[pdf_index] != NULL));
142  densities_[pdf_index]->GetComponentVariance(gauss, out);
143 }
144 
145 inline void AmDiagGmm::SetGaussianMean(int32 pdf_index, int32 gauss_index,
146  const VectorBase<BaseFloat> &in) {
147  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
148  && (densities_[pdf_index] != NULL));
149  densities_[pdf_index]->SetComponentMean(gauss_index, in);
150 }
151 
152 inline void AmDiagGmm::SplitPdf(int32 pdf_index,
153  int32 target_components,
154  float perturb_factor) {
155  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
156  && (densities_[pdf_index] != NULL));
157  densities_[pdf_index]->Split(target_components, perturb_factor);
158 }
159 
166 
168  : ubm_num_gauss(400), reduce_state_factor(0.2),
169  intermediate_num_gauss(4000), cluster_varfloor(0.01),
170  max_am_gauss(20000) {}
171  UbmClusteringOptions(int32 ncomp, BaseFloat red, int32 interm_gauss,
172  BaseFloat vfloor, int32 max_am_gauss)
173  : ubm_num_gauss(ncomp), reduce_state_factor(red),
174  intermediate_num_gauss(interm_gauss), cluster_varfloor(vfloor),
175  max_am_gauss(max_am_gauss) {}
176  void Register(OptionsItf *opts) {
177  std::string module = "UbmClusteringOptions: ";
178  opts->Register("max-am-gauss", &max_am_gauss, module+
179  "We first reduce acoustic model to this max #Gauss before clustering.");
180  opts->Register("ubm-num-gauss", &ubm_num_gauss, module+
181  "Number of Gaussians components in the final UBM.");
182  opts->Register("ubm-numcomps", &ubm_num_gauss, module+
183  "Backward compatibility option (see ubm-num-gauss)");
184  opts->Register("reduce-state-factor", &reduce_state_factor, module+
185  "Intermediate number of clustered states (as fraction of total states).");
186  opts->Register("intermediate-num-gauss", &intermediate_num_gauss, module+
187  "Intermediate number of merged Gaussian components.");
188  opts->Register("intermediate-numcomps", &intermediate_num_gauss, module+
189  "Backward compatibility option (see intermediate-num-gauss)");
190  opts->Register("cluster-varfloor", &cluster_varfloor, module+
191  "Variance floor used in bottom-up state clustering.");
192  }
193 
194  void Check();
195 };
196 
208 void ClusterGaussiansToUbm(const AmDiagGmm &am,
209  const Vector<BaseFloat> &state_occs,
211  DiagGmm *ubm_out);
212 
213 
214 
215 
216 } // namespace kaldi
217 
219 #endif // KALDI_GMM_AM_DIAG_GMM_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CopyFromAmDiagGmm(const AmDiagGmm &other)
Copies the parameters from another model. Allocates necessary memory.
Definition: am-diag-gmm.cc:79
void AddPdf(const DiagGmm &gmm)
Adds a GMM to the model, and increments the total number of PDFs.
Definition: am-diag-gmm.cc:57
void RemovePdf(int32 pdf_index)
Definition: am-diag-gmm.cc:66
int32 NumGauss() const
Definition: am-diag-gmm.cc:72
int32 ComputeGconsts()
Sets the gconsts for all the PDFs.
Definition: am-diag-gmm.cc:90
void MergeByCount(const Vector< BaseFloat > &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count)
Definition: am-diag-gmm.cc:125
kaldi::int32 int32
void SetGaussianMean(int32 pdf_index, int32 gauss_index, const VectorBase< BaseFloat > &in)
Mutators.
Definition: am-diag-gmm.h:145
int32 NumGaussInPdf(int32 pdf_index) const
Definition: am-diag-gmm.h:113
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
void ClusterGaussiansToUbm(const AmDiagGmm &am, const Vector< BaseFloat > &state_occs, UbmClusteringOptions opts, DiagGmm *ubm_out)
Clusters the Gaussians in an acoustic model to a single GMM with specified number of components...
Definition: am-diag-gmm.cc:195
BaseFloat LogLikelihood(const int32 pdf_index, const VectorBase< BaseFloat > &data) const
Definition: am-diag-gmm.h:108
void GetGaussianVariance(int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
Definition: am-diag-gmm.h:138
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
KALDI_DISALLOW_COPY_AND_ASSIGN(AmDiagGmm)
UbmClusteringOptions(int32 ncomp, BaseFloat red, int32 interm_gauss, BaseFloat vfloor, int32 max_am_gauss)
Definition: am-diag-gmm.h:171
int32 Dim() const
Definition: am-diag-gmm.h:79
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
Definition: am-diag-gmm.h:119
void Register(OptionsItf *opts)
Definition: am-diag-gmm.h:176
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
void GetGaussianMean(int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
Definition: am-diag-gmm.h:131
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
void SplitPdf(int32 idx, int32 target_components, float perturb_factor)
Definition: am-diag-gmm.h:152
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void Init(const DiagGmm &proto, int32 num_pdfs)
Initializes with a single "prototype" GMM.
Definition: am-diag-gmm.cc:38
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
void SplitByCount(const Vector< BaseFloat > &state_occs, int32 target_components, float perturb_factor, BaseFloat power, BaseFloat min_count)
Definition: am-diag-gmm.cc:102