am-diag-gmm.cc
Go to the documentation of this file.
1 // gmm/am-diag-gmm.cc
2 
3 // Copyright 2012 Arnab Ghoshal Johns Hopkins University (Author: Daniel Povey) Karel Vesely
4 // Copyright 2009-2011 Saarland University; Microsoft Corporation;
5 // Georg Stemmer
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include <queue>
23 #include <string>
24 #include <vector>
25 using std::vector;
26 
27 #include "gmm/am-diag-gmm.h"
28 #include "util/stl-utils.h"
30 #include "tree/cluster-utils.h"
31 
32 namespace kaldi {
33 
36 }
37 
38 void AmDiagGmm::Init(const DiagGmm &proto, int32 num_pdfs) {
39  if (densities_.size() != 0) {
40  KALDI_WARN << "Init() called on a non-empty object. Contents will be "
41  "overwritten";
43  }
44  if (num_pdfs == 0) {
45  KALDI_WARN << "Init() called with number of pdfs = 0. Will do nothing.";
46  return;
47  }
48 
49  densities_.resize(num_pdfs, NULL);
50  for (vector<DiagGmm*>::iterator itr = densities_.begin(),
51  end = densities_.end(); itr != end; ++itr) {
52  *itr = new DiagGmm();
53  (*itr)->CopyFromDiagGmm(proto);
54  }
55 }
56 
57 void AmDiagGmm::AddPdf(const DiagGmm &gmm) {
58  if (densities_.size() != 0) // not the first gmm
59  KALDI_ASSERT(gmm.Dim() == this->Dim());
60 
61  DiagGmm *gmm_ptr = new DiagGmm();
62  gmm_ptr->CopyFromDiagGmm(gmm);
63  densities_.push_back(gmm_ptr);
64 }
65 
66 void AmDiagGmm::RemovePdf(int32 pdf_index) {
67  KALDI_ASSERT(static_cast<size_t>(pdf_index) < densities_.size());
68  delete densities_[pdf_index];
69  densities_.erase(densities_.begin() + pdf_index);
70 }
71 
73  int32 ans = 0;
74  for (size_t i = 0; i < densities_.size(); i++)
75  ans += densities_[i]->NumGauss();
76  return ans;
77 }
78 
80  if (densities_.size() != 0) {
82  }
83  densities_.resize(other.NumPdfs(), NULL);
84  for (int32 i = 0, end = densities_.size(); i < end; i++) {
85  densities_[i] = new DiagGmm();
86  densities_[i]->CopyFromDiagGmm(*other.densities_[i]);
87  }
88 }
89 
91  int32 num_bad = 0;
92  for (std::vector<DiagGmm*>::iterator itr = densities_.begin(),
93  end = densities_.end(); itr != end; ++itr) {
94  num_bad += (*itr)->ComputeGconsts();
95  }
96  if (num_bad > 0)
97  KALDI_WARN << "Found " << num_bad << " Gaussian components.";
98  return num_bad;
99 }
100 
101 
103  int32 target_components,
104  float perturb_factor, BaseFloat power,
105  BaseFloat min_count) {
106  int32 gauss_at_start = NumGauss();
107  std::vector<int32> targets;
108  GetSplitTargets(state_occs, target_components, power,
109  min_count, &targets);
110 
111  for (int32 i = 0; i < NumPdfs(); i++) {
112  if (densities_[i]->NumGauss() < targets[i])
113  densities_[i]->Split(targets[i], perturb_factor);
114  }
115 
116  KALDI_LOG << "Split " << NumPdfs() << " states with target = "
117  << target_components << ", power = " << power
118  << ", perturb_factor = " << perturb_factor
119  << " and min_count = " << min_count
120  << ", split #Gauss from " << gauss_at_start << " to "
121  << NumGauss();
122 }
123 
124 
126  int32 target_components,
127  BaseFloat power,
128  BaseFloat min_count) {
129  int32 gauss_at_start = NumGauss();
130  std::vector<int32> targets;
131  GetSplitTargets(state_occs, target_components,
132  power, min_count, &targets);
133 
134  for (int32 i = 0; i < NumPdfs(); i++) {
135  if (targets[i] == 0) targets[i] = 1; // can't merge below 1.
136  if (densities_[i]->NumGauss() > targets[i])
137  densities_[i]->Merge(targets[i]);
138  }
139 
140  KALDI_LOG << "Merged " << NumPdfs() << " states with target = "
141  << target_components << ", power = " << power
142  << " and min_count = " << min_count
143  << ", merged from " << gauss_at_start << " to "
144  << NumGauss();
145 }
146 
147 void AmDiagGmm::Read(std::istream &in_stream, bool binary) {
148  int32 num_pdfs, dim;
149 
150  ExpectToken(in_stream, binary, "<DIMENSION>");
151  ReadBasicType(in_stream, binary, &dim);
152  ExpectToken(in_stream, binary, "<NUMPDFS>");
153  ReadBasicType(in_stream, binary, &num_pdfs);
154  KALDI_ASSERT(num_pdfs > 0);
155  densities_.reserve(num_pdfs);
156  for (int32 i = 0; i < num_pdfs; i++) {
157  densities_.push_back(new DiagGmm());
158  densities_.back()->Read(in_stream, binary);
159  KALDI_ASSERT(densities_.back()->Dim() == dim);
160  }
161 }
162 
163 void AmDiagGmm::Write(std::ostream &out_stream, bool binary) const {
164  int32 dim = this->Dim();
165  if (dim == 0) {
166  KALDI_WARN << "Trying to write empty AmDiagGmm object.";
167  }
168  WriteToken(out_stream, binary, "<DIMENSION>");
169  WriteBasicType(out_stream, binary, dim);
170  WriteToken(out_stream, binary, "<NUMPDFS>");
171  WriteBasicType(out_stream, binary, static_cast<int32>(densities_.size()));
172  for (std::vector<DiagGmm*>::const_iterator it = densities_.begin(),
173  end = densities_.end(); it != end; ++it) {
174  (*it)->Write(out_stream, binary);
175  }
176 }
177 
179  if (ubm_num_gauss > intermediate_num_gauss)
180  KALDI_ERR << "Invalid parameters: --ubm-num_gauss=" << ubm_num_gauss
181  << " > --intermediate-num_gauss=" << intermediate_num_gauss;
182  if (ubm_num_gauss > max_am_gauss)
183  KALDI_ERR << "Invalid parameters: --ubm-num_gauss=" << ubm_num_gauss
184  << " > --max-am-gauss=" << max_am_gauss;
185  if (ubm_num_gauss <= 0)
186  KALDI_ERR << "Invalid parameters: --ubm-num_gauss=" << ubm_num_gauss;
187  if (cluster_varfloor <= 0)
188  KALDI_ERR << "Invalid parameters: --cluster-varfloor="
189  << cluster_varfloor;
190  if (reduce_state_factor <= 0 || reduce_state_factor > 1)
191  KALDI_ERR << "Invalid parameters: --reduce-state-factor="
192  << reduce_state_factor;
193 }
194 
196  const Vector<BaseFloat> &state_occs,
198  DiagGmm *ubm_out) {
199  opts.Check(); // Make sure the various # of Gaussians make sense.
200  if (am.NumGauss() > opts.max_am_gauss) {
201  KALDI_LOG << "ClusterGaussiansToUbm: first reducing num-gauss from " << am.NumGauss()
202  << " to " << opts.max_am_gauss;
203  AmDiagGmm tmp_am;
204  tmp_am.CopyFromAmDiagGmm(am);
205  BaseFloat power = 1.0, min_count = 1.0; // Make the power 1, which I feel
206  // is appropriate to the way we're doing the overall clustering procedure.
207  tmp_am.MergeByCount(state_occs, opts.max_am_gauss, power, min_count);
208 
209  if (tmp_am.NumGauss() > opts.max_am_gauss) {
210  KALDI_LOG << "Clustered down to " << tmp_am.NumGauss()
211  << "; will not cluster further";
212  opts.max_am_gauss = tmp_am.NumGauss();
213  }
214  ClusterGaussiansToUbm(tmp_am, state_occs, opts, ubm_out);
215  return;
216  }
217 
218  int32 num_pdfs = static_cast<int32>(am.NumPdfs()),
219  dim = am.Dim(),
220  num_clust_states = static_cast<int32>(opts.reduce_state_factor*num_pdfs);
221 
222  Vector<BaseFloat> tmp_mean(dim);
223  Vector<BaseFloat> tmp_var(dim);
224  DiagGmm tmp_gmm;
225  vector<Clusterable*> states;
226  states.reserve(num_pdfs); // NOT resize(); uses push_back.
227 
228  // Replace the GMM for each state with a single Gaussian.
229  KALDI_VLOG(1) << "Merging densities to 1 Gaussian per state.";
230  for (int32 pdf_index = 0; pdf_index < num_pdfs; pdf_index++) {
231  KALDI_VLOG(3) << "Merging Gausians for state : " << pdf_index;
232  tmp_gmm.CopyFromDiagGmm(am.GetPdf(pdf_index));
233  tmp_gmm.Merge(1);
234  tmp_gmm.GetComponentMean(0, &tmp_mean);
235  tmp_gmm.GetComponentVariance(0, &tmp_var);
236  tmp_var.AddVec2(1.0, tmp_mean); // make it x^2 stats.
237  // It may cause problems downstream if we add states with zero weights (see
238  // KALDI_ASSERT(weight > 0) below), so we put in a very small floor.
239  // These states with tiny weights will later get merged into other states.
240  BaseFloat this_weight = 1.0e-10 + state_occs(pdf_index);
241  tmp_mean.Scale(this_weight);
242  tmp_var.Scale(this_weight);
243  states.push_back(new GaussClusterable(tmp_mean, tmp_var,
244  opts.cluster_varfloor, this_weight));
245  }
246 
247  // Bottom-up clustering of the Gaussians corresponding to each state, which
248  // gives a partial clustering of states in the 'state_clusters' vector.
249  vector<int32> state_clusters;
250  KALDI_VLOG(1) << "Creating " << num_clust_states << " clusters of states.";
251  ClusterBottomUp(states, std::numeric_limits<BaseFloat>::max(), num_clust_states,
252  NULL /*actual clusters not needed*/,
253  &state_clusters /*get the cluster assignments*/);
254  DeletePointers(&states);
255 
256  // For each cluster of states, create a pool of all the Gaussians in those
257  // states, weighted by the state occupancies. This is done so that initially
258  // only the Gaussians corresponding to "similar" states (similarity as
259  // determined by the previous clustering) are merged.
260  vector< vector<Clusterable*> > state_clust_gauss;
261  state_clust_gauss.resize(num_clust_states);
262  for (int32 pdf_index = 0; pdf_index < num_pdfs; pdf_index++) {
263  int32 current_cluster = state_clusters[pdf_index];
264  for (int32 num_gauss = am.GetPdf(pdf_index).NumGauss(),
265  gauss_index = 0; gauss_index < num_gauss; ++gauss_index) {
266  am.GetGaussianMean(pdf_index, gauss_index, &tmp_mean);
267  am.GetGaussianVariance(pdf_index, gauss_index, &tmp_var);
268  tmp_var.AddVec2(1.0, tmp_mean); // make it x^2 stats.
269  // adding 1.0e-10 to the weight will prevent problems later on, see
270  // the line KALDI_ASSERT(weight > 0.0).
271  BaseFloat this_weight = (1.0e-10 + state_occs(pdf_index)) *
272  (am.GetPdf(pdf_index).weights())(gauss_index);
273  tmp_mean.Scale(this_weight);
274  tmp_var.Scale(this_weight);
275  state_clust_gauss[current_cluster].push_back(new GaussClusterable(
276  tmp_mean, tmp_var, opts.cluster_varfloor, this_weight));
277  }
278  }
279 
280  // This is an unlikely operating scenario, no need to handle this in a more
281  // optimized fashion.
282  if (opts.intermediate_num_gauss > am.NumGauss()) {
283  KALDI_WARN << "Intermediate num_gauss " << opts.intermediate_num_gauss
284  << " is more than num-gauss " << am.NumGauss()
285  << ", reducing it to " << am.NumGauss();
286  opts.intermediate_num_gauss = am.NumGauss();
287  }
288 
289  // The compartmentalized clusterer used below does not merge compartments.
290  if (opts.intermediate_num_gauss < num_clust_states) {
291  KALDI_WARN << "Intermediate num_gauss " << opts.intermediate_num_gauss
292  << " is less than # of preclustered states " << num_clust_states
293  << ", increasing it to " << num_clust_states;
294  opts.intermediate_num_gauss = num_clust_states;
295  }
296 
297  KALDI_VLOG(1) << "Merging from " << am.NumGauss() << " Gaussians in the "
298  << "acoustic model, down to " << opts.intermediate_num_gauss
299  << " Gaussians.";
300  vector< vector<Clusterable*> > gauss_clusters_out;
301  ClusterBottomUpCompartmentalized(state_clust_gauss, std::numeric_limits<BaseFloat>::max(),
303  &gauss_clusters_out, NULL);
304  for (int32 clust_index = 0; clust_index < num_clust_states; clust_index++)
305  DeletePointers(&state_clust_gauss[clust_index]);
306 
307  // Next, put the remaining clustered Gaussians into a single GMM.
308  KALDI_VLOG(1) << "Putting " << opts.intermediate_num_gauss << " Gaussians "
309  << "into a single GMM for final merge step.";
310  Matrix<BaseFloat> tmp_means(opts.intermediate_num_gauss, dim);
311  Matrix<BaseFloat> tmp_vars(opts.intermediate_num_gauss, dim);
312  Vector<BaseFloat> tmp_weights(opts.intermediate_num_gauss);
313  Vector<BaseFloat> tmp_vec(dim);
314  int32 gauss_index = 0;
315  for (int32 clust_index = 0; clust_index < num_clust_states; clust_index++) {
316  for (int32 i = gauss_clusters_out[clust_index].size()-1; i >=0; --i) {
317  GaussClusterable *this_cluster = static_cast<GaussClusterable*>(
318  gauss_clusters_out[clust_index][i]);
319  BaseFloat weight = this_cluster->count();
320  KALDI_ASSERT(weight > 0.0);
321  tmp_weights(gauss_index) = weight;
322  tmp_vec.CopyFromVec(this_cluster->x_stats());
323  tmp_vec.Scale(1.0 / weight);
324  tmp_means.CopyRowFromVec(tmp_vec, gauss_index);
325  tmp_vec.CopyFromVec(this_cluster->x2_stats());
326  tmp_vec.Scale(1.0 / weight);
327  tmp_vec.AddVec2(-1.0, tmp_means.Row(gauss_index)); // x^2 stats to var.
328  tmp_vars.CopyRowFromVec(tmp_vec, gauss_index);
329  gauss_index++;
330  }
331  DeletePointers(&(gauss_clusters_out[clust_index]));
332  }
333  tmp_gmm.Resize(opts.intermediate_num_gauss, dim);
334  tmp_weights.Scale(1.0/tmp_weights.Sum());
335  tmp_gmm.SetWeights(tmp_weights);
336  tmp_vars.InvertElements(); // need inverse vars...
337  tmp_gmm.SetInvVarsAndMeans(tmp_vars, tmp_means);
338 
339  // Finally, cluster to the desired number of Gaussians in the UBM.
340  if (opts.ubm_num_gauss < tmp_gmm.NumGauss()) {
341  tmp_gmm.Merge(opts.ubm_num_gauss);
342  KALDI_VLOG(1) << "Merged down to " << tmp_gmm.NumGauss() << " Gaussians.";
343  } else {
344  KALDI_WARN << "Not merging Gaussians since " << opts.ubm_num_gauss
345  << " < " << tmp_gmm.NumGauss();
346  }
347  ubm_out->CopyFromDiagGmm(tmp_gmm);
348 }
349 
350 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 Dim() const
Returns the dimensionality of the Gaussian mean vectors.
Definition: diag-gmm.h:74
void CopyFromAmDiagGmm(const AmDiagGmm &other)
Copies the parameters from another model. Allocates necessary memory.
Definition: am-diag-gmm.cc:79
void AddPdf(const DiagGmm &gmm)
Adds a GMM to the model, and increments the total number of PDFs.
Definition: am-diag-gmm.cc:57
void CopyFromDiagGmm(const DiagGmm &diaggmm)
Copies from given DiagGmm.
Definition: diag-gmm.cc:83
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
void RemovePdf(int32 pdf_index)
Definition: am-diag-gmm.cc:66
int32 NumGauss() const
Definition: am-diag-gmm.cc:72
void SetInvVarsAndMeans(const MatrixBase< Real > &invvars, const MatrixBase< Real > &means)
Use SetInvVarsAndMeans if updating both means and (inverse) variances.
Definition: diag-gmm-inl.h:63
void Merge(int32 target_components, std::vector< int32 > *history=NULL)
Merge the components and remember the order in which the components were merged (flat list of pairs) ...
Definition: diag-gmm.cc:295
int32 ComputeGconsts()
Sets the gconsts for all the PDFs.
Definition: am-diag-gmm.cc:90
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void MergeByCount(const Vector< BaseFloat > &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count)
Definition: am-diag-gmm.cc:125
void GetComponentMean(int32 gauss, VectorBase< Real > *out) const
Accessor for single component mean.
Definition: diag-gmm-inl.h:135
void Resize(int32 nMix, int32 dim)
Resizes arrays to this dim. Does not initialize data.
Definition: diag-gmm.cc:66
kaldi::int32 int32
BaseFloat ClusterBottomUpCompartmentalized(const std::vector< std::vector< Clusterable *> > &points, BaseFloat thresh, int32 min_clust, std::vector< std::vector< Clusterable *> > *clusters_out, std::vector< std::vector< int32 > > *assignments_out)
This is a bottom-up clustering where the points are pre-clustered in a set of compartments, such that only points in the same compartment are clustered together.
void GetSplitTargets(const Vector< BaseFloat > &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count, std::vector< int32 > *targets)
Get Gaussian-mixture or substate-mixture splitting targets, according to a power rule (e...
void AddVec2(const Real alpha, const VectorBase< Real > &v)
Add vector : *this = *this + alpha * rv^2 [element-wise squaring].
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
float BaseFloat
Definition: kaldi-types.h:29
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
void ClusterGaussiansToUbm(const AmDiagGmm &am, const Vector< BaseFloat > &state_occs, UbmClusteringOptions opts, DiagGmm *ubm_out)
Clusters the Gaussians in an acoustic model to a single GMM with specified number of components...
Definition: am-diag-gmm.cc:195
void GetGaussianVariance(int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
Definition: am-diag-gmm.h:138
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
BaseFloat ClusterBottomUp(const std::vector< Clusterable *> &points, BaseFloat max_merge_thresh, int32 min_clust, std::vector< Clusterable *> *clusters_out, std::vector< int32 > *assignments_out)
A bottom-up clustering algorithm.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
const Vector< BaseFloat > & weights() const
Definition: diag-gmm.h:178
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
void Scale(Real alpha)
Multiplies all elements by this constant.
Real Sum() const
Returns sum of the elements.
int32 Dim() const
Definition: am-diag-gmm.h:79
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
void GetComponentVariance(int32 gauss, VectorBase< Real > *out) const
Accessor for single component variance.
Definition: diag-gmm-inl.h:145
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
Definition: am-diag-gmm.h:119
void InvertElements()
Inverts all the elements of the matrix.
A class representing a vector.
Definition: kaldi-vector.h:406
SubVector< double > x2_stats() const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
void GetGaussianMean(int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
Definition: am-diag-gmm.h:131
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void CopyRowFromVec(const VectorBase< Real > &v, const MatrixIndexT row)
Copy vector into specific row of matrix.
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
void SetWeights(const VectorBase< Real > &w)
Mutators for both float or double.
Definition: diag-gmm-inl.h:28
GaussClusterable wraps Gaussian statistics in a form accessible to generic clustering algorithms...
#define KALDI_LOG
Definition: kaldi-error.h:153
void Init(const DiagGmm &proto, int32 num_pdfs)
Initializes with a single "prototype" GMM.
Definition: am-diag-gmm.cc:38
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
SubVector< double > x_stats() const
void SplitByCount(const Vector< BaseFloat > &state_occs, int32 target_components, float perturb_factor, BaseFloat power, BaseFloat min_count)
Definition: am-diag-gmm.cc:102