LinearVtln Class Reference

#include <lvtln.h>

Collaboration diagram for LinearVtln:

Public Member Functions

 LinearVtln ()
 
 LinearVtln (int32 dim, int32 num_classes, int32 default_class)
 
void SetTransform (int32 i, const MatrixBase< BaseFloat > &transform)
 
void SetWarp (int32 i, BaseFloat warp)
 
BaseFloat GetWarp (int32 i) const
 
void GetTransform (int32 i, MatrixBase< BaseFloat > *transform) const
 
void ComputeTransform (const FmllrDiagGmmAccs &accs, std::string norm_type, BaseFloat logdet_scale, MatrixBase< BaseFloat > *Ws, int32 *class_idx, BaseFloat *logdet_out, BaseFloat *objf_impr=NULL, BaseFloat *count=NULL)
 Compute the transform for the speaker. More...
 
void Read (std::istream &is, bool binary)
 
void Write (std::ostream &os, bool binary) const
 
int32 Dim () const
 
int32 NumClasses () const
 
void GetOffset (const FmllrDiagGmmAccs &speaker_stats, int32 class_idx, VectorBase< BaseFloat > *offset) const
 

Protected Attributes

int32 default_class_
 
std::vector< Matrix< BaseFloat > > A_
 
std::vector< BaseFloatlogdets_
 
std::vector< BaseFloatwarps_
 

Friends

class LinearVtlnStats
 

Detailed Description

Definition at line 40 of file lvtln.h.

Constructor & Destructor Documentation

◆ LinearVtln() [1/2]

LinearVtln ( )
inline

Definition at line 42 of file lvtln.h.

References LinearVtln::ComputeTransform(), count, LinearVtln::GetTransform(), LinearVtln::GetWarp(), rnnlm::i, LinearVtln::Read(), LinearVtln::SetTransform(), LinearVtln::SetWarp(), and LinearVtln::Write().

42 { } // This initializer will probably be used prior to calling

◆ LinearVtln() [2/2]

LinearVtln ( int32  dim,
int32  num_classes,
int32  default_class 
)

Definition at line 29 of file lvtln.cc.

References LinearVtln::A_, LinearVtln::default_class_, rnnlm::i, KALDI_ASSERT, LinearVtln::logdets_, and LinearVtln::warps_.

29  {
30  default_class_ = default_class;
31  KALDI_ASSERT(default_class >= 0 && default_class < num_classes);
32  A_.resize(num_classes);
33  for (int32 i = 0; i < num_classes; i++) {
34  A_[i].Resize(dim, dim);
35  A_[i].SetUnit();
36  }
37  logdets_.clear();
38  logdets_.resize(num_classes, 0.0);
39  warps_.clear();
40  warps_.resize(num_classes, 1.0);
41 } // namespace kaldi
std::vector< Matrix< BaseFloat > > A_
Definition: lvtln.h:88
int32 default_class_
Definition: lvtln.h:87
kaldi::int32 int32
std::vector< BaseFloat > warps_
Definition: lvtln.h:90
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< BaseFloat > logdets_
Definition: lvtln.h:89

Member Function Documentation

◆ ComputeTransform()

void ComputeTransform ( const FmllrDiagGmmAccs accs,
std::string  norm_type,
BaseFloat  logdet_scale,
MatrixBase< BaseFloat > *  Ws,
int32 class_idx,
BaseFloat logdet_out,
BaseFloat objf_impr = NULL,
BaseFloat count = NULL 
)

Compute the transform for the speaker.

Definition at line 97 of file lvtln.cc.

