regtree-fmllr-diag-gmm.h
Go to the documentation of this file.
1 // transform/regtree-fmllr-diag-gmm.h
2 
3 // Copyright 2009-2011 Saarland University; Georg Stemmer;
4 // Microsoft Corporation
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 
22 #ifndef KALDI_TRANSFORM_REGTREE_FMLLR_DIAG_GMM_H_
23 #define KALDI_TRANSFORM_REGTREE_FMLLR_DIAG_GMM_H_
24 
25 #include <vector>
26 
27 #include "base/kaldi-common.h"
28 #include "gmm/am-diag-gmm.h"
31 #include "util/kaldi-table.h"
32 #include "util/kaldi-holder.h"
33 
34 namespace kaldi {
35 
36 
39  std::string update_type;
42  bool use_regtree;
43 
45  RegtreeFmllrOptions(): update_type("full"), min_count(1000.0),
46  num_iters(10), use_regtree(true) { }
47 
48  void Register(OptionsItf *opts) {
49  opts->Register("fmllr-update-type", &update_type,
50  "Update type for fMLLR (\"full\"|\"diag\"|\"offset\"|\"none\")");
51  opts->Register("fmllr-min-count", &min_count,
52  "Minimum count to estimate an fMLLR transform.");
53  opts->Register("fmllr-num-iters", &num_iters,
54  "Number of fMLLR iterations (if using an iterative update).");
55  opts->Register("fmllr-use-regtree", &use_regtree,
56  "Use a regression-class tree for fMLLR.");
57  }
58 };
59 
60 
70  public:
71  RegtreeFmllrDiagGmm() : dim_(-1), num_xforms_(-1), valid_logdet_(false) {}
73  : dim_(other.dim_), num_xforms_(other.num_xforms_),
74  xform_matrices_(other.xform_matrices_), logdet_(other.logdet_),
75  valid_logdet_(other.valid_logdet_),
76  bclass2xforms_(other.bclass2xforms_) {}
79  void Init(size_t num_xforms, size_t dim);
80  void Validate();
81  void SetUnit();
84  void ComputeLogDets();
86  void TransformFeature(const VectorBase<BaseFloat> &in,
87  std::vector< Vector<BaseFloat> > *out) const;
88  void Write(std::ostream &out_stream, bool binary) const;
89  void Read(std::istream &in_stream, bool binary);
90 
92  int32 Dim() const { return dim_; }
93  int32 NumBaseClasses() const { return bclass2xforms_.size(); }
94  int32 NumRegClasses() const { return num_xforms_; }
95  void GetXformMatrix(int32 xform_index, Matrix<BaseFloat> *out) const;
96  void GetLogDets(VectorBase<BaseFloat> *out) const;
97  int32 Base2RegClass(int32 bclass) const { return bclass2xforms_[bclass]; }
98 
100  void SetParameters(const MatrixBase<BaseFloat> &mat, size_t regclass);
101  void set_bclass2xforms(const std::vector<int32> &in) { bclass2xforms_ = in; }
102 
103  private:
106  std::vector< Matrix<BaseFloat> > xform_matrices_;
109  std::vector<int32> bclass2xforms_;
111 
112  void operator = (const RegtreeFmllrDiagGmm&); // Disallow assignment operator
113 };
114 
116  Matrix<BaseFloat> *out) const {
117  if (xform_index >= num_xforms_) {
118  KALDI_ERR << "Index (" << xform_index << ") out of range [0, "
119  << num_xforms_ << "]";
120  }
121  out->Resize(dim_, dim_ + 1);
122  out->CopyFromMat(xform_matrices_[xform_index], kNoTrans);
123 }
124 
126  size_t regclass) {
127  xform_matrices_[regclass].CopyFromMat(mat, kNoTrans);
128  valid_logdet_ = false;
129 }
130 
132  KALDI_ASSERT(valid_logdet_ && out->Dim() == logdet_.Dim());
133  out->CopyFromVec(logdet_);
134 }
135 
142 
149  public:
150  RegtreeFmllrDiagGmmAccs() : num_baseclasses_(-1), dim_(-1) {}
151  ~RegtreeFmllrDiagGmmAccs() { DeletePointers(&baseclass_stats_); }
152 
153  void Init(size_t num_bclass, size_t dim);
154  void SetZero();
155 
161  BaseFloat AccumulateForGmm(const RegressionTree &regtree,
162  const AmDiagGmm &am,
163  const VectorBase<BaseFloat> &data,
164  size_t pdf_index, BaseFloat weight);
165 
167  void AccumulateForGaussian(const RegressionTree &regtree,
168  const AmDiagGmm &am,
169  const VectorBase<BaseFloat> &data,
170  size_t pdf_index, size_t gauss_index,
171  BaseFloat weight);
172 
173  void Update(const RegressionTree &regtree, const RegtreeFmllrOptions &opts,
174  RegtreeFmllrDiagGmm *out_fmllr, BaseFloat *auxf_impr,
175  BaseFloat *tot_t) const;
176 
177  void Write(std::ostream &out_stream, bool binary) const;
178  void Read(std::istream &in_stream, bool binary, bool add);
179 
181  int32 Dim() const { return dim_; }
182  int32 NumBaseClasses() const { return num_baseclasses_; }
183  const std::vector<AffineXformStats*> &baseclass_stats() const {
184  return baseclass_stats_;
185  }
186 
187  private:
189  std::vector<AffineXformStats*> baseclass_stats_;
194 
195  // Cannot have copy constructor and assigment operator
197 };
198 
199 
200 
201 
202 } // namespace kaldi
203 
204 #endif // KALDI_TRANSFORM_REGTREE_FMLLR_DIAG_GMM_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
void set_bclass2xforms(const std::vector< int32 > &in)
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.
void Register(OptionsItf *opts)
void GetLogDets(VectorBase< BaseFloat > *out) const
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
int32 dim_
Dimension of feature vectors.
kaldi::int32 int32
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
void CopyFromMat(const MatrixBase< OtherReal > &M, MatrixTransposeType trans=kNoTrans)
Copy given matrix. (no resize is done).
TableWriter< KaldiObjectHolder< RegtreeFmllrDiagGmm > > RegtreeFmllrDiagGmmWriter
RandomAccessTableReaderMapped< KaldiObjectHolder< RegtreeFmllrDiagGmm > > RandomAccessRegtreeFmllrDiagGmmReaderMapped
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
An FMLLR (feature-space MLLR) transformation, also called CMLLR (constrained MLLR) is an affine trans...
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
Configuration variables for FMLLR transforms.
A regression tree is a clustering of Gaussian densities in an acoustic model, such that the group of ...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
SequentialTableReader< KaldiObjectHolder< RegtreeFmllrDiagGmm > > RegtreeFmllrDiagGmmSeqReader
#define KALDI_ERR
Definition: kaldi-error.h:147
BaseFloat min_count
Minimum occupancy for computing a transform.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
bool use_regtree
If &#39;true&#39;, find transforms to generate using regression tree.
void GetXformMatrix(int32 xform_index, Matrix< BaseFloat > *out) const
int32 num_xforms_
Number of transform matrices.
int32 num_baseclasses_
Number of baseclasses.
int32 num_iters
Number of iterations (if using an iterative update)
std::vector< Matrix< BaseFloat > > xform_matrices_
Transform matrices.
void SetParameters(const MatrixBase< BaseFloat > &mat, size_t regclass)
Mutators.
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
Class for computing the accumulators needed for the maximum-likelihood estimate of FMLLR transforms f...
int32 dim_
Dimension of feature vectors.
Vector< BaseFloat > logdet_
Log-determinants of the Jacobians.
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
std::string update_type
"full", "diag", "offset", "none"
const std::vector< AffineXformStats * > & baseclass_stats() const
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
RandomAccessTableReader< KaldiObjectHolder< RegtreeFmllrDiagGmm > > RandomAccessRegtreeFmllrDiagGmmReader
RegtreeFmllrDiagGmm(const RegtreeFmllrDiagGmm &other)
int32 Dim() const
Accessors.
bool valid_logdet_
Whether logdets are for current transforms.
int32 Base2RegClass(int32 bclass) const