DiagGmm

kaldi Diagonal Gaussian Mixture Models More...

Classes

class  AmDiagGmm
 
struct  UbmClusteringOptions
 

Functions

 AmDiagGmm ()
 
 ~AmDiagGmm ()
 
void Init (const DiagGmm &proto, int32 num_pdfs)
 Initializes with a single "prototype" GMM. More...
 
void AddPdf (const DiagGmm &gmm)
 Adds a GMM to the model, and increments the total number of PDFs. More...
 
void CopyFromAmDiagGmm (const AmDiagGmm &other)
 Copies the parameters from another model. Allocates necessary memory. More...
 
void SplitPdf (int32 idx, int32 target_components, float perturb_factor)
 
void SplitByCount (const Vector< BaseFloat > &state_occs, int32 target_components, float perturb_factor, BaseFloat power, BaseFloat min_count)
 
void MergeByCount (const Vector< BaseFloat > &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count)
 
int32 ComputeGconsts ()
 Sets the gconsts for all the PDFs. More...
 
BaseFloat LogLikelihood (const int32 pdf_index, const VectorBase< BaseFloat > &data) const
 
void Read (std::istream &in_stream, bool binary)
 
void Write (std::ostream &out_stream, bool binary) const
 
int32 Dim () const
 
int32 NumPdfs () const
 
int32 NumGauss () const
 
int32 NumGaussInPdf (int32 pdf_index) const
 
DiagGmmGetPdf (int32 pdf_index)
 Accessors. More...
 
const DiagGmmGetPdf (int32 pdf_index) const
 
void GetGaussianMean (int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
 
void GetGaussianVariance (int32 pdf_index, int32 gauss, VectorBase< BaseFloat > *out) const
 
void SetGaussianMean (int32 pdf_index, int32 gauss_index, const VectorBase< BaseFloat > &in)
 Mutators. More...
 
void RemovePdf (int32 pdf_index)
 
 KALDI_DISALLOW_COPY_AND_ASSIGN (AmDiagGmm)
 
 UbmClusteringOptions ()
 
 UbmClusteringOptions (int32 ncomp, BaseFloat red, int32 interm_gauss, BaseFloat vfloor, int32 max_am_gauss)
 
void Register (OptionsItf *opts)
 
void Check ()
 
void ClusterGaussiansToUbm (const AmDiagGmm &am, const Vector< BaseFloat > &state_occs, UbmClusteringOptions opts, DiagGmm *ubm_out)
 Clusters the Gaussians in an acoustic model to a single GMM with specified number of components. More...
 

Variables

std::vector< DiagGmm * > densities_
 
int32 ubm_num_gauss
 
BaseFloat reduce_state_factor
 
int32 intermediate_num_gauss
 
BaseFloat cluster_varfloor
 
int32 max_am_gauss
 

Detailed Description

kaldi Diagonal Gaussian Mixture Models

Function Documentation

◆ AddPdf()

void AddPdf ( const DiagGmm gmm)

Adds a GMM to the model, and increments the total number of PDFs.

Definition at line 57 of file am-diag-gmm.cc.

References CopyFromDiagGmm(), AmDiagGmm::densities_, Dim(), AmDiagGmm::Dim(), and KALDI_ASSERT.

Referenced by AmDiagGmm::AmDiagGmm(), kaldi::InitAmGmm(), kaldi::InitAmGmmFromOld(), main(), UnitTestAmDiagGmm(), UnitTestMleAmDiagGmm(), UnitTestRegressionTree(), and kaldi::UnitTestRegtreeFmllrDiagGmm().

57  {
58  if (densities_.size() != 0) // not the first gmm
59  KALDI_ASSERT(gmm.Dim() == this->Dim());
60 
61  DiagGmm *gmm_ptr = new DiagGmm();
62  gmm_ptr->CopyFromDiagGmm(gmm);
63  densities_.push_back(gmm_ptr);
64 }
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
int32 Dim() const
Definition: am-diag-gmm.h:79
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ AmDiagGmm()

◆ Check()

void Check ( )

Definition at line 178 of file am-diag-gmm.cc.

References KALDI_ERR.

Referenced by kaldi::ClusterGaussiansToUbm(), and main().

178  {
180  KALDI_ERR << "Invalid parameters: --ubm-num_gauss=" << ubm_num_gauss
181  << " > --intermediate-num_gauss=" << intermediate_num_gauss;
183  KALDI_ERR << "Invalid parameters: --ubm-num_gauss=" << ubm_num_gauss
184  << " > --max-am-gauss=" << max_am_gauss;
185  if (ubm_num_gauss <= 0)
186  KALDI_ERR << "Invalid parameters: --ubm-num_gauss=" << ubm_num_gauss;
187  if (cluster_varfloor <= 0)
188  KALDI_ERR << "Invalid parameters: --cluster-varfloor="
189  << cluster_varfloor;
190  if (reduce_state_factor <= 0 || reduce_state_factor > 1)
191  KALDI_ERR << "Invalid parameters: --reduce-state-factor="
193 }
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ ClusterGaussiansToUbm()

void ClusterGaussiansToUbm ( const AmDiagGmm am,
const Vector< BaseFloat > &  state_occs,
UbmClusteringOptions  opts,
DiagGmm ubm_out 
)

Clusters the Gaussians in an acoustic model to a single GMM with specified number of components.

First the each state is mixed-down to a single Gaussian, then the states are clustered by clustering these Gaussians in a bottom-up fashion. Number of clusters is determined by reduce_state_factor. The Gaussians for each cluster of states are then merged based on the least likelihood reduction till there are intermediate_numcomp Gaussians, which are then merged into ubm_num_gauss Gaussians. This is the UBM initialization algorithm described in section 2.1 of Povey, et al., "The subspace Gaussian mixture model - A structured model for speech recognition", In Computer Speech and Language, April 2011.

Definition at line 195 of file am-diag-gmm.cc.

References VectorBase< Real >::AddVec2(), UbmClusteringOptions::Check(), UbmClusteringOptions::cluster_varfloor, kaldi::ClusterBottomUp(), kaldi::ClusterBottomUpCompartmentalized(), AmDiagGmm::CopyFromAmDiagGmm(), CopyFromDiagGmm(), VectorBase< Real >::CopyFromVec(), MatrixBase< Real >::CopyRowFromVec(), GaussClusterable::count(), kaldi::DeletePointers(), AmDiagGmm::Dim(), GetComponentMean(), GetComponentVariance(), AmDiagGmm::GetGaussianMean(), AmDiagGmm::GetGaussianVariance(), AmDiagGmm::GetPdf(), rnnlm::i, UbmClusteringOptions::intermediate_num_gauss, MatrixBase< Real >::InvertElements(), KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, KALDI_WARN, UbmClusteringOptions::max_am_gauss, Merge(), AmDiagGmm::MergeByCount(), NumGauss(), AmDiagGmm::NumGauss(), AmDiagGmm::NumPdfs(), UbmClusteringOptions::reduce_state_factor, Resize(), MatrixBase< Real >::Row(), VectorBase< Real >::Scale(), SetInvVarsAndMeans(), SetWeights(), VectorBase< Real >::Sum(), UbmClusteringOptions::ubm_num_gauss, weights(), GaussClusterable::x2_stats(), and GaussClusterable::x_stats().

Referenced by main(), UbmClusteringOptions::Register(), and TestClustering().

198  {
199  opts.Check(); // Make sure the various # of Gaussians make sense.
200  if (am.NumGauss() > opts.max_am_gauss) {
201  KALDI_LOG << "ClusterGaussiansToUbm: first reducing num-gauss from " << am.NumGauss()
202  << " to " << opts.max_am_gauss;
203  AmDiagGmm tmp_am;
204  tmp_am.CopyFromAmDiagGmm(am);
205  BaseFloat power = 1.0, min_count = 1.0; // Make the power 1, which I feel
206  // is appropriate to the way we're doing the overall clustering procedure.
207  tmp_am.MergeByCount(state_occs, opts.max_am_gauss, power, min_count);
208 
209  if (tmp_am.NumGauss() > opts.max_am_gauss) {
210  KALDI_LOG << "Clustered down to " << tmp_am.NumGauss()
211  << "; will not cluster further";
212  opts.max_am_gauss = tmp_am.NumGauss();
213  }
214  ClusterGaussiansToUbm(tmp_am, state_occs, opts, ubm_out);
215  return;
216  }
217 
218  int32 num_pdfs = static_cast<int32>(am.NumPdfs()),
219  dim = am.Dim(),
220  num_clust_states = static_cast<int32>(opts.reduce_state_factor*num_pdfs);
221 
222  Vector<BaseFloat> tmp_mean(dim);
223  Vector<BaseFloat> tmp_var(dim);
224  DiagGmm tmp_gmm;
225  vector<Clusterable*> states;
226  states.reserve(num_pdfs); // NOT resize(); uses push_back.
227 
228  // Replace the GMM for each state with a single Gaussian.
229  KALDI_VLOG(1) << "Merging densities to 1 Gaussian per state.";
230  for (int32 pdf_index = 0; pdf_index < num_pdfs; pdf_index++) {
231  KALDI_VLOG(3) << "Merging Gausians for state : " << pdf_index;
232  tmp_gmm.CopyFromDiagGmm(am.GetPdf(pdf_index));
233  tmp_gmm.Merge(1);
234  tmp_gmm.GetComponentMean(0, &tmp_mean);
235  tmp_gmm.GetComponentVariance(0, &tmp_var);
236  tmp_var.AddVec2(1.0, tmp_mean); // make it x^2 stats.
237  // It may cause problems downstream if we add states with zero weights (see
238  // KALDI_ASSERT(weight > 0) below), so we put in a very small floor.
239  // These states with tiny weights will later get merged into other states.
240  BaseFloat this_weight = 1.0e-10 + state_occs(pdf_index);
241  tmp_mean.Scale(this_weight);
242  tmp_var.Scale(this_weight);
243  states.push_back(new GaussClusterable(tmp_mean, tmp_var,
244  opts.cluster_varfloor, this_weight));
245  }
246 
247  // Bottom-up clustering of the Gaussians corresponding to each state, which
248  // gives a partial clustering of states in the 'state_clusters' vector.
249  vector<int32> state_clusters;
250  KALDI_VLOG(1) << "Creating " << num_clust_states << " clusters of states.";
251  ClusterBottomUp(states, std::numeric_limits<BaseFloat>::max(), num_clust_states,
252  NULL /*actual clusters not needed*/,
253  &state_clusters /*get the cluster assignments*/);
254  DeletePointers(&states);
255 
256  // For each cluster of states, create a pool of all the Gaussians in those
257  // states, weighted by the state occupancies. This is done so that initially
258  // only the Gaussians corresponding to "similar" states (similarity as
259  // determined by the previous clustering) are merged.
260  vector< vector<Clusterable*> > state_clust_gauss;
261  state_clust_gauss.resize(num_clust_states);
262  for (int32 pdf_index = 0; pdf_index < num_pdfs; pdf_index++) {
263  int32 current_cluster = state_clusters[pdf_index];
264  for (int32 num_gauss = am.GetPdf(pdf_index).NumGauss(),
265  gauss_index = 0; gauss_index < num_gauss; ++gauss_index) {
266  am.GetGaussianMean(pdf_index, gauss_index, &tmp_mean);
267  am.GetGaussianVariance(pdf_index, gauss_index, &tmp_var);
268  tmp_var.AddVec2(1.0, tmp_mean); // make it x^2 stats.
269  // adding 1.0e-10 to the weight will prevent problems later on, see
270  // the line KALDI_ASSERT(weight > 0.0).
271  BaseFloat this_weight = (1.0e-10 + state_occs(pdf_index)) *
272  (am.GetPdf(pdf_index).weights())(gauss_index);
273  tmp_mean.Scale(this_weight);
274  tmp_var.Scale(this_weight);
275  state_clust_gauss[current_cluster].push_back(new GaussClusterable(
276  tmp_mean, tmp_var, opts.cluster_varfloor, this_weight));
277  }
278  }
279 
280  // This is an unlikely operating scenario, no need to handle this in a more
281  // optimized fashion.
282  if (opts.intermediate_num_gauss > am.NumGauss()) {
283  KALDI_WARN << "Intermediate num_gauss " << opts.intermediate_num_gauss
284  << " is more than num-gauss " << am.NumGauss()
285  << ", reducing it to " << am.NumGauss();
286  opts.intermediate_num_gauss = am.NumGauss();
287  }
288 
289  // The compartmentalized clusterer used below does not merge compartments.
290  if (opts.intermediate_num_gauss < num_clust_states) {
291  KALDI_WARN << "Intermediate num_gauss " << opts.intermediate_num_gauss
292  << " is less than # of preclustered states " << num_clust_states
293  << ", increasing it to " << num_clust_states;
294  opts.intermediate_num_gauss = num_clust_states;
295  }
296 
297  KALDI_VLOG(1) << "Merging from " << am.NumGauss() << " Gaussians in the "
298  << "acoustic model, down to " << opts.intermediate_num_gauss
299  << " Gaussians.";
300  vector< vector<Clusterable*> > gauss_clusters_out;
301  ClusterBottomUpCompartmentalized(state_clust_gauss, std::numeric_limits<BaseFloat>::max(),
302  opts.intermediate_num_gauss,
303  &gauss_clusters_out, NULL);
304  for (int32 clust_index = 0; clust_index < num_clust_states; clust_index++)
305  DeletePointers(&state_clust_gauss[clust_index]);
306 
307  // Next, put the remaining clustered Gaussians into a single GMM.
308  KALDI_VLOG(1) << "Putting " << opts.intermediate_num_gauss << " Gaussians "
309  << "into a single GMM for final merge step.";
310  Matrix<BaseFloat> tmp_means(opts.intermediate_num_gauss, dim);
311  Matrix<BaseFloat> tmp_vars(opts.intermediate_num_gauss, dim);
312  Vector<BaseFloat> tmp_weights(opts.intermediate_num_gauss);
313  Vector<BaseFloat> tmp_vec(dim);
314  int32 gauss_index = 0;
315  for (int32 clust_index = 0; clust_index < num_clust_states; clust_index++) {
316  for (int32 i = gauss_clusters_out[clust_index].size()-1; i >=0; --i) {
317  GaussClusterable *this_cluster = static_cast<GaussClusterable*>(
318  gauss_clusters_out[clust_index][i]);
319  BaseFloat weight = this_cluster->count();
320  KALDI_ASSERT(weight > 0.0);
321  tmp_weights(gauss_index) = weight;
322  tmp_vec.CopyFromVec(this_cluster->x_stats());
323  tmp_vec.Scale(1.0 / weight);
324  tmp_means.CopyRowFromVec(tmp_vec, gauss_index);
325  tmp_vec.CopyFromVec(this_cluster->x2_stats());
326  tmp_vec.Scale(1.0 / weight);
327  tmp_vec.AddVec2(-1.0, tmp_means.Row(gauss_index)); // x^2 stats to var.
328  tmp_vars.CopyRowFromVec(tmp_vec, gauss_index);
329  gauss_index++;
330  }
331  DeletePointers(&(gauss_clusters_out[clust_index]));
332  }
333  tmp_gmm.Resize(opts.intermediate_num_gauss, dim);
334  tmp_weights.Scale(1.0/tmp_weights.Sum());
335  tmp_gmm.SetWeights(tmp_weights);
336  tmp_vars.InvertElements(); // need inverse vars...
337  tmp_gmm.SetInvVarsAndMeans(tmp_vars, tmp_means);
338 
339  // Finally, cluster to the desired number of Gaussians in the UBM.
340  if (opts.ubm_num_gauss < tmp_gmm.NumGauss()) {
341  tmp_gmm.Merge(opts.ubm_num_gauss);
342  KALDI_VLOG(1) << "Merged down to " << tmp_gmm.NumGauss() << " Gaussians.";
343  } else {
344  KALDI_WARN << "Not merging Gaussians since " << opts.ubm_num_gauss
345  << " < " << tmp_gmm.NumGauss();
346  }
347  ubm_out->CopyFromDiagGmm(tmp_gmm);
348 }
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
kaldi::int32 int32
BaseFloat ClusterBottomUpCompartmentalized(const std::vector< std::vector< Clusterable *> > &points, BaseFloat thresh, int32 min_clust, std::vector< std::vector< Clusterable *> > *clusters_out, std::vector< std::vector< int32 > > *assignments_out)
This is a bottom-up clustering where the points are pre-clustered in a set of compartments, such that only points in the same compartment are clustered together.
float BaseFloat
Definition: kaldi-types.h:29
void ClusterGaussiansToUbm(const AmDiagGmm &am, const Vector< BaseFloat > &state_occs, UbmClusteringOptions opts, DiagGmm *ubm_out)
Clusters the Gaussians in an acoustic model to a single GMM with specified number of components...
Definition: am-diag-gmm.cc:195
BaseFloat ClusterBottomUp(const std::vector< Clusterable *> &points, BaseFloat max_merge_thresh, int32 min_clust, std::vector< Clusterable *> *clusters_out, std::vector< int32 > *assignments_out)
A bottom-up clustering algorithm.
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ ComputeGconsts()

int32 ComputeGconsts ( )

Sets the gconsts for all the PDFs.

Returns the total number of Gaussians over all PDFs that are "invalid" e.g. due to zero weights or variances.

Definition at line 90 of file am-diag-gmm.cc.

References AmDiagGmm::densities_, and KALDI_WARN.

Referenced by AmDiagGmm::AmDiagGmm(), kaldi::DoRescalingUpdate(), and RegtreeMllrDiagGmm::TransformModel().

90  {
91  int32 num_bad = 0;
92  for (std::vector<DiagGmm*>::iterator itr = densities_.begin(),
93  end = densities_.end(); itr != end; ++itr) {
94  num_bad += (*itr)->ComputeGconsts();
95  }
96  if (num_bad > 0)
97  KALDI_WARN << "Found " << num_bad << " Gaussian components.";
98  return num_bad;
99 }
kaldi::int32 int32
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99

◆ CopyFromAmDiagGmm()

void CopyFromAmDiagGmm ( const AmDiagGmm other)

Copies the parameters from another model. Allocates necessary memory.

Definition at line 79 of file am-diag-gmm.cc.

References kaldi::DeletePointers(), AmDiagGmm::densities_, rnnlm::i, and AmDiagGmm::NumPdfs().

Referenced by AmDiagGmm::AmDiagGmm(), kaldi::ClusterGaussiansToUbm(), main(), TestAmDiagGmmAccsIO(), TestMllrAccsIO(), TestSplitStates(), and TestXformMean().

79  {
80  if (densities_.size() != 0) {
82  }
83  densities_.resize(other.NumPdfs(), NULL);
84  for (int32 i = 0, end = densities_.size(); i < end; i++) {
85  densities_[i] = new DiagGmm();
86  densities_[i]->CopyFromDiagGmm(*other.densities_[i]);
87  }
88 }
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
kaldi::int32 int32
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99

◆ Dim()

◆ GetGaussianMean()

void GetGaussianMean ( int32  pdf_index,
int32  gauss,
VectorBase< BaseFloat > *  out 
) const
inline

Definition at line 131 of file am-diag-gmm.h.

References AmDiagGmm::densities_, and KALDI_ASSERT.

Referenced by RegressionTree::BuildTree(), kaldi::ClusterGaussiansToUbm(), RegtreeMllrDiagGmm::GetTransformedMeans(), AmDiagGmm::NumPdfs(), and RegtreeMllrDiagGmm::TransformModel().

132  {
133  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
134  && (densities_[pdf_index] != NULL));
135  densities_[pdf_index]->GetComponentMean(gauss, out);
136 }
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetGaussianVariance()

void GetGaussianVariance ( int32  pdf_index,
int32  gauss,
VectorBase< BaseFloat > *  out 
) const
inline

Definition at line 138 of file am-diag-gmm.h.

References AmDiagGmm::densities_, and KALDI_ASSERT.

Referenced by RegressionTree::BuildTree(), kaldi::ClusterGaussiansToUbm(), and AmDiagGmm::NumPdfs().

139  {
140  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
141  && (densities_[pdf_index] != NULL));
142  densities_[pdf_index]->GetComponentVariance(gauss, out);
143 }
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetPdf() [1/2]

DiagGmm & GetPdf ( int32  pdf_index)
inline

Accessors.

Definition at line 119 of file am-diag-gmm.h.

References AmDiagGmm::densities_, and KALDI_ASSERT.

Referenced by kaldi::AccStatsForUtterance(), AccumAmDiagGmm::AccumulateForGaussian(), RegtreeMllrDiagGmmAccs::AccumulateForGaussian(), RegtreeFmllrDiagGmmAccs::AccumulateForGaussian(), AccumAmDiagGmm::AccumulateForGmm(), RegtreeMllrDiagGmmAccs::AccumulateForGmm(), RegtreeFmllrDiagGmmAccs::AccumulateForGmm(), AccumAmDiagGmm::AccumulateForGmmTwofeats(), kaldi::AccumulateForUtterance(), RegressionTree::BuildTree(), kaldi::ClusterGaussiansToUbm(), BasisFmllrEstimate::ComputeAmDiagPrecond(), kaldi::ComputeAmGmmFeatureDeriv(), kaldi::DoRescalingUpdate(), SingleUtteranceGmmDecoder::EstimateFmllr(), SingleUtteranceGmmDecoder::GetGaussianPosteriors(), kaldi::GetStatsDerivative(), RegtreeMllrDiagGmm::GetTransformedMeans(), DecodableAmDiagGmmRegtreeMllr::GetXformedMeanInvVars(), AccumAmDiagGmm::Init(), kaldi::InitAmGmmFromOld(), kaldi::IsmoothStatsAmDiagGmmFromModel(), DecodableAmDiagGmmRegtreeFmllr::LogLikelihoodZeroBased(), DecodableAmDiagGmmUnmapped::LogLikelihoodZeroBased(), DecodableAmDiagGmmRegtreeMllr::LogLikelihoodZeroBased(), main(), kaldi::MapAmDiagGmmUpdate(), kaldi::MleAmDiagGmmUpdate(), AmDiagGmm::NumPdfs(), kaldi::ResizeModel(), TestXformMean(), kaldi::UpdateEbwAmDiagGmm(), and kaldi::UpdateEbwWeightsAmDiagGmm().

119  {
120  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
121  && (densities_[pdf_index] != NULL));
122  return *(densities_[pdf_index]);
123 }
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetPdf() [2/2]

const DiagGmm & GetPdf ( int32  pdf_index) const
inline

Definition at line 125 of file am-diag-gmm.h.

References AmDiagGmm::densities_, and KALDI_ASSERT.

125  {
126  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
127  && (densities_[pdf_index] != NULL));
128  return *(densities_[pdf_index]);
129 }
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Init()

void Init ( const DiagGmm proto,
int32  num_pdfs 
)

Initializes with a single "prototype" GMM.

Definition at line 38 of file am-diag-gmm.cc.

References kaldi::DeletePointers(), AmDiagGmm::densities_, and KALDI_WARN.

Referenced by AmDiagGmm::AmDiagGmm(), and UnitTestRegtreeMllrDiagGmm().

38  {
39  if (densities_.size() != 0) {
40  KALDI_WARN << "Init() called on a non-empty object. Contents will be "
41  "overwritten";
43  }
44  if (num_pdfs == 0) {
45  KALDI_WARN << "Init() called with number of pdfs = 0. Will do nothing.";
46  return;
47  }
48 
49  densities_.resize(num_pdfs, NULL);
50  for (vector<DiagGmm*>::iterator itr = densities_.begin(),
51  end = densities_.end(); itr != end; ++itr) {
52  *itr = new DiagGmm();
53  (*itr)->CopyFromDiagGmm(proto);
54  }
55 }
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( AmDiagGmm  )
private

◆ LogLikelihood()

BaseFloat LogLikelihood ( const int32  pdf_index,
const VectorBase< BaseFloat > &  data 
) const
inline

◆ MergeByCount()

void MergeByCount ( const Vector< BaseFloat > &  state_occs,
int32  target_components,
BaseFloat  power,
BaseFloat  min_count 
)

Definition at line 125 of file am-diag-gmm.cc.

References AmDiagGmm::densities_, kaldi::GetSplitTargets(), rnnlm::i, KALDI_LOG, AmDiagGmm::NumGauss(), and AmDiagGmm::NumPdfs().

Referenced by AmDiagGmm::AmDiagGmm(), kaldi::ClusterGaussiansToUbm(), and main().

128  {
129  int32 gauss_at_start = NumGauss();
130  std::vector<int32> targets;
131  GetSplitTargets(state_occs, target_components,
132  power, min_count, &targets);
133 
134  for (int32 i = 0; i < NumPdfs(); i++) {
135  if (targets[i] == 0) targets[i] = 1; // can't merge below 1.
136  if (densities_[i]->NumGauss() > targets[i])
137  densities_[i]->Merge(targets[i]);
138  }
139 
140  KALDI_LOG << "Merged " << NumPdfs() << " states with target = "
141  << target_components << ", power = " << power
142  << " and min_count = " << min_count
143  << ", merged from " << gauss_at_start << " to "
144  << NumGauss();
145 }
int32 NumGauss() const
Definition: am-diag-gmm.cc:72
kaldi::int32 int32
void GetSplitTargets(const Vector< BaseFloat > &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count, std::vector< int32 > *targets)
Get Gaussian-mixture or substate-mixture splitting targets, according to a power rule (e...
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ NumGauss()

int32 NumGauss ( ) const

Definition at line 72 of file am-diag-gmm.cc.

References AmDiagGmm::densities_, and rnnlm::i.

Referenced by RegressionTree::BuildTree(), kaldi::ClusterGaussiansToUbm(), main(), RegressionTree::MakeGauss2Bclass(), AmDiagGmm::MergeByCount(), kaldi::MleAmDiagGmmUpdate(), AmDiagGmm::NumPdfs(), RegressionTree::Read(), AmDiagGmm::SplitByCount(), TestClustering(), and TestSplitStates().

72  {
73  int32 ans = 0;
74  for (size_t i = 0; i < densities_.size(); i++)
75  ans += densities_[i]->NumGauss();
76  return ans;
77 }
int32 NumGauss() const
Definition: am-diag-gmm.cc:72
kaldi::int32 int32
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99

◆ NumGaussInPdf()

int32 NumGaussInPdf ( int32  pdf_index) const
inline

Definition at line 113 of file am-diag-gmm.h.

References AmDiagGmm::densities_, and KALDI_ASSERT.

Referenced by RegressionTree::MakeGauss2Bclass(), AmDiagGmm::NumPdfs(), and kaldi::UnitTestRegtreeFmllrDiagGmm().

113  {
114  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
115  && (densities_[pdf_index] != NULL));
116  return densities_[pdf_index]->NumGauss();
117 }
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ NumPdfs()

◆ Read()

void Read ( std::istream &  in_stream,
bool  binary 
)

Definition at line 147 of file am-diag-gmm.cc.

References AmDiagGmm::densities_, kaldi::ExpectToken(), rnnlm::i, KALDI_ASSERT, and kaldi::ReadBasicType().

Referenced by AmDiagGmm::AmDiagGmm(), kaldi::InitAmGmmFromOld(), main(), and TestAmDiagGmmIO().

147  {
148  int32 num_pdfs, dim;
149 
150  ExpectToken(in_stream, binary, "<DIMENSION>");
151  ReadBasicType(in_stream, binary, &dim);
152  ExpectToken(in_stream, binary, "<NUMPDFS>");
153  ReadBasicType(in_stream, binary, &num_pdfs);
154  KALDI_ASSERT(num_pdfs > 0);
155  densities_.reserve(num_pdfs);
156  for (int32 i = 0; i < num_pdfs; i++) {
157  densities_.push_back(new DiagGmm());
158  densities_.back()->Read(in_stream, binary);
159  KALDI_ASSERT(densities_.back()->Dim() == dim);
160  }
161 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
kaldi::int32 int32
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Register()

void Register ( OptionsItf opts)
inline

Definition at line 176 of file am-diag-gmm.h.

References kaldi::ClusterGaussiansToUbm(), and OptionsItf::Register().

Referenced by main().

176  {
177  std::string module = "UbmClusteringOptions: ";
178  opts->Register("max-am-gauss", &max_am_gauss, module+
179  "We first reduce acoustic model to this max #Gauss before clustering.");
180  opts->Register("ubm-num-gauss", &ubm_num_gauss, module+
181  "Number of Gaussians components in the final UBM.");
182  opts->Register("ubm-numcomps", &ubm_num_gauss, module+
183  "Backward compatibility option (see ubm-num-gauss)");
184  opts->Register("reduce-state-factor", &reduce_state_factor, module+
185  "Intermediate number of clustered states (as fraction of total states).");
186  opts->Register("intermediate-num-gauss", &intermediate_num_gauss, module+
187  "Intermediate number of merged Gaussian components.");
188  opts->Register("intermediate-numcomps", &intermediate_num_gauss, module+
189  "Backward compatibility option (see intermediate-num-gauss)");
190  opts->Register("cluster-varfloor", &cluster_varfloor, module+
191  "Variance floor used in bottom-up state clustering.");
192  }

◆ RemovePdf()

void RemovePdf ( int32  pdf_index)
private

Definition at line 66 of file am-diag-gmm.cc.

References AmDiagGmm::densities_, and KALDI_ASSERT.

66  {
67  KALDI_ASSERT(static_cast<size_t>(pdf_index) < densities_.size());
68  delete densities_[pdf_index];
69  densities_.erase(densities_.begin() + pdf_index);
70 }
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetGaussianMean()

void SetGaussianMean ( int32  pdf_index,
int32  gauss_index,
const VectorBase< BaseFloat > &  in 
)
inline

Mutators.

Definition at line 145 of file am-diag-gmm.h.

References AmDiagGmm::densities_, and KALDI_ASSERT.

Referenced by AmDiagGmm::NumPdfs(), and RegtreeMllrDiagGmm::TransformModel().

146  {
147  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
148  && (densities_[pdf_index] != NULL));
149  densities_[pdf_index]->SetComponentMean(gauss_index, in);
150 }
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SplitByCount()

void SplitByCount ( const Vector< BaseFloat > &  state_occs,
int32  target_components,
float  perturb_factor,
BaseFloat  power,
BaseFloat  min_count 
)

Definition at line 102 of file am-diag-gmm.cc.

References AmDiagGmm::densities_, kaldi::GetSplitTargets(), rnnlm::i, KALDI_LOG, AmDiagGmm::NumGauss(), and AmDiagGmm::NumPdfs().

Referenced by AmDiagGmm::AmDiagGmm(), main(), and TestSplitStates().

105  {
106  int32 gauss_at_start = NumGauss();
107  std::vector<int32> targets;
108  GetSplitTargets(state_occs, target_components, power,
109  min_count, &targets);
110 
111  for (int32 i = 0; i < NumPdfs(); i++) {
112  if (densities_[i]->NumGauss() < targets[i])
113  densities_[i]->Split(targets[i], perturb_factor);
114  }
115 
116  KALDI_LOG << "Split " << NumPdfs() << " states with target = "
117  << target_components << ", power = " << power
118  << ", perturb_factor = " << perturb_factor
119  << " and min_count = " << min_count
120  << ", split #Gauss from " << gauss_at_start << " to "
121  << NumGauss();
122 }
int32 NumGauss() const
Definition: am-diag-gmm.cc:72
kaldi::int32 int32
void GetSplitTargets(const Vector< BaseFloat > &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count, std::vector< int32 > *targets)
Get Gaussian-mixture or substate-mixture splitting targets, according to a power rule (e...
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ SplitPdf()

void SplitPdf ( int32  idx,
int32  target_components,
float  perturb_factor 
)
inline

Definition at line 152 of file am-diag-gmm.h.

References AmDiagGmm::densities_, and KALDI_ASSERT.

Referenced by AmDiagGmm::AmDiagGmm(), and kaldi::UnitTestRegtreeFmllrDiagGmm().

154  {
155  KALDI_ASSERT((static_cast<size_t>(pdf_index) < densities_.size())
156  && (densities_[pdf_index] != NULL));
157  densities_[pdf_index]->Split(target_components, perturb_factor);
158 }
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ UbmClusteringOptions() [1/2]

Definition at line 167 of file am-diag-gmm.h.

◆ UbmClusteringOptions() [2/2]

UbmClusteringOptions ( int32  ncomp,
BaseFloat  red,
int32  interm_gauss,
BaseFloat  vfloor,
int32  max_am_gauss 
)
inline

Definition at line 171 of file am-diag-gmm.h.

◆ Write()

void Write ( std::ostream &  out_stream,
bool  binary 
) const

Definition at line 163 of file am-diag-gmm.cc.

References AmDiagGmm::densities_, AmDiagGmm::Dim(), KALDI_WARN, kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by AmDiagGmm::AmDiagGmm(), main(), and TestAmDiagGmmIO().

163  {
164  int32 dim = this->Dim();
165  if (dim == 0) {
166  KALDI_WARN << "Trying to write empty AmDiagGmm object.";
167  }
168  WriteToken(out_stream, binary, "<DIMENSION>");
169  WriteBasicType(out_stream, binary, dim);
170  WriteToken(out_stream, binary, "<NUMPDFS>");
171  WriteBasicType(out_stream, binary, static_cast<int32>(densities_.size()));
172  for (std::vector<DiagGmm*>::const_iterator it = densities_.begin(),
173  end = densities_.end(); it != end; ++it) {
174  (*it)->Write(out_stream, binary);
175  }
176 }
kaldi::int32 int32
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
int32 Dim() const
Definition: am-diag-gmm.h:79
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

◆ ~AmDiagGmm()

~AmDiagGmm ( )

Definition at line 34 of file am-diag-gmm.cc.

References kaldi::DeletePointers(), and AmDiagGmm::densities_.

Referenced by AmDiagGmm::AmDiagGmm().

34  {
36 }
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
std::vector< DiagGmm * > densities_
Definition: am-diag-gmm.h:99

Variable Documentation

◆ cluster_varfloor

BaseFloat cluster_varfloor

Definition at line 164 of file am-diag-gmm.h.

Referenced by kaldi::ClusterGaussiansToUbm().

◆ densities_

◆ intermediate_num_gauss

int32 intermediate_num_gauss

Definition at line 163 of file am-diag-gmm.h.

Referenced by kaldi::ClusterGaussiansToUbm().

◆ max_am_gauss

int32 max_am_gauss

Definition at line 165 of file am-diag-gmm.h.

Referenced by kaldi::ClusterGaussiansToUbm().

◆ reduce_state_factor

BaseFloat reduce_state_factor

Definition at line 162 of file am-diag-gmm.h.

Referenced by kaldi::ClusterGaussiansToUbm().

◆ ubm_num_gauss

int32 ubm_num_gauss

Definition at line 161 of file am-diag-gmm.h.

Referenced by kaldi::ClusterGaussiansToUbm().