21 #ifndef KALDI_NNET2_NNET_PRECONDITION_ONLINE_H_ 22 #define KALDI_NNET2_NNET_PRECONDITION_ONLINE_H_
std::mutex read_write_mutex_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void ComputeWt1(int32 N, const VectorBase< BaseFloat > &d_t, const VectorBase< BaseFloat > &d_t1, BaseFloat rho_t, BaseFloat rho_t1, const MatrixBase< BaseFloat > &U_t, const VectorBase< BaseFloat > &sqrt_c_t, const VectorBase< BaseFloat > &inv_sqrt_e_t, const CuMatrixBase< BaseFloat > &W_t, CuMatrixBase< BaseFloat > *J_t, CuMatrixBase< BaseFloat > *W_t1) const
Base class which provides matrix operations not involving resizing or allocation. ...
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void InitDefault(int32 D)
static void InitOrthonormalSpecial(CuMatrixBase< BaseFloat > *R)
This function creates a matrix with orthonormal rows that is like the following matrix, except with each row normalized to have unit 2-norm: [ 1.1 0 1 0 1 0 0 1.1 0 1 0 1 ] The reason why the first element in each row is 1.1 and not 1, is for symmetry-breaking...
void ComputeEt(const VectorBase< BaseFloat > &d_t, BaseFloat beta_t, VectorBase< BaseFloat > *e_t, VectorBase< BaseFloat > *sqrt_e_t, VectorBase< BaseFloat > *inv_sqrt_e_t) const
void ComputeZt(int32 N, BaseFloat rho_t, const VectorBase< BaseFloat > &d_t, const VectorBase< BaseFloat > &inv_sqrt_e_t, const MatrixBase< BaseFloat > &K_t, const MatrixBase< BaseFloat > &L_t, SpMatrix< double > *Z_t) const
void SetAlpha(BaseFloat alpha)
BaseFloat Eta(int32 N) const
CuMatrix< BaseFloat > W_t_
void PreconditionDirections(CuMatrixBase< BaseFloat > *R, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)
void SetNumSamplesHistory(BaseFloat num_samples_history)
BaseFloat num_samples_history_
int32 GetUpdatePeriod() const
BaseFloat GetNumSamplesHistory() const
void ReorthogonalizeXt1(const VectorBase< BaseFloat > &d_t1, BaseFloat rho_t1, CuMatrixBase< BaseFloat > *W_t1, CuMatrixBase< BaseFloat > *temp_W, CuMatrixBase< BaseFloat > *temp_O)
Matrix for CUDA computing.
A class representing a vector.
OnlinePreconditioner & operator=(const OnlinePreconditioner &other)
BaseFloat GetAlpha() const
void Init(const CuMatrixBase< BaseFloat > &R0)
void SetUpdatePeriod(int32 update_period)
Keywords for search: natural gradient, naturalgradient, NG-SGD.
Provides a vector abstraction class.
int32 num_updates_skipped_
void PreconditionDirectionsInternal(const int32 t, const BaseFloat rho_t, const Vector< BaseFloat > &d_t, CuMatrixBase< BaseFloat > *WJKL_t, CuMatrixBase< BaseFloat > *X_t, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)
Vector for CUDA computing.