plda.h
Go to the documentation of this file.
1 // ivector/plda.h
2 
3 // Copyright 2013 Daniel Povey
4 // 2015 David Snyder
5 
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #ifndef KALDI_IVECTOR_PLDA_H_
23 #define KALDI_IVECTOR_PLDA_H_
24 
25 #include <vector>
26 #include <algorithm>
27 #include "base/kaldi-common.h"
28 #include "matrix/matrix-lib.h"
29 #include "gmm/model-common.h"
30 #include "gmm/diag-gmm.h"
31 #include "gmm/full-gmm.h"
32 #include "itf/options-itf.h"
33 #include "util/common-utils.h"
34 
35 namespace kaldi {
36 
37 /* This code implements Probabilistic Linear Discriminant Analysis: see
38  "Probabilistic Linear Discriminant Analysis" by Sergey Ioffe, ECCV 2006.
39  At least, that was the inspiration. The E-M is an efficient method
40  that I derived myself (note: it could be made even more efficient but
41  it doesn't seem to be necessary as it's already very fast).
42 
43  This implementation of PLDA only supports estimating with a between-class
44  dimension equal to the feature dimension. If you want a between-class
45  covariance that has a lower dimension, you can just remove the smallest
46  elements of the diagonalized between-class covariance matrix. This is not
47  100% exact (wouldn't give you as good likelihood as E-M estimation with that
48  dimension) but it's close enough. */
49 
50 struct PldaConfig {
51  // This config is for the application of PLDA as a transform to iVectors,
52  // prior to dot-product scoring.
55  PldaConfig(): normalize_length(true), simple_length_norm(false) { }
56  void Register(OptionsItf *opts) {
57  opts->Register("normalize-length", &normalize_length,
58  "If true, do length normalization as part of PLDA (see "
59  "code for details). This does not set the length unit; "
60  "by default it instead ensures that the inner product "
61  "with the PLDA model's inverse variance (which is a "
62  "function of how many utterances the iVector was averaged "
63  "over) has the expected value, equal to the iVector "
64  "dimension.");
65 
66  opts->Register("simple-length-normalization", &simple_length_norm,
67  "If true, replace the default length normalization by an "
68  "alternative that normalizes the length of the iVectors to "
69  "be equal to the square root of the iVector dimension.");
70  }
71 };
72 
73 
74 class Plda {
75  public:
76  Plda() { }
77 
78  explicit Plda(const Plda &other):
79  mean_(other.mean_),
80  transform_(other.transform_),
81  psi_(other.psi_),
82  offset_(other.offset_) {
83  };
101  double TransformIvector(const PldaConfig &config,
102  const VectorBase<double> &ivector,
103  int32 num_enroll_examples,
104  VectorBase<double> *transformed_ivector) const;
105 
108  float TransformIvector(const PldaConfig &config,
109  const VectorBase<float> &ivector,
110  int32 num_enroll_examples,
111  VectorBase<float> *transformed_ivector) const;
112 
120  double LogLikelihoodRatio(const VectorBase<double> &transformed_enroll_ivector,
121  int32 num_enroll_utts,
122  const VectorBase<double> &transformed_test_ivector)
123  const;
124 
125 
132  void SmoothWithinClassCovariance(double smoothing_factor);
133 
138  void ApplyTransform(const Matrix<double> &in_transform);
139 
140  int32 Dim() const { return mean_.Dim(); }
141  void Write(std::ostream &os, bool binary) const;
142  void Read(std::istream &is, bool binary);
143  protected:
144  void ComputeDerivedVars(); // computes offset_.
145  friend class PldaEstimator;
147 
148  Vector<double> mean_; // mean of samples in original space.
149  Matrix<double> transform_; // of dimension Dim() by Dim();
150  // this transform makes within-class covar unit
151  // and diagonalizes the between-class covar.
152  Vector<double> psi_; // of dimension Dim(). The between-class
153  // (diagonal) covariance elements, in decreasing order.
154 
155  Vector<double> offset_; // derived variable: -1.0 * transform_ * mean_
156 
157  private:
158  Plda &operator = (const Plda &other); // disallow assignment
159 
166  double GetNormalizationFactor(const VectorBase<double> &transformed_ivector,
167  int32 num_examples) const;
168 
169 };
170 
171 
172 class PldaStats {
173  public:
174  PldaStats(): dim_(0) { }
175 
181  void AddSamples(double weight,
182  const Matrix<double> &group);
183 
184  int32 Dim() const { return dim_; }
185 
186  void Init(int32 dim);
187 
188  void Sort() { std::sort(class_info_.begin(), class_info_.end()); }
189  bool IsSorted() const;
190  ~PldaStats();
191  protected:
192 
193  friend class PldaEstimator;
194 
197  int64 num_examples_; // total number of examples, summed over classes.
198  double class_weight_; // total over classes, of their weight.
199  double example_weight_; // total over classes, of weight times #examples.
200 
201  Vector<double> sum_; // Weighted sum of class means (normalize by
202  // class_weight_ to get mean).
203 
204  SpMatrix<double> offset_scatter_; // Sum over all examples, of the weight
205  // times (example - class-mean).
206 
207  // We have one of these objects per class.
208  struct ClassInfo {
209  double weight;
210  Vector<double> *mean; // owned here, but as a pointer so
211  // sort can be lightweight
212  int32 num_examples; // the number of examples in the class
213  bool operator < (const ClassInfo &other) const {
214  return (num_examples < other.num_examples);
215  }
216  ClassInfo(double weight, Vector<double> *mean, int32 num_examples):
217  weight(weight), mean(mean), num_examples(num_examples) { }
218  };
219 
220  std::vector<ClassInfo> class_info_;
221  private:
223 };
224 
225 
228  PldaEstimationConfig(): num_em_iters(10){ }
229  void Register(OptionsItf *opts) {
230  opts->Register("num-em-iters", &num_em_iters,
231  "Number of iterations of E-M used for PLDA estimation");
232  }
233 };
234 
236  public:
237  PldaEstimator(const PldaStats &stats);
238 
239  void Estimate(const PldaEstimationConfig &config,
240  Plda *output);
241 private:
243 
246  double ComputeObjfPart1() const;
247 
250  double ComputeObjfPart2() const;
251 
253  double ComputeObjf() const;
254 
255  int32 Dim() const { return stats_.Dim(); }
256 
257  void EstimateOneIter();
258 
259  void InitParameters();
260 
261  void ResetPerIterStats();
262 
263  // gets stats from intra-class variation (stats_.offset_scatter_).
264  void GetStatsFromIntraClass();
265 
266  // gets part of stats relating to class means.
267  void GetStatsFromClassMeans();
268 
269  // M-step
270  void EstimateFromStats();
271 
272  // Copy to output.
273  void GetOutput(Plda *plda);
274 
276 
279 
280  // These stats are reset on each iteration.
282  double within_var_count_; // count corresponding to within_var_stats_
284  double between_var_count_; // count corresponding to within_var_stats_
285 
287 };
288 
289 
294 
296  mean_diff_scale(1.0),
297  within_covar_scale(0.3),
298  between_covar_scale(0.7) { }
299 
300  void Register(OptionsItf *opts) {
301  opts->Register("mean-diff-scale", &mean_diff_scale,
302  "Scale with which to add to the total data variance, the outer "
303  "product of the difference between the original mean and the "
304  "adaptation-data mean");
305  opts->Register("within-covar-scale", &within_covar_scale,
306  "Scale that determines how much of excess variance in a "
307  "particular direction gets attributed to within-class covar.");
308  opts->Register("between-covar-scale", &between_covar_scale,
309  "Scale that determines how much of excess variance in a "
310  "particular direction gets attributed to between-class covar.");
311 
312  }
313 };
314 
320  public:
321  PldaUnsupervisedAdaptor(): tot_weight_(0.0) { }
322  // Add stats to this class. Normally the weight will be 1.0.
323  void AddStats(double weight, const Vector<double> &ivector);
324  void AddStats(double weight, const Vector<float> &ivector);
325 
326 
327  void UpdatePlda(const PldaUnsupervisedAdaptorConfig &config,
328  Plda *plda) const;
329  private:
330 
331  double tot_weight_;
334 };
335 
336 
337 
338 } // namespace kaldi
339 
340 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
SpMatrix< double > between_var_
Definition: plda.h:278
Vector< double > mean_stats_
Definition: plda.h:332
ClassInfo(double weight, Vector< double > *mean, int32 num_examples)
Definition: plda.h:216
SpMatrix< double > within_var_stats_
Definition: plda.h:281
double example_weight_
Definition: plda.h:199
void GetOutput(OnlineFeatureInterface *a, Matrix< BaseFloat > *output)
Plda(const Plda &other)
Definition: plda.h:78
Vector< double > offset_
Definition: plda.h:155
Matrix< double > transform_
Definition: plda.h:149
kaldi::int32 int32
SpMatrix< double > offset_scatter_
Definition: plda.h:204
Vector< double > sum_
Definition: plda.h:201
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
const PldaStats & stats_
Definition: plda.h:275
PldaStats::ClassInfo ClassInfo
Definition: plda.h:242
SpMatrix< double > variance_stats_
Definition: plda.h:333
double ComputeObjf(bool batchnorm_test_mode, bool dropout_test_mode, const std::vector< NnetExample > &egs, const Nnet &nnet, NnetComputeProb *prob_computer)
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
void Register(OptionsItf *opts)
Definition: plda.h:300
Vector< double > psi_
Definition: plda.h:152
double class_weight_
Definition: plda.h:198
Plda()
Definition: plda.h:76
This class takes unlabeled iVectors from the domain of interest and uses their mean and variance to a...
Definition: plda.h:319
Vector< double > * mean
Definition: plda.h:210
void Register(OptionsItf *opts)
Definition: plda.h:56
int64 num_classes_
Definition: plda.h:196
bool operator<(const Int32Pair &a, const Int32Pair &b)
Definition: cu-matrixdim.h:83
bool IsSorted(const std::vector< T > &vec)
Returns true if the vector is sorted.
Definition: stl-utils.h:47
double between_var_count_
Definition: plda.h:284
Vector< double > mean_
Definition: plda.h:148
int32 Dim() const
Definition: plda.h:184
int32 Dim() const
Definition: plda.h:255
double within_var_count_
Definition: plda.h:282
SpMatrix< double > within_var_
Definition: plda.h:277
std::vector< ClassInfo > class_info_
Definition: plda.h:220
int32 dim_
Definition: plda.h:195
void Register(OptionsItf *opts)
Definition: plda.h:229
SpMatrix< double > between_var_stats_
Definition: plda.h:283
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
int32 Dim() const
Definition: plda.h:140
bool normalize_length
Definition: plda.h:53
int64 num_examples_
Definition: plda.h:197
void Sort()
Definition: plda.h:188
bool simple_length_norm
Definition: plda.h:54