Public Member Functions | |
OnlineNaturalGradientSimple () | |
void | SetRank (int32 rank) |
void | PreconditionDirections (CuMatrixBase< BaseFloat > *R, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale) |
Private Member Functions | |
BaseFloat | Eta (int32 N) const |
void | PreconditionDirectionsCpu (MatrixBase< double > *R, VectorBase< double > *row_prod, BaseFloat *scale) |
void | Init (const MatrixBase< double > &R0) |
void | InitDefault (int32 D) |
Private Attributes | |
int32 | rank_ |
double | num_samples_history_ |
double | alpha_ |
double | epsilon_ |
double | delta_ |
Vector< double > | d_t_ |
Matrix< double > | R_t_ |
double | rho_t_ |
Definition at line 28 of file natural-gradient-online-test.cc.
|
inline |
Definition at line 30 of file natural-gradient-online-test.cc.
Definition at line 120 of file natural-gradient-online-test.cc.
References KALDI_ASSERT, and OnlineNaturalGradientSimple::num_samples_history_.
Referenced by OnlineNaturalGradientSimple::PreconditionDirectionsCpu(), and OnlineNaturalGradientSimple::SetRank().
|
private |
Definition at line 108 of file natural-gradient-online-test.cc.
References rnnlm::i, OnlineNaturalGradientSimple::InitDefault(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), and OnlineNaturalGradientSimple::PreconditionDirections().
Referenced by OnlineNaturalGradientSimple::PreconditionDirectionsCpu(), and OnlineNaturalGradientSimple::SetRank().
|
private |
Definition at line 84 of file natural-gradient-online-test.cc.
References OnlineNaturalGradientSimple::d_t_, OnlineNaturalGradientSimple::epsilon_, rnnlm::i, KALDI_WARN, OnlineNaturalGradientSimple::R_t_, OnlineNaturalGradientSimple::rank_, Vector< Real >::Resize(), Matrix< Real >::Resize(), OnlineNaturalGradientSimple::rho_t_, and VectorBase< Real >::Set().
Referenced by OnlineNaturalGradientSimple::Init(), and OnlineNaturalGradientSimple::SetRank().
void PreconditionDirections | ( | CuMatrixBase< BaseFloat > * | R, |
CuVectorBase< BaseFloat > * | row_prod, | ||
BaseFloat * | scale | ||
) |
Definition at line 67 of file natural-gradient-online-test.cc.
References MatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::CopyFromMat(), CuVectorBase< Real >::CopyFromVec(), VectorBase< Real >::CopyFromVec(), and OnlineNaturalGradientSimple::PreconditionDirectionsCpu().
Referenced by OnlineNaturalGradientSimple::Init(), OnlineNaturalGradientSimple::SetRank(), and kaldi::nnet3::UnitTestPreconditionDirectionsOnline().
|
private |
Definition at line 128 of file natural-gradient-online-test.cc.
References VectorBase< Real >::Add(), VectorBase< Real >::AddDiagMat2(), SpMatrix< Real >::AddMat2(), SpMatrix< Real >::AddMat2Vec(), MatrixBase< Real >::AddMatMat(), MatrixBase< Real >::AddMatSp(), SpMatrix< Real >::AddSp(), PackedMatrix< Real >::AddToDiag(), VectorBase< Real >::AddVec(), OnlineNaturalGradientSimple::alpha_, VectorBase< Real >::ApplyFloor(), VectorBase< Real >::ApplyPow(), kaldi::AssertEqual(), MatrixBase< Real >::CopyFromMat(), VectorBase< Real >::CopyFromVec(), OnlineNaturalGradientSimple::d_t_, OnlineNaturalGradientSimple::delta_, SpMatrix< Real >::Eig(), OnlineNaturalGradientSimple::epsilon_, OnlineNaturalGradientSimple::Eta(), rnnlm::i, OnlineNaturalGradientSimple::Init(), SpMatrix< Real >::Invert(), VectorBase< Real >::InvertElements(), SpMatrix< Real >::IsUnit(), rnnlm::j, KALDI_ASSERT, KALDI_VLOG, KALDI_WARN, kaldi::kNoTrans, kaldi::kTrans, VectorBase< Real >::Max(), VectorBase< Real >::Min(), MatrixBase< Real >::MulRowsVec(), VectorBase< Real >::Norm(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), OnlineNaturalGradientSimple::R_t_, OnlineNaturalGradientSimple::rho_t_, VectorBase< Real >::Scale(), kaldi::SortSvd(), VectorBase< Real >::Sum(), SpMatrix< Real >::Trace(), kaldi::TraceMatMat(), and kaldi::VecVec().
Referenced by OnlineNaturalGradientSimple::PreconditionDirections(), and OnlineNaturalGradientSimple::SetRank().
|
inline |
Definition at line 33 of file natural-gradient-online-test.cc.
References OnlineNaturalGradientSimple::Eta(), OnlineNaturalGradientSimple::Init(), OnlineNaturalGradientSimple::InitDefault(), OnlineNaturalGradientSimple::PreconditionDirections(), OnlineNaturalGradientSimple::PreconditionDirectionsCpu(), and OnlineNaturalGradientSimple::rank_.
Referenced by kaldi::nnet3::UnitTestPreconditionDirectionsOnline().
|
private |
Definition at line 56 of file natural-gradient-online-test.cc.
Referenced by OnlineNaturalGradientSimple::PreconditionDirectionsCpu().
|
private |
Definition at line 61 of file natural-gradient-online-test.cc.
Referenced by OnlineNaturalGradientSimple::InitDefault(), and OnlineNaturalGradientSimple::PreconditionDirectionsCpu().
|
private |
Definition at line 58 of file natural-gradient-online-test.cc.
Referenced by OnlineNaturalGradientSimple::PreconditionDirectionsCpu().
|
private |
Definition at line 57 of file natural-gradient-online-test.cc.
Referenced by OnlineNaturalGradientSimple::InitDefault(), and OnlineNaturalGradientSimple::PreconditionDirectionsCpu().
|
private |
Definition at line 55 of file natural-gradient-online-test.cc.
Referenced by OnlineNaturalGradientSimple::Eta().
|
private |
Definition at line 62 of file natural-gradient-online-test.cc.
Referenced by OnlineNaturalGradientSimple::InitDefault(), and OnlineNaturalGradientSimple::PreconditionDirectionsCpu().
|
private |
Definition at line 54 of file natural-gradient-online-test.cc.
Referenced by OnlineNaturalGradientSimple::InitDefault(), and OnlineNaturalGradientSimple::SetRank().
|
private |
Definition at line 63 of file natural-gradient-online-test.cc.
Referenced by OnlineNaturalGradientSimple::InitDefault(), and OnlineNaturalGradientSimple::PreconditionDirectionsCpu().