decodable-am-diag-gmm-regtree.cc
Go to the documentation of this file.
1 // transform/decodable-am-diag-gmm-regtree.cc
2 
3 // Copyright 2009-2011 Saarland University; Lukas Burget
4 // 2013 Johns Hopkins Universith (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <vector>
22 using std::vector;
23 
25 
26 namespace kaldi {
27 
28 
30  int32 state) {
31  KALDI_ASSERT(frame < NumFramesReady() && frame >= 0);
32  KALDI_ASSERT(state < NumIndices() && state >= 0);
33 
34  if (!valid_logdets_) {
37  valid_logdets_ = true;
38  }
39 
40  if (log_like_cache_[state].hit_time == frame) {
41  return log_like_cache_[state].log_like; // return cached value, if found
42  }
43 
44  const DiagGmm &pdf = acoustic_model_.GetPdf(state);
45  const VectorBase<BaseFloat> &data = feature_matrix_.Row(frame);
46 
47  // check if everything is in order
48  if (pdf.Dim() != data.Dim()) {
49  KALDI_ERR << "Dim mismatch: data dim = " << data.Dim()
50  << " vs. model dim = " << pdf.Dim();
51  }
52  if (!pdf.valid_gconsts()) {
53  KALDI_ERR << "State " << (state) << ": Must call ComputeGconsts() "
54  "before computing likelihood.";
55  }
56 
57  if (frame != previous_frame_) { // cache the transformed & squared stats.
60  vector< Vector <BaseFloat> >::iterator it = xformed_data_squared_.begin(),
61  end = xformed_data_squared_.end();
62  for (; it != end; ++it) { it->ApplyPow(2.0); }
63  previous_frame_ = frame;
64  }
65 
66  Vector<BaseFloat> loglikes(pdf.gconsts()); // need to recreate for each pdf
67  int32 baseclass, regclass;
68  for (int32 comp_id = 0, num_comp = pdf.NumGauss(); comp_id < num_comp;
69  ++comp_id) {
70  baseclass = regtree_.Gauss2BaseclassId(state, comp_id);
71  regclass = fmllr_xform_.Base2RegClass(baseclass);
72  // loglikes += means * inv(vars) * data.
73  loglikes(comp_id) += VecVec(pdf.means_invvars().Row(comp_id),
74  xformed_data_[regclass]);
75  // loglikes += -0.5 * inv(vars) * data_sq.
76  loglikes(comp_id) -= 0.5 * VecVec(pdf.inv_vars().Row(comp_id),
77  xformed_data_squared_[regclass]);
78  loglikes(comp_id) += logdets_(regclass);
79  }
80 
81  BaseFloat log_sum = loglikes.LogSumExp(log_sum_exp_prune_);
82  if (KALDI_ISNAN(log_sum) || KALDI_ISINF(log_sum))
83  KALDI_ERR << "Invalid answer (overflow or invalid variances/features?)";
84 
85  log_like_cache_[state].log_like = log_sum;
86  log_like_cache_[state].hit_time = frame;
87 
88  return log_sum;
89 }
90 
92  DeletePointers(&xformed_mean_invvars_);
93  DeletePointers(&xformed_gconsts_);
94 }
95 
96 
98  if (xformed_mean_invvars_.size() != 0)
99  DeletePointers(&xformed_mean_invvars_);
100  if (xformed_gconsts_.size() != 0)
101  DeletePointers(&xformed_gconsts_);
102  int32 num_pdfs = acoustic_model_.NumPdfs();
103  xformed_mean_invvars_.resize(num_pdfs);
104  xformed_gconsts_.resize(num_pdfs);
105  is_cached_.resize(num_pdfs, false);
107 }
108 
109 
110 // This is almost the same code as DiagGmm::ComputeGconsts, except that
111 // means are used instead of means * inv(vars). This saves some computation.
112 static void ComputeGconsts(const VectorBase<BaseFloat> &weights,
113  const MatrixBase<BaseFloat> &means,
114  const MatrixBase<BaseFloat> &inv_vars,
115  VectorBase<BaseFloat> *gconsts_out) {
116  int32 num_gauss = weights.Dim();
117  int32 dim = means.NumCols();
118  KALDI_ASSERT(means.NumRows() == num_gauss
119  && inv_vars.NumRows() == num_gauss && inv_vars.NumCols() == dim);
120  KALDI_ASSERT(gconsts_out->Dim() == num_gauss);
121 
122  BaseFloat offset = -0.5 * M_LOG_2PI * dim; // constant term in gconst.
123  int32 num_bad = 0;
124 
125  for (int32 gauss = 0; gauss < num_gauss; gauss++) {
126  KALDI_ASSERT(weights(gauss) >= 0); // Cannot have negative weights.
127  BaseFloat gc = Log(weights(gauss)) + offset; // May be -inf if weights == 0
128  for (int32 d = 0; d < dim; d++) {
129  gc += 0.5 * Log(inv_vars(gauss, d)) - 0.5 * means(gauss, d)
130  * means(gauss, d) * inv_vars(gauss, d); // diff from DiagGmm version.
131  }
132 
133  if (KALDI_ISNAN(gc)) { // negative infinity is OK but NaN is not acceptable
134  KALDI_ERR << "At component " << gauss
135  << ", not a number in gconst computation";
136  }
137  if (KALDI_ISINF(gc)) {
138  num_bad++;
139  // If positive infinity, make it negative infinity.
140  // Want to make sure the answer becomes -inf in the end, not NaN.
141  if (gc > 0) gc = -gc;
142  }
143  (*gconsts_out)(gauss) = gc;
144  }
145  if (num_bad > 0)
146  KALDI_WARN << num_bad << " unusable components found while computing "
147  << "gconsts.";
148 }
149 
150 
152  int32 state) {
153  if (is_cached_[state]) { // found in cache
154  KALDI_ASSERT(xformed_mean_invvars_[state] != NULL);
155  KALDI_VLOG(3) << "For PDF index " << state << ": transformed means "
156  << "found in cache.";
157  return *xformed_mean_invvars_[state];
158  } else { // transform the means and cache them
159  KALDI_ASSERT(xformed_mean_invvars_[state] == NULL);
160  KALDI_VLOG(3) << "For PDF index " << state << ": transforming means.";
161  int32 num_gauss = acoustic_model_.GetPdf(state).NumGauss(),
162  dim = acoustic_model_.Dim();
163  const Vector<BaseFloat> &weights = acoustic_model_.GetPdf(state).weights();
164  const Matrix<BaseFloat> &invvars = acoustic_model_.GetPdf(state).inv_vars();
165  xformed_mean_invvars_[state] = new Matrix<BaseFloat>(num_gauss, dim);
166  mllr_xform_.GetTransformedMeans(regtree_, acoustic_model_, state,
167  xformed_mean_invvars_[state]);
168  xformed_gconsts_[state] = new Vector<BaseFloat>(num_gauss);
169  // At this point, the transformed means haven't been multiplied with
170  // the inv vars, and they are used to compute gconsts first.
171  ComputeGconsts(weights, *xformed_mean_invvars_[state], invvars,
172  xformed_gconsts_[state]);
173  // Finally, multiply the transformed means with the inv vars.
174  xformed_mean_invvars_[state]->MulElements(invvars);
175  is_cached_[state] = true;
176  return *xformed_mean_invvars_[state];
177  }
178 }
179 
181  int32 state) {
182  if (!is_cached_[state]) {
183  KALDI_ERR << "GConsts not cached for state: " << state << ". Must call "
184  << "GetXformedMeanInvVars() first.";
185  }
186  KALDI_ASSERT(xformed_gconsts_[state] != NULL);
187  return *xformed_gconsts_[state];
188 }
189 
191  int32 state) {
192 // KALDI_ERR << "Function not completely implemented yet.";
193  KALDI_ASSERT(frame < NumFramesReady() && frame >= 0);
194  KALDI_ASSERT(state < NumIndices() && state >= 0);
195 
196  if (log_like_cache_[state].hit_time == frame) {
197  return log_like_cache_[state].log_like; // return cached value, if found
198  }
199 
200  const DiagGmm &pdf = acoustic_model_.GetPdf(state);
201  const VectorBase<BaseFloat> &data = feature_matrix_.Row(frame);
202 
203  // check if everything is in order
204  if (pdf.Dim() != data.Dim()) {
205  KALDI_ERR << "Dim mismatch: data dim = " << data.Dim()
206  << " vs. model dim = " << pdf.Dim();
207  }
208 
209  if (frame != previous_frame_) { // cache the squared stats.
210  data_squared_.CopyFromVec(feature_matrix_.Row(frame));
211  data_squared_.ApplyPow(2.0);
212  previous_frame_ = frame;
213  }
214 
215  const Matrix<BaseFloat> &means_invvars = GetXformedMeanInvVars(state);
216  const Vector<BaseFloat> &gconsts = GetXformedGconsts(state);
217 
218  Vector<BaseFloat> loglikes(gconsts); // need to recreate for each pdf
219  // loglikes += means * inv(vars) * data.
220  loglikes.AddMatVec(1.0, means_invvars, kNoTrans, data, 1.0);
221  // loglikes += -0.5 * inv(vars) * data_sq.
222  loglikes.AddMatVec(-0.5, pdf.inv_vars(), kNoTrans, data_squared_, 1.0);
223 
224  BaseFloat log_sum = loglikes.LogSumExp(log_sum_exp_prune_);
225  if (KALDI_ISNAN(log_sum) || KALDI_ISINF(log_sum))
226  KALDI_ERR << "Invalid answer (overflow or invalid variances/features?)";
227 
228  log_like_cache_[state].log_like = log_sum;
229  log_like_cache_[state].hit_time = frame;
230 
231  return log_sum;
232 }
233 
234 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 Dim() const
Returns the dimensionality of the Gaussian mean vectors.
Definition: diag-gmm.h:74
int32 Gauss2BaseclassId(size_t pdf_id, size_t gauss_id) const
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
#define M_LOG_2PI
Definition: kaldi-math.h:60
std::vector< Vector< BaseFloat > > xformed_data_
void InitCache()
Initializes the mean & gconst caches.
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
const Matrix< BaseFloat > & means_invvars() const
Definition: diag-gmm.h:179
#define KALDI_ISINF
Definition: kaldi-math.h:73
virtual BaseFloat LogLikelihoodZeroBased(int32 frame, int32 state_index)
const Vector< BaseFloat > & gconsts() const
Const accessors.
Definition: diag-gmm.h:174
bool valid_gconsts() const
Definition: diag-gmm.h:181
void GetLogDets(VectorBase< BaseFloat > *out) const
kaldi::int32 int32
Real LogSumExp(Real prune=-1.0) const
Returns log(sum(exp())) without exp overflow If prune > 0.0, ignores terms less than the max - prune...
Vector< BaseFloat > data_squared_
Cache for fast likelihood calculation.
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
double Log(double x)
Definition: kaldi-math.h:100
std::vector< Vector< BaseFloat > > xformed_data_squared_
virtual int32 NumIndices() const
Returns the number of states in the acoustic model (they will be indexed one-based, i.e.
const Matrix< BaseFloat > & GetXformedMeanInvVars(int32 state_index)
Get the transformed means times inverse variances for a given pdf, and cache them.
void TransformFeature(const VectorBase< BaseFloat > &in, std::vector< Vector< BaseFloat > > *out) const
Get the transformed features for each of the transforms.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
const Vector< BaseFloat > & weights() const
Definition: diag-gmm.h:178
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
const Matrix< BaseFloat > & feature_matrix_
void AddMatVec(const Real alpha, const MatrixBase< Real > &M, const MatrixTransposeType trans, const VectorBase< Real > &v, const Real beta)
Add matrix times vector : this <– beta*this + alpha*M*v.
Definition: kaldi-vector.cc:92
std::vector< LikelihoodCacheRecord > log_like_cache_
int32 Dim() const
Definition: am-diag-gmm.h:79
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
DiagGmm & GetPdf(int32 pdf_index)
Accessors.
Definition: am-diag-gmm.h:119
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ISNAN
Definition: kaldi-math.h:72
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
const Vector< BaseFloat > & GetXformedGconsts(int32 state_index)
Get the cached (while computing transformed means) gconsts for likelihood calculation.
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
virtual BaseFloat LogLikelihoodZeroBased(int32 frame, int32 state_index)
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:37
int32 Base2RegClass(int32 bclass) const
static void ComputeGconsts(const VectorBase< BaseFloat > &weights, const MatrixBase< BaseFloat > &means, const MatrixBase< BaseFloat > &inv_vars, VectorBase< BaseFloat > *gconsts_out)
const Matrix< BaseFloat > & inv_vars() const
Definition: diag-gmm.h:180