lvtln.cc
Go to the documentation of this file.
1 // transform/lvtln.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 // 2014 Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <utility>
22 #include <vector>
23 using std::vector;
24 
25 #include "transform/lvtln.h"
26 
27 namespace kaldi {
28 
29 LinearVtln::LinearVtln(int32 dim, int32 num_classes, int32 default_class) {
30  default_class_ = default_class;
31  KALDI_ASSERT(default_class >= 0 && default_class < num_classes);
32  A_.resize(num_classes);
33  for (int32 i = 0; i < num_classes; i++) {
34  A_[i].Resize(dim, dim);
35  A_[i].SetUnit();
36  }
37  logdets_.clear();
38  logdets_.resize(num_classes, 0.0);
39  warps_.clear();
40  warps_.resize(num_classes, 1.0);
41 } // namespace kaldi
42 
43 
44 
45 void LinearVtln::Read(std::istream &is, bool binary) {
46  int32 sz;
47  ExpectToken(is, binary, "<LinearVtln>");
48  ReadBasicType(is, binary, &sz);
49  A_.resize(sz);
50  logdets_.resize(sz);
51  warps_.resize(sz);
52  for (int32 i = 0; i < sz; i++) {
53  ExpectToken(is, binary, "<A>");
54  A_[i].Read(is, binary);
55  ExpectToken(is, binary, "<logdet>");
56  ReadBasicType(is, binary, &(logdets_[i]));
57  ExpectToken(is, binary, "<warp>");
58  ReadBasicType(is, binary, &(warps_[i]));
59  }
60  std::string token;
61  ReadToken(is, binary, &token);
62  if (token == "</LinearVtln>") {
63  // the older code had a bug in that it wasn't writing or reading
64  // default_class_. The following guess at its value is likely to be
65  // correct.
66  default_class_ = (sz + 1) / 2;
67  } else {
68  KALDI_ASSERT(token == "<DefaultClass>");
69  ReadBasicType(is, binary, &default_class_);
70  ExpectToken(is, binary, "</LinearVtln>");
71  }
72 }
73 
74 void LinearVtln::Write(std::ostream &os, bool binary) const {
75  WriteToken(os, binary, "<LinearVtln>");
76  if(!binary) os << "\n";
77  int32 sz = A_.size();
78  KALDI_ASSERT(static_cast<size_t>(sz) == logdets_.size());
79  KALDI_ASSERT(static_cast<size_t>(sz) == warps_.size());
80  WriteBasicType(os, binary, sz);
81  for (int32 i = 0; i < sz; i++) {
82  WriteToken(os, binary, "<A>");
83  A_[i].Write(os, binary);
84  WriteToken(os, binary, "<logdet>");
85  WriteBasicType(os, binary, logdets_[i]);
86  WriteToken(os, binary, "<warp>");
87  WriteBasicType(os, binary, warps_[i]);
88  if(!binary) os << "\n";
89  }
90  WriteToken(os, binary, "<DefaultClass>");
91  WriteBasicType(os, binary, default_class_);
92  WriteToken(os, binary, "</LinearVtln>");
93 }
94 
95 
98  std::string norm_type, // "none", "offset", "diag"
99  BaseFloat logdet_scale,
100  MatrixBase<BaseFloat> *Ws, // output fMLLR transform, should be size dim x dim+1
101  int32 *class_idx, // the transform that was chosen...
102  BaseFloat *logdet_out,
103  BaseFloat *objf_impr, // versus no transform
104  BaseFloat *count) {
105  int32 dim = Dim();
106  KALDI_ASSERT(dim != 0);
107  if (norm_type != "none" && norm_type != "offset" && norm_type != "diag")
108  KALDI_ERR << "LinearVtln::ComputeTransform, norm_type should be "
109  "one of \"none\", \"offset\" or \"diag\"";
110 
111  if (accs.beta_ == 0.0) {
112  KALDI_WARN << "no stats, returning default transform";
113  int32 dim = Dim();
114  if (Ws) {
115  KALDI_ASSERT(Ws->NumRows() == dim && Ws->NumCols() == dim+1);
116  Ws->Range(0, dim, 0, dim).CopyFromMat(A_[default_class_]);
117  Ws->Range(0, dim, dim, 1).SetZero(); // Set last column to zero.
118  }
119  if (class_idx) *class_idx = default_class_;
120  if (logdet_out) *logdet_out = logdets_[default_class_];
121  if (objf_impr) *objf_impr = 0;
122  if (count) *count = 0;
123  return;
124  }
125 
126  Matrix<BaseFloat> best_transform(dim, dim+1);
127  best_transform.SetUnit();
128  BaseFloat old_objf = FmllrAuxFuncDiagGmm(best_transform, accs),
129  best_objf = -1.0e+100;
130  int32 best_class = -1;
131 
132  for (int32 i = 0; i < NumClasses(); i++) {
133  FmllrDiagGmmAccs accs_tmp(accs);
134  ApplyFeatureTransformToStats(A_[i], &accs_tmp);
135  // "old_trans" just needed by next function as "initial" transform.
136  Matrix<BaseFloat> old_trans(dim, dim+1); old_trans.SetUnit();
137  Matrix<BaseFloat> trans(dim, dim+1);
138  ComputeFmllrMatrixDiagGmm(old_trans, accs_tmp, norm_type,
139  100, // num iters.. don't care since norm_type != "full"
140  &trans);
141  Matrix<BaseFloat> product(dim, dim+1);
142  // product = trans * A_[i] (modulo messing about with offsets)
143  ComposeTransforms(trans, A_[i], false, &product);
144 
145  BaseFloat objf = FmllrAuxFuncDiagGmm(product, accs);
146 
147  if (logdet_scale != 1.0)
148  objf += accs.beta_ * (logdet_scale - 1.0) * logdets_[i];
149 
150  if (objf > best_objf) {
151  best_objf = objf;
152  best_class = i;
153  best_transform.CopyFromMat(product);
154  }
155  }
156  KALDI_ASSERT(best_class != -1);
157  if (Ws) Ws->CopyFromMat(best_transform);
158  if (class_idx) *class_idx = best_class;
159  if (logdet_out) *logdet_out = logdets_[best_class];
160  if (objf_impr) *objf_impr = best_objf - old_objf;
161  if (count) *count = accs.beta_;
162 }
163 
164 
165 
167  KALDI_ASSERT(i >= 0 && i < NumClasses());
168  KALDI_ASSERT(transform.NumRows() == transform.NumCols() &&
169  static_cast<int32>(transform.NumRows()) == Dim());
170  A_[i].CopyFromMat(transform);
171  logdets_[i] = A_[i].LogDet();
172 }
173 
175  KALDI_ASSERT(i >= 0 && i < NumClasses());
176  KALDI_ASSERT(warps_.size() == static_cast<size_t>(NumClasses()));
177  warps_[i] = warp;
178 }
179 
181  KALDI_ASSERT(i >= 0 && i < NumClasses());
182  return warps_[i];
183 }
184 
186  KALDI_ASSERT(i >= 0 && i < NumClasses());
187  KALDI_ASSERT(transform->NumRows() == transform->NumCols() &&
188  static_cast<int32>(transform->NumRows()) == Dim());
189  transform->CopyFromMat(A_[i]);
190 }
191 
192 
193 
194 } // end namespace kaldi
195 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
BaseFloat GetWarp(int32 i) const
Definition: lvtln.cc:180
void SetTransform(int32 i, const MatrixBase< BaseFloat > &transform)
Definition: lvtln.cc:166
std::vector< Matrix< BaseFloat > > A_
Definition: lvtln.h:88
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
int32 default_class_
Definition: lvtln.h:87
This does not work with multiple feature transforms.
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
void CopyFromMat(const MatrixBase< OtherReal > &M, MatrixTransposeType trans=kNoTrans)
Copy given matrix. (no resize is done).
void SetUnit()
Sets to zero, except ones along diagonal [for non-square matrices too].
int32 NumClasses() const
Definition: lvtln.h:78
void ApplyFeatureTransformToStats(const MatrixBase< BaseFloat > &xform, AffineXformStats *stats)
This function applies a feature-level transform to stats (useful for certain techniques based on fMLL...
int32 Dim() const
Definition: lvtln.h:77
const size_t count
void GetTransform(int32 i, MatrixBase< BaseFloat > *transform) const
Definition: lvtln.cc:185
void Read(std::istream &is, bool binary)
Definition: lvtln.cc:45
void ComputeTransform(const FmllrDiagGmmAccs &accs, std::string norm_type, BaseFloat logdet_scale, MatrixBase< BaseFloat > *Ws, int32 *class_idx, BaseFloat *logdet_out, BaseFloat *objf_impr=NULL, BaseFloat *count=NULL)
Compute the transform for the speaker.
Definition: lvtln.cc:97
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void SetWarp(int32 i, BaseFloat warp)
Definition: lvtln.cc:174
std::vector< BaseFloat > warps_
Definition: lvtln.h:90
BaseFloat ComputeFmllrMatrixDiagGmm(const MatrixBase< BaseFloat > &in_xform, const AffineXformStats &stats, std::string fmllr_type, int32 num_iters, MatrixBase< BaseFloat > *out_xform)
This function internally calls ComputeFmllrMatrixDiagGmm{Full, Diagonal, Offset}, depending on "fmllr...
bool ComposeTransforms(const Matrix< BaseFloat > &a, const Matrix< BaseFloat > &b, bool b_is_affine, Matrix< BaseFloat > *c)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
float FmllrAuxFuncDiagGmm(const MatrixBase< float > &xform, const AffineXformStats &stats)
Returns the (diagonal-GMM) FMLLR auxiliary function value given the transform and the stats...
void Write(std::ostream &os, bool binary) const
Definition: lvtln.cc:74
SubMatrix< Real > Range(const MatrixIndexT row_offset, const MatrixIndexT num_rows, const MatrixIndexT col_offset, const MatrixIndexT num_cols) const
Return a sub-part of matrix.
Definition: kaldi-matrix.h:202
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
double beta_
beta_ is the occupation count.
std::vector< BaseFloat > logdets_
Definition: lvtln.h:89