BasisFmllrAccus Class Reference

Stats for fMLLR subspace estimation. More...

#include <basis-fmllr-diag-gmm.h>

Collaboration diagram for BasisFmllrAccus:

Public Member Functions

 BasisFmllrAccus ()
 
 BasisFmllrAccus (int32 dim)
 
void ResizeAccus (int32 dim)
 
void Write (std::ostream &out_stream, bool binary) const
 Routines for reading and writing stats. More...
 
void Read (std::istream &in_stream, bool binary, bool add=false)
 
void AccuGradientScatter (const AffineXformStats &spk_stats)
 Accumulate gradient scatter for one (training) speaker. More...
 

Public Attributes

SpMatrix< BaseFloatgrad_scatter_
 Gradient scatter. Dim is [(D+1)*D] [(D+1)*D]. More...
 
int32 dim_
 Feature dimension. More...
 
double beta_
 Occupancy count. More...
 

Detailed Description

Stats for fMLLR subspace estimation.

This class is only to estimate the "basis", which is done in training time. Class BasisFmllrEstimate contains the functions that are used in test time. (see the function BasisFmllrCoefficients()).

Definition at line 73 of file basis-fmllr-diag-gmm.h.

Constructor & Destructor Documentation

◆ BasisFmllrAccus() [1/2]

BasisFmllrAccus ( )
inline

Definition at line 76 of file basis-fmllr-diag-gmm.h.

76 { }

◆ BasisFmllrAccus() [2/2]

BasisFmllrAccus ( int32  dim)
inlineexplicit

Definition at line 77 of file basis-fmllr-diag-gmm.h.

77  {
78  dim_ = dim;
79  beta_ = 0;
80  ResizeAccus(dim);
81  }
double beta_
Occupancy count.
int32 dim_
Feature dimension.

Member Function Documentation

◆ AccuGradientScatter()

void AccuGradientScatter ( const AffineXformStats spk_stats)

Accumulate gradient scatter for one (training) speaker.

To finish the process, we need to traverse the whole training set. Parallelization works if the speaker list is splitted, and stats are summed up by setting add=true in BasisFmllrEstimate:: ReadBasis. See section 5.2 of the paper.

Definition at line 91 of file basis-fmllr-diag-gmm.cc.

References MatrixBase< Real >::AddMat(), AffineXformStats::beta_, BasisFmllrAccus::beta_, VectorBase< Real >::CopyRowsFromMat(), rnnlm::d, BasisFmllrAccus::dim_, AffineXformStats::G_, BasisFmllrAccus::grad_scatter_, AffineXformStats::K_, MatrixBase< Real >::Row(), MatrixBase< Real >::Scale(), and MatrixBase< Real >::SetUnit().

92  {
93 
94  // Gradient of auxf w.r.t. xform_spk
95  // Eq. (33)
96  Matrix<double> grad_mat(dim_, dim_ + 1);
97  grad_mat.SetUnit();
98  grad_mat.Scale(spk_stats.beta_);
99  grad_mat.AddMat(1.0, spk_stats.K_);
100  for (int d = 0; d < dim_; ++d) {
101  Matrix<double> G_d_mat(spk_stats.G_[d]);
102  grad_mat.Row(d).AddVec(-1.0, G_d_mat.Row(d));
103  }
104  // Row stack of gradient matrix
105  Vector<BaseFloat> grad_vec((dim_+1) * dim_);
106  grad_vec.CopyRowsFromMat(grad_mat);
107  // The amount of data beta_ is likely to be ZERO, especially
108  // when silence-weight is set to be 0 and we are using the
109  // per-utt mode.
110  if (spk_stats.beta_ > 0) {
111  beta_ += spk_stats.beta_;
112  grad_scatter_.AddVec2(BaseFloat(1.0 / spk_stats.beta_), grad_vec);
113  }
114 }
double beta_
Occupancy count.
float BaseFloat
Definition: kaldi-types.h:29
int32 dim_
Feature dimension.
SpMatrix< BaseFloat > grad_scatter_
Gradient scatter. Dim is [(D+1)*D] [(D+1)*D].

◆ Read()

void Read ( std::istream &  in_stream,
bool  binary,
bool  add = false 
)

Definition at line 66 of file basis-fmllr-diag-gmm.cc.

References BasisFmllrAccus::beta_, kaldi::ExpectToken(), BasisFmllrAccus::grad_scatter_, and kaldi::ReadBasicType().

67  {
68  ExpectToken(is, binary, "<BASISFMLLRACCUS>");
69  ExpectToken(is, binary, "<BETA>");
70  double tmp_beta = 0;
71  ReadBasicType(is, binary, &tmp_beta);
72  if (add) {
73  beta_ += tmp_beta;
74  } else {
75  beta_ = tmp_beta;
76  }
77  ExpectToken(is, binary, "<GRADSCATTER>");
78  grad_scatter_.Read(is, binary, add);
79  ExpectToken(is, binary, "</BASISFMLLRACCUS>");
80 }
double beta_
Occupancy count.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
SpMatrix< BaseFloat > grad_scatter_
Gradient scatter. Dim is [(D+1)*D] [(D+1)*D].

◆ ResizeAccus()

void ResizeAccus ( int32  dim)

Definition at line 82 of file basis-fmllr-diag-gmm.cc.

References BasisFmllrAccus::grad_scatter_, KALDI_ERR, and kaldi::kSetZero.

82  {
83  if (dim <= 0) {
84  KALDI_ERR << "Invalid feature dimension " << dim; // dim=0 is not allowed
85  } else {
86  // 'kSetZero' may not be necessary, but makes computation safe
87  grad_scatter_.Resize((dim + 1) * dim, kSetZero);
88  }
89 }
#define KALDI_ERR
Definition: kaldi-error.h:147
SpMatrix< BaseFloat > grad_scatter_
Gradient scatter. Dim is [(D+1)*D] [(D+1)*D].

◆ Write()

void Write ( std::ostream &  out_stream,
bool  binary 
) const

Routines for reading and writing stats.

Definition at line 53 of file basis-fmllr-diag-gmm.cc.

References BasisFmllrAccus::beta_, BasisFmllrAccus::grad_scatter_, kaldi::WriteBasicType(), and kaldi::WriteToken().

53  {
54 
55  WriteToken(os, binary, "<BASISFMLLRACCUS>");
56  WriteToken(os, binary, "<BETA>");
57  WriteBasicType(os, binary, beta_);
58  if (!binary) os << '\n';
59  if (grad_scatter_.NumCols() != 0) {
60  WriteToken(os, binary, "<GRADSCATTER>");
61  grad_scatter_.Write(os, binary);
62  }
63  WriteToken(os, binary, "</BASISFMLLRACCUS>");
64 }
double beta_
Occupancy count.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
SpMatrix< BaseFloat > grad_scatter_
Gradient scatter. Dim is [(D+1)*D] [(D+1)*D].
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

◆ beta_

◆ dim_

◆ grad_scatter_


The documentation for this class was generated from the following files: