am-diag-gmm-test.cc
Go to the documentation of this file.
1 // gmm/am-diag-gmm-test.cc
2 
3 // Copyright 2009-2011 Saarland University
4 // Author: Arnab Ghoshal
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "gmm/model-test-common.h"
22 #include "gmm/am-diag-gmm.h"
23 #include "util/kaldi-io.h"
24 
25 using kaldi::AmDiagGmm;
26 using kaldi::int32;
27 using kaldi::BaseFloat;
28 namespace ut = kaldi::unittest;
29 
30 // Tests the Read() and Write() methods, in both binary and ASCII mode, as well
31 // as Check(), CopyFromSgmm(), and methods in likelihood computations.
32 void TestAmDiagGmmIO(const AmDiagGmm &am_gmm) {
33  int32 dim = am_gmm.Dim();
34 
35  kaldi::Vector<BaseFloat> feat(dim);
36  for (int32 d = 0; d < dim; d++) {
37  feat(d) = kaldi::RandGauss();
38  }
39 
40  BaseFloat loglike = 0.0;
41  for (int32 i = 0; i < am_gmm.NumPdfs(); i++)
42  loglike += am_gmm.LogLikelihood(i, feat);
43  // First, non-binary write
44  am_gmm.Write(kaldi::Output("tmpf", false).Stream(), false);
45 
46  bool binary_in;
47  AmDiagGmm *am_gmm1 = new AmDiagGmm();
48  // Non-binary read
49  kaldi::Input ki1("tmpf", &binary_in);
50  am_gmm1->Read(ki1.Stream(), binary_in);
51  BaseFloat loglike1 = 0.0;
52  for (int32 i = 0; i < am_gmm1->NumPdfs(); i++)
53  loglike1 += am_gmm1->LogLikelihood(i, feat);
54  kaldi::AssertEqual(loglike, loglike1, 1e-4);
55 
56  // Next, binary write
57  am_gmm1->Write(kaldi::Output("tmpfb", true).Stream(), true);
58  delete am_gmm1;
59 
60  AmDiagGmm *am_gmm2 = new AmDiagGmm();
61  // Binary read
62  kaldi::Input ki2("tmpfb", &binary_in);
63  am_gmm2->Read(ki2.Stream(), binary_in);
64  BaseFloat loglike2 = 0.0;
65  for (int32 i = 0; i < am_gmm2->NumPdfs(); i++)
66  loglike2 += am_gmm2->LogLikelihood(i, feat);
67  kaldi::AssertEqual(loglike, loglike2, 1e-4);
68  delete am_gmm2;
69 
70  unlink("tmpf");
71  unlink("tmpfb");
72 }
73 
74 void TestSplitStates(const AmDiagGmm &am_gmm) {
75  int32 target_comp = 2 * am_gmm.NumGauss();
76  kaldi::Vector<BaseFloat> occs(am_gmm.NumPdfs());
77  for (int32 i = 0; i < occs.Dim(); i++)
78  occs(i) = std::fabs(kaldi::RandGauss()) * (kaldi::RandUniform()+1) * 4;
79  AmDiagGmm *am_gmm1 = new AmDiagGmm();
80  am_gmm1->CopyFromAmDiagGmm(am_gmm);
81  am_gmm1->SplitByCount(occs, target_comp, 0.01, 0.2, 0.0);
82 
83  int32 dim = am_gmm.Dim();
84  kaldi::Vector<BaseFloat> feat(dim);
85  for (int32 d = 0; d < dim; d++) {
86  feat(d) = kaldi::RandGauss();
87  }
88  BaseFloat loglike = am_gmm.LogLikelihood(0, feat);
89  BaseFloat loglike1 = am_gmm1->LogLikelihood(0, feat);
90  kaldi::AssertEqual(loglike, loglike1, 1e-2);
91 
92  delete am_gmm1;
93 }
94 
95 void TestClustering(const AmDiagGmm &am_gmm) {
96  int32 target_comp = am_gmm.NumGauss() / 5,
97  interm_comp = am_gmm.NumGauss() / 2;
98  kaldi::Vector<BaseFloat> occs(am_gmm.NumPdfs());
99  for (int32 i = 0; i < occs.Dim(); i++)
100  occs(i) = std::fabs(kaldi::RandGauss()) * (kaldi::RandUniform()+1) * 4;
101 
102  kaldi::UbmClusteringOptions ubm_opts(target_comp, 0.2, interm_comp, 0.01, 30);
103  kaldi::DiagGmm ubm;
104  ClusterGaussiansToUbm(am_gmm, occs, ubm_opts, &ubm);
105 }
106 
108  int32 dim = 1 + kaldi::RandInt(0, 9), // random dimension of the gmm
109  num_pdfs = 5 + kaldi::RandInt(0, 9); // random number of states
110 
111  AmDiagGmm am_gmm;
112  for (int32 i = 0; i < num_pdfs; i++) {
113  int32 num_comp = 1 + kaldi::RandInt(0, 9); // random number of mixtures
114  kaldi::DiagGmm gmm;
115  ut::InitRandDiagGmm(dim, num_comp, &gmm);
116  am_gmm.AddPdf(gmm);
117  }
118 
119  TestAmDiagGmmIO(am_gmm);
120  TestSplitStates(am_gmm);
121  TestClustering(am_gmm);
122 }
123 
124 int main() {
125  for (int i = 0; i < 5; i++)
127  std::cout << "Test OK.\n";
128  return 0;
129 }
void CopyFromAmDiagGmm(const AmDiagGmm &other)
Copies the parameters from another model. Allocates necessary memory.
Definition: am-diag-gmm.cc:79
void AddPdf(const DiagGmm &gmm)
Adds a GMM to the model, and increments the total number of PDFs.
Definition: am-diag-gmm.cc:57
int32 NumGauss() const
Definition: am-diag-gmm.cc:72
float RandUniform(struct RandomState *state=NULL)
Returns a random number strictly between 0 and 1.
Definition: kaldi-math.h:151
void TestAmDiagGmmIO(const AmDiagGmm &am_gmm)
float RandGauss(struct RandomState *state=NULL)
Definition: kaldi-math.h:155
kaldi::int32 int32
void TestSplitStates(const AmDiagGmm &am_gmm)
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
void ClusterGaussiansToUbm(const AmDiagGmm &am, const Vector< BaseFloat > &state_occs, UbmClusteringOptions opts, DiagGmm *ubm_out)
Clusters the Gaussians in an acoustic model to a single GMM with specified number of components...
Definition: am-diag-gmm.cc:195
BaseFloat LogLikelihood(const int32 pdf_index, const VectorBase< BaseFloat > &data) const
Definition: am-diag-gmm.h:108
void InitRandDiagGmm(int32 dim, int32 num_comp, DiagGmm *gmm)
int main()
int32 Dim() const
Definition: am-diag-gmm.h:79
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
void UnitTestAmDiagGmm()
A class representing a vector.
Definition: kaldi-vector.h:406
void TestClustering(const AmDiagGmm &am_gmm)
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
void SplitByCount(const Vector< BaseFloat > &state_occs, int32 target_components, float perturb_factor, BaseFloat power, BaseFloat min_count)
Definition: am-diag-gmm.cc:102