References LinearVtln::A_, kaldi::ApplyFeatureTransformToStats(), AffineXformStats::beta_, kaldi::ComposeTransforms(), kaldi::ComputeFmllrMatrixDiagGmm(), MatrixBase< Real >::CopyFromMat(), LinearVtln::default_class_, LinearVtln::Dim(), kaldi::FmllrAuxFuncDiagGmm(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, KALDI_WARN, LinearVtln::logdets_, LinearVtln::NumClasses(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), MatrixBase< Real >::Range(), and MatrixBase< Real >::SetUnit().

Referenced by LinearVtln::LinearVtln(), and main().

104  {
105  int32 dim = Dim();
106  KALDI_ASSERT(dim != 0);
107  if (norm_type != "none" && norm_type != "offset" && norm_type != "diag")
108  KALDI_ERR << "LinearVtln::ComputeTransform, norm_type should be "
109  "one of \"none\", \"offset\" or \"diag\"";
110 
111  if (accs.beta_ == 0.0) {
112  KALDI_WARN << "no stats, returning default transform";
113  int32 dim = Dim();
114  if (Ws) {
115  KALDI_ASSERT(Ws->NumRows() == dim && Ws->NumCols() == dim+1);
116  Ws->Range(0, dim, 0, dim).CopyFromMat(A_[default_class_]);
117  Ws->Range(0, dim, dim, 1).SetZero(); // Set last column to zero.
118  }
119  if (class_idx) *class_idx = default_class_;
120  if (logdet_out) *logdet_out = logdets_[default_class_];
121  if (objf_impr) *objf_impr = 0;
122  if (count) *count = 0;
123  return;
124  }
125 
126  Matrix<BaseFloat> best_transform(dim, dim+1);
127  best_transform.SetUnit();
128  BaseFloat old_objf = FmllrAuxFuncDiagGmm(best_transform, accs),
129  best_objf = -1.0e+100;
130  int32 best_class = -1;
131 
132  for (int32 i = 0; i < NumClasses(); i++) {
133  FmllrDiagGmmAccs accs_tmp(accs);
134  ApplyFeatureTransformToStats(A_[i], &accs_tmp);
135  // "old_trans" just needed by next function as "initial" transform.
136  Matrix<BaseFloat> old_trans(dim, dim+1); old_trans.SetUnit();
137  Matrix<BaseFloat> trans(dim, dim+1);
138  ComputeFmllrMatrixDiagGmm(old_trans, accs_tmp, norm_type,
139  100, // num iters.. don't care since norm_type != "full"
140  &trans);
141  Matrix<BaseFloat> product(dim, dim+1);
142  // product = trans * A_[i] (modulo messing about with offsets)
143  ComposeTransforms(trans, A_[i], false, &product);
144 
145  BaseFloat objf = FmllrAuxFuncDiagGmm(product, accs);
146 
147  if (logdet_scale != 1.0)
148  objf += accs.beta_ * (logdet_scale - 1.0) * logdets_[i];
149 
150  if (objf > best_objf) {
151  best_objf = objf;
152  best_class = i;
153  best_transform.CopyFromMat(product);
154  }
155  }
156  KALDI_ASSERT(best_class != -1);
157  if (Ws) Ws->CopyFromMat(best_transform);
158  if (class_idx) *class_idx = best_class;
159  if (logdet_out) *logdet_out = logdets_[best_class];
160  if (objf_impr) *objf_impr = best_objf - old_objf;
161  if (count) *count = accs.beta_;
162 }
std::vector< Matrix< BaseFloat > > A_
Definition: lvtln.h:88
int32 default_class_
Definition: lvtln.h:87
kaldi::int32 int32
int32 NumClasses() const
Definition: lvtln.h:78
void ApplyFeatureTransformToStats(const MatrixBase< BaseFloat > &xform, AffineXformStats *stats)
This function applies a feature-level transform to stats (useful for certain techniques based on fMLL...
int32 Dim() const
Definition: lvtln.h:77
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
BaseFloat ComputeFmllrMatrixDiagGmm(const MatrixBase< BaseFloat > &in_xform, const AffineXformStats &stats, std::string fmllr_type, int32 num_iters, MatrixBase< BaseFloat > *out_xform)
This function internally calls ComputeFmllrMatrixDiagGmm{Full, Diagonal, Offset}, depending on "fmllr...
bool ComposeTransforms(const Matrix< BaseFloat > &a, const Matrix< BaseFloat > &b, bool b_is_affine, Matrix< BaseFloat > *c)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
float FmllrAuxFuncDiagGmm(const MatrixBase< float > &xform, const AffineXformStats &stats)
Returns the (diagonal-GMM) FMLLR auxiliary function value given the transform and the stats...
std::vector< BaseFloat > logdets_
Definition: lvtln.h:89

◆ Dim()

int32 Dim ( ) const
inline

Definition at line 77 of file lvtln.h.

References LinearVtln::A_, and KALDI_ASSERT.

Referenced by LinearVtln::ComputeTransform(), LinearVtln::GetTransform(), main(), and LinearVtln::SetTransform().

77 { KALDI_ASSERT(!A_.empty()); return A_[0].NumRows(); }
std::vector< Matrix< BaseFloat > > A_
Definition: lvtln.h:88
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetOffset()

void GetOffset ( const FmllrDiagGmmAccs speaker_stats,
int32  class_idx,
VectorBase< BaseFloat > *  offset 
) const

Referenced by LinearVtln::NumClasses().

◆ GetTransform()

void GetTransform ( int32  i,
MatrixBase< BaseFloat > *  transform 
) const

Definition at line 185 of file lvtln.cc.

References LinearVtln::A_, MatrixBase< Real >::CopyFromMat(), LinearVtln::Dim(), KALDI_ASSERT, LinearVtln::NumClasses(), MatrixBase< Real >::NumCols(), and MatrixBase< Real >::NumRows().

Referenced by LinearVtln::LinearVtln().

185  {
186  KALDI_ASSERT(i >= 0 && i < NumClasses());
187  KALDI_ASSERT(transform->NumRows() == transform->NumCols() &&
188  static_cast<int32>(transform->NumRows()) == Dim());
189  transform->CopyFromMat(A_[i]);
190 }
std::vector< Matrix< BaseFloat > > A_
Definition: lvtln.h:88
kaldi::int32 int32
int32 NumClasses() const
Definition: lvtln.h:78
int32 Dim() const
Definition: lvtln.h:77
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetWarp()

BaseFloat GetWarp ( int32  i) const

Definition at line 180 of file lvtln.cc.

References rnnlm::i, KALDI_ASSERT, LinearVtln::NumClasses(), and LinearVtln::warps_.

Referenced by LinearVtln::LinearVtln(), and main().

180  {
181  KALDI_ASSERT(i >= 0 && i < NumClasses());
182  return warps_[i];
183 }
int32 NumClasses() const
Definition: lvtln.h:78
std::vector< BaseFloat > warps_
Definition: lvtln.h:90
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ NumClasses()

int32 NumClasses ( ) const
inline

Definition at line 78 of file lvtln.h.

References LinearVtln::A_, and LinearVtln::GetOffset().

Referenced by LinearVtln::ComputeTransform(), LinearVtln::GetTransform(), LinearVtln::GetWarp(), main(), LinearVtln::SetTransform(), and LinearVtln::SetWarp().

78 { return A_.size(); }
std::vector< Matrix< BaseFloat > > A_
Definition: lvtln.h:88

◆ Read()

void Read ( std::istream &  is,
bool  binary 
)

Definition at line 45 of file lvtln.cc.

References LinearVtln::A_, LinearVtln::default_class_, kaldi::ExpectToken(), rnnlm::i, KALDI_ASSERT, LinearVtln::logdets_, kaldi::ReadBasicType(), kaldi::ReadToken(), and LinearVtln::warps_.

Referenced by LinearVtln::LinearVtln().

45  {
46  int32 sz;
47  ExpectToken(is, binary, "<LinearVtln>");
48  ReadBasicType(is, binary, &sz);
49  A_.resize(sz);
50  logdets_.resize(sz);
51  warps_.resize(sz);
52  for (int32 i = 0; i < sz; i++) {
53  ExpectToken(is, binary, "<A>");
54  A_[i].Read(is, binary);
55  ExpectToken(is, binary, "<logdet>");
56  ReadBasicType(is, binary, &(logdets_[i]));
57  ExpectToken(is, binary, "<warp>");
58  ReadBasicType(is, binary, &(warps_[i]));
59  }
60  std::string token;
61  ReadToken(is, binary, &token);
62  if (token == "</LinearVtln>") {
63  // the older code had a bug in that it wasn't writing or reading
64  // default_class_. The following guess at its value is likely to be
65  // correct.
66  default_class_ = (sz + 1) / 2;
67  } else {
68  KALDI_ASSERT(token == "<DefaultClass>");
69  ReadBasicType(is, binary, &default_class_);
70  ExpectToken(is, binary, "</LinearVtln>");
71  }
72 }
std::vector< Matrix< BaseFloat > > A_
Definition: lvtln.h:88
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
int32 default_class_
Definition: lvtln.h:87
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
std::vector< BaseFloat > warps_
Definition: lvtln.h:90
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< BaseFloat > logdets_
Definition: lvtln.h:89

◆ SetTransform()

void SetTransform ( int32  i,
const MatrixBase< BaseFloat > &  transform 
)

Definition at line 166 of file lvtln.cc.

References LinearVtln::A_, LinearVtln::Dim(), rnnlm::i, KALDI_ASSERT, LinearVtln::logdets_, LinearVtln::NumClasses(), MatrixBase< Real >::NumCols(), and MatrixBase< Real >::NumRows().

Referenced by LinearVtln::LinearVtln(), and main().

166  {
167  KALDI_ASSERT(i >= 0 && i < NumClasses());
168  KALDI_ASSERT(transform.NumRows() == transform.NumCols() &&
169  static_cast<int32>(transform.NumRows()) == Dim());
170  A_[i].CopyFromMat(transform);
171  logdets_[i] = A_[i].LogDet();
172 }
std::vector< Matrix< BaseFloat > > A_
Definition: lvtln.h:88
kaldi::int32 int32
int32 NumClasses() const
Definition: lvtln.h:78
int32 Dim() const
Definition: lvtln.h:77
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< BaseFloat > logdets_
Definition: lvtln.h:89

◆ SetWarp()

void SetWarp ( int32  i,
BaseFloat  warp 
)

Definition at line 174 of file lvtln.cc.

References rnnlm::i, KALDI_ASSERT, LinearVtln::NumClasses(), and LinearVtln::warps_.

Referenced by LinearVtln::LinearVtln(), and main().

174  {
175  KALDI_ASSERT(i >= 0 && i < NumClasses());
176  KALDI_ASSERT(warps_.size() == static_cast<size_t>(NumClasses()));
177  warps_[i] = warp;
178 }
int32 NumClasses() const
Definition: lvtln.h:78
std::vector< BaseFloat > warps_
Definition: lvtln.h:90
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const

Definition at line 74 of file lvtln.cc.

References LinearVtln::A_, LinearVtln::default_class_, rnnlm::i, KALDI_ASSERT, LinearVtln::logdets_, LinearVtln::warps_, kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by LinearVtln::LinearVtln(), and main().

74  {
75  WriteToken(os, binary, "<LinearVtln>");
76  if(!binary) os << "\n";
77  int32 sz = A_.size();
78  KALDI_ASSERT(static_cast<size_t>(sz) == logdets_.size());
79  KALDI_ASSERT(static_cast<size_t>(sz) == warps_.size());
80  WriteBasicType(os, binary, sz);
81  for (int32 i = 0; i < sz; i++) {
82  WriteToken(os, binary, "<A>");
83  A_[i].Write(os, binary);
84  WriteToken(os, binary, "<logdet>");
85  WriteBasicType(os, binary, logdets_[i]);
86  WriteToken(os, binary, "<warp>");
87  WriteBasicType(os, binary, warps_[i]);
88  if(!binary) os << "\n";
89  }
90  WriteToken(os, binary, "<DefaultClass>");
91  WriteBasicType(os, binary, default_class_);
92  WriteToken(os, binary, "</LinearVtln>");
93 }
std::vector< Matrix< BaseFloat > > A_
Definition: lvtln.h:88
int32 default_class_
Definition: lvtln.h:87
kaldi::int32 int32
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
std::vector< BaseFloat > warps_
Definition: lvtln.h:90
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
std::vector< BaseFloat > logdets_
Definition: lvtln.h:89

Friends And Related Function Documentation

◆ LinearVtlnStats

friend class LinearVtlnStats
friend

Definition at line 85 of file lvtln.h.

Member Data Documentation

◆ A_

◆ default_class_

int32 default_class_
protected

◆ logdets_

std::vector<BaseFloat> logdets_
protected

◆ warps_

std::vector<BaseFloat> warps_
protected

The documentation for this class was generated from the following files: