regtree-mllr-diag-gmm.h
Go to the documentation of this file.
1 // transform/regtree-mllr-diag-gmm.h
2 
3 // Copyright 2009-2011 Saarland University; Jan Silovsky
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_TRANSFORM_REGTREE_MLLR_DIAG_GMM_H_
21 #define KALDI_TRANSFORM_REGTREE_MLLR_DIAG_GMM_H_
22 
23 #include <vector>
24 
25 #include "base/kaldi-common.h"
26 #include "gmm/am-diag-gmm.h"
29 #include "util/common-utils.h"
30 
31 namespace kaldi {
32 
33 
37 
41 
42  RegtreeMllrOptions(): min_count(1000.0), use_regtree(true) { }
43 
44  void Register(OptionsItf *opts) {
45  opts->Register("mllr-min-count", &min_count,
46  "Minimum count to estimate an MLLR transform.");
47  opts->Register("mllr-use-regtree", &use_regtree,
48  "Use a regression-class tree for MLLR.");
49  }
50 };
51 
54  public:
56 
58  void Init(int32 num_xforms, int32 dim);
59 
61  void SetUnit();
62 
64  void TransformModel(const RegressionTree &regtree, AmDiagGmm *am);
65 
67  void GetTransformedMeans(const RegressionTree &regtree, const AmDiagGmm &am,
68  int32 pdf_index, MatrixBase<BaseFloat> *out) const;
69 
70  void Write(std::ostream &out_stream, bool binary) const;
71  void Read(std::istream &in_stream, bool binary);
72 
74  void SetParameters(const MatrixBase<BaseFloat> &mat, int32 regclass);
75  void set_bclass2xforms(const std::vector<int32> &in) { bclass2xforms_ = in; }
76 
78  const std::vector< Matrix<BaseFloat> > xform_matrices() const {
79  return xform_matrices_;
80  }
81 
82  private:
84  std::vector< Matrix<BaseFloat> > xform_matrices_;
86  std::vector<int32> bclass2xforms_;
89 
90  // Cannot have copy constructor and assigment operator
92 };
93 
95  int32 regclass) {
96  xform_matrices_[regclass].CopyFromMat(mat, kNoTrans);
97 }
98 
104  public:
106  ~RegtreeMllrDiagGmmAccs() { DeletePointers(&baseclass_stats_); }
107 
108  void Init(int32 num_bclass, int32 dim);
109  void SetZero();
110 
113  BaseFloat AccumulateForGmm(const RegressionTree &regtree,
114  const AmDiagGmm &am,
115  const VectorBase<BaseFloat> &data,
116  int32 pdf_index, BaseFloat weight);
117 
119  void AccumulateForGaussian(const RegressionTree &regtree,
120  const AmDiagGmm &am,
121  const VectorBase<BaseFloat> &data,
122  int32 pdf_index, int32 gauss_index,
123  BaseFloat weight);
124 
125  void Update(const RegressionTree &regtree, const RegtreeMllrOptions &opts,
126  RegtreeMllrDiagGmm *out_mllr, BaseFloat *auxf_impr,
127  BaseFloat *t) const;
128 
129  void Write(std::ostream &out_stream, bool binary) const;
130  void Read(std::istream &in_stream, bool binary, bool add);
131 
133  int32 Dim() const { return dim_; }
134  int32 NumBaseClasses() const { return num_baseclasses_; }
135  const std::vector<AffineXformStats*> &baseclass_stats() const {
136  return baseclass_stats_;
137  }
138 
139  private:
141  std::vector<AffineXformStats*> baseclass_stats_;
144 
146  BaseFloat MllrObjFunction(const Matrix<BaseFloat> &xform,
147  int32 bclass_id) const;
148 
149  // Cannot have copy constructor and assigment operator
151 };
152 
161 
162 } // namespace kaldi
163 
164 #endif // KALDI_TRANSFORM_REGTREE_MLLR_DIAG_GMM_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
TableWriter< KaldiObjectHolder< RegtreeMllrDiagGmm > > RegtreeMllrDiagGmmWriter
void Register(OptionsItf *opts)
An MLLR mean transformation is an affine transformation of Gaussian means.
std::vector< Matrix< BaseFloat > > xform_matrices_
Transform matrices: size() = num_xforms_.
int32 dim_
Dimension of feature vectors.
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
Configuration variables for FMLLR transforms.
bool use_regtree
If &#39;true&#39;, find transforms to generate using regression tree.
BaseFloat min_count
Minimum occupancy for computing a transform.
void set_bclass2xforms(const std::vector< int32 > &in)
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
const std::vector< Matrix< BaseFloat > > xform_matrices() const
Accessors.
const std::vector< AffineXformStats * > & baseclass_stats() const
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
SequentialTableReader< KaldiObjectHolder< RegtreeMllrDiagGmm > > RegtreeMllrDiagGmmSeqReader
RandomAccessTableReader< KaldiObjectHolder< RegtreeMllrDiagGmm > > RandomAccessRegtreeMllrDiagGmmReader
void SetParameters(const MatrixBase< BaseFloat > &mat, int32 regclass)
Mutators.
A regression tree is a clustering of Gaussian densities in an acoustic model, such that the group of ...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
std::vector< AffineXformStats * > baseclass_stats_
Per-baseclass stats; used for accumulation.
int32 num_baseclasses_
Number of baseclasses.
Class for computing the maximum-likelihood estimates of the parameters of an acoustic model that uses...
RandomAccessTableReaderMapped< KaldiObjectHolder< RegtreeMllrDiagGmm > > RandomAccessRegtreeMllrDiagGmmReaderMapped
int32 num_xforms_
Number of transforms == xform_matrices_.size()
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
int32 dim_
Dimension of feature vectors.