OnlinePreconditioner Class Reference

Keywords for search: natural gradient, naturalgradient, NG-SGD. More...

#include <nnet-precondition-online.h>

Collaboration diagram for OnlinePreconditioner:

Public Member Functions

 OnlinePreconditioner ()
 
void SetRank (int32 rank)
 
void SetUpdatePeriod (int32 update_period)
 
void SetNumSamplesHistory (BaseFloat num_samples_history)
 
void SetAlpha (BaseFloat alpha)
 
void TurnOnDebug ()
 
BaseFloat GetNumSamplesHistory () const
 
BaseFloat GetAlpha () const
 
int32 GetRank () const
 
int32 GetUpdatePeriod () const
 
void PreconditionDirections (CuMatrixBase< BaseFloat > *R, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)
 
 OnlinePreconditioner (const OnlinePreconditioner &other)
 
OnlinePreconditioneroperator= (const OnlinePreconditioner &other)
 

Private Member Functions

void PreconditionDirectionsInternal (const int32 t, const BaseFloat rho_t, const Vector< BaseFloat > &d_t, CuMatrixBase< BaseFloat > *WJKL_t, CuMatrixBase< BaseFloat > *X_t, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)
 
void ComputeEt (const VectorBase< BaseFloat > &d_t, BaseFloat beta_t, VectorBase< BaseFloat > *e_t, VectorBase< BaseFloat > *sqrt_e_t, VectorBase< BaseFloat > *inv_sqrt_e_t) const
 
void ComputeZt (int32 N, BaseFloat rho_t, const VectorBase< BaseFloat > &d_t, const VectorBase< BaseFloat > &inv_sqrt_e_t, const MatrixBase< BaseFloat > &K_t, const MatrixBase< BaseFloat > &L_t, SpMatrix< double > *Z_t) const
 
void ComputeWt1 (int32 N, const VectorBase< BaseFloat > &d_t, const VectorBase< BaseFloat > &d_t1, BaseFloat rho_t, BaseFloat rho_t1, const MatrixBase< BaseFloat > &U_t, const VectorBase< BaseFloat > &sqrt_c_t, const VectorBase< BaseFloat > &inv_sqrt_e_t, const CuMatrixBase< BaseFloat > &W_t, CuMatrixBase< BaseFloat > *J_t, CuMatrixBase< BaseFloat > *W_t1) const
 
void ReorthogonalizeXt1 (const VectorBase< BaseFloat > &d_t1, BaseFloat rho_t1, CuMatrixBase< BaseFloat > *W_t1, CuMatrixBase< BaseFloat > *temp_W, CuMatrixBase< BaseFloat > *temp_O)
 
void Init (const CuMatrixBase< BaseFloat > &R0)
 
void InitDefault (int32 D)
 
BaseFloat Eta (int32 N) const
 
void SelfTest () const
 

Static Private Member Functions

static void InitOrthonormalSpecial (CuMatrixBase< BaseFloat > *R)
 This function creates a matrix with orthonormal rows that is like the following matrix, except with each row normalized to have unit 2-norm: [ 1.1 0 1 0 1 0 0 1.1 0 1 0 1 ] The reason why the first element in each row is 1.1 and not 1, is for symmetry-breaking... More...
 

Private Attributes

int32 rank_
 
int32 update_period_
 
BaseFloat num_samples_history_
 
BaseFloat alpha_
 
BaseFloat epsilon_
 
BaseFloat delta_
 
int32 t_
 
int32 num_updates_skipped_
 
bool self_debug_
 
CuMatrix< BaseFloatW_t_
 
BaseFloat rho_t_
 
Vector< BaseFloatd_t_
 
std::mutex read_write_mutex_
 
std::mutex update_mutex_
 

Detailed Description

Keywords for search: natural gradient, naturalgradient, NG-SGD.

This method is explained in the paper "Parallel training of DNNs with Natural Gradient and Parameter Averaging" by D. Povey, X. Zhang and S. Khudanpur, ICLR Workshop, 2015, where it is referred to as online NG-SGD. Note that the method exported from this header is just the core of the algorithm, and some outer-level parts of it are implemented in class NaturalGradientAffineComponent.

The rest of this extended comment describes the way we keep updated an estimate of the inverse of a scatter matrix, in an online way. This is the same as the estimation of one of the A or B quantities in the paper. This comment is slightly redundant with the paper- actually it precedes the paper- but we keep it in case it is useful in understanging our method.

We consider the problem of doing online estimation of a (scaled-identity plus low-rank) approximation of a Fisher matrix... since the Fisher matrix is a scatter of vector-valued derivatives and we will be given the derivatives (or at least terms in a factorization of the derivatives which need not concern us right now), we can just think of the present task as being the online accumulation of a (low-rank plus scaled-identity) approximation to a variance of a distribution with mean zero.

Later on we'll think about how to get easy access to the inverse of this approximate variance, which is what we really need.

Our approximation to the Fisher matrix (the scatter of derivatives) will be of the following form (and just think of this as an approximate variance matrix of some arbitrary quantities).

F_t =(def) R_t^T D_t R_t + I

(t is the minibatch index), where R_t is an R by D matrix with orthonormal rows (1 <= R < D is our chosen rank), D_t is a positive-definite diagonal matrix, and > 0. Suppose the dimension of F_t is D. Let the vectors whose variance we are approximating be provided in minibatches of size M (M can vary from iteration to iteration, but it won't vary in the normal case, so we omit the subscript t). The batch of gradients is given as X_t Re^{M D}, i.e. each row is one of the vectors whose scatter we're estimating. On the t'th iteration, define the scatter S_t of the input vectors X_t as:

S_t =(def) 1/N X_t^T X_t (eqn:St)

(where N is the minibatch size). Be careful not to confuse the rank R with with input X_t (we would typeface X_t in bold if this were not plain text, to make the distinction clearer). We want F_t to approach some kind of time-weighted average of the S_t quantities, to the extent permitted by the limitation of the rank R. We want the F_t quantities to stay "fresh" (since we'll be doing this in a SGD context and the parameters will be slowly changing). We use a constant 0 < < 1 to control the updating rate. Our update for R_t is based on the power method. Define the smoothed scatter

T_t =(def) S_t + (1-) F_t

we'll use this in place of the observed scatter S_t, to slow down the update. Defining

Y_t =(def) R_t T_t

which can be expanded as follows: Y_t = R_t ( S_t + (1-) F_t ) = R_t ( S_t + (1-) (R_t^T D_t R_t + I) ) = R_t ( S_t + (1-) (R_t^T D_t R_t + I) ) = R_t S_t + (1-) (D_t + I) R_t

It is useful to think of Y_t as having each of the top eigenvectors of the scatter scaled by the corresponding eigenvalue . We compute the following R by R matrix: Z_t =(def) Y_t Y_t^T and do the symmetric eigenvalue decomposition Z_t = U_t C_t U_t^T where C_t is diagonal and U_t orthogonal; the diagonal elements of C_t will be positive (since > 0, T_t is positive definite; since R_t has full row rank and T_t is positive definite, Y_t has full row rank; hence Z_t is positive definite). The diagonal elements of C_t can be thought of as corresponding to the squares of our current estimate of the top eigenvalues of the scatter matrix. [we should check that no element of C_t is <= 0.]

It is easy to show that C_t^{-0.5} U_t^T Z_t U_t C_t^{-0.5} = I, so (C_t^{-0.5} U_t^T Y_t) (Y_t^T U_t C_t^{-0.5}) = I. Define R_{t+1} =(def) C_t^{-0.5} U_t^T Y_t

and it's clear that R_{t+1} R_{t+1}^T = I. We will set D_{t+1} =(def) C_t^{0.5} - {t+1} I (eqn:dt1)

which ensures that for each row r of R_{t+1}, the variance of our scatter matrix F_{t+1} will be the square root of the corresponding diagonal element of C_t. This makes sense because, as we have pointed out, the diagonal elements of C_t can be thought of as corresponding to squared eigenvalues. But a proper treatment of this would require convergence analysis that would get quite complicated. We will choose {t+1} in order to ensure that tr(F_{t+1}) = tr(T_t).

For any t, tr(F_t) = D + tr(D_t) tr(T_t) = tr(S_t) + (1-) tr(F_t) = tr(S_t) + (1-) (D + tr(D_t)) Expanding out D_{t+1} from (eqn:dt1) in the expression for tr(F_{t+1}) below: tr(F_{t+1}) = D {t+1} + tr(D_{t+1}) tr(F_{t+1}) = D {t+1} + tr(C_t^{0.5} - {t+1} I) = (D - R) {t+1} + tr(C_t^{0.5}) and equating tr(F_{t+1}) with T_t (since F_{t+1} is supposed to be a low-rank approximation to T_t), we have tr(F_{t+1}) = tr(T_t) (D - R) {t+1} + tr(C_t^{0.5}) = tr(S_t) + (1-) (D + tr(D_t))

Solving for {t+1}, {t+1} = 1/(D - R) ( tr(S_t) + (1-)(D + tr(D_t)) - tr(C_t^{0.5})). (eqn:rhot1)

Note that it is technically possible that diagonal elements of of D_{t+1} may be negative, but we can still show that F_{t+1} is strictly positive definite if F_t was strictly positive definite.

If the quantities for which we are computing the Fisher matrix are all zero for some, reason, the sequence of F_t will geometrically approach zero, which would cause problems with inversion; to prevent this happening, after setting D_{t+1} and {t+1} as above, we floor {t+1} to a small value (like 1.0e-10).

OK, we have described the updating of R_t, D_t and . Next, we need to figure out how to efficiently multiply by the inverse of F_t. Our experience from working with the old preconditioning method was that it's best not to use the inverse of the Fisher matrix itself, but a version of the Fisher matrix that's smoothed with some constant times the identity. Below, ( is a configuration value, e.g. 4.0 seemed to work well). The following formula is designed to ensure that the smoothing varies proportionally with the scale of F_t:

G_t =(def) F_t + /D tr(F_t) I = R_t^T D_t R_t + ( + /D tr(F_t)) I = R_t^T D_t R_t + I where =(def) + /D tr(F_t) = (1+) + /D tr(D_t) (eqn:betat2)

Define {X}_t =(def) X_t G_t^{-1}. the factor of is inserted arbitrarily as it just happens to be convenient to put unit scale on X_t in the formula for {X}_t; it will anyway be canceled out in the next step. Then our final preconditioned minibatch of vectors is: {X}_t = {X}_t where = sqrt(tr(X_t X_t^T) / tr({X}_t {X}_t^T). The factor of ensures that {X}_t is scaled to have the same overall 2-norm as the input X_t. We found in previous versions of this method that this rescaling was helpful, as otherwise there are certain situations (e.g. at the start of training) where the preconditioned derivatives can get very large. Note that this rescaling introduces a small bias into the training, because now the scale applied to a given sample depends on that sample itself, albeit in an increasingly diluted way as the minibatch size gets large.

To efficiently compute G_t^{-1}, we will use the Woodbury matrix identity. Writing the Woodbury formula for the symmetric case, (A + U D U^T)^{-1} = A^{-1} - A^{-1} U (D^{-1} + U^T A^{-1} U)^{-1} U^T A^{-1} Substituting A = I, D = D_t and U = R_t^T, this becomes G_t^{-1} = 1/ I - 1/^2 R_t^T (D_t^{-1} + 1/ I)^{-1} R_t = 1/ (I - R_t^T E_t R_t) where E_t =(def) 1/ (D_t^{-1} + 1/ I)^{-1}, (eqn:etdef) so e_{tii} = 1/ * 1/(1/d_{tii} + 1/) (eqn:tii) = 1/(/d_{tii} + 1)

We would like an efficient-to-compute expression for {X}_t, without too many separate invocations of kernels on the GPU. {X}_t = X_t G_t^{-1} = X_t - X_t R_t^T E_t R_t For efficient operation on the GPU, we want to reduce the number of high-dimensional operations that we do (defining "high-dimension" as anything involving D or M, but not R, since R is likely small, such as 20). We define W_t =(def) E_t^{0.5} R_t. We will actually be storing W_t on the GPU rather than R_t, in order to reduce the number of operations on the GPU. We can now write:

{X}_t = X_t - X_t W_t^T W_t (eqn:pt2)

The following, which we'll compute on the GPU, are going to be useful in computing quantities like Z_t:

H_t =(def) X_t W_t^T (dim is N by R) J_t =(def) H_t^T X_t (dim is R by D) = W_t X_t^T X_t K_t =(def) J_t J_t^T (dim is R by R, symmetric).. transfer this to CPU. L_t =(def) H_t^T H_t (dim is R by R, symmetric).. transfer this to CPU. = W_t X_t^T X_t W_t^T Note: L_t may also be computed as L_t = J_t W_t^T which may be more efficient if D < N.

Note: after we have computed H_t we can directly compute {X}_t = X_t - H_t W_t

We need to determine how Y_t and Z_t relate to the quantities we just defined. First, we'll expand out H_t, J_t, K_t and L_t in terms of the more fundamental quantities. H_t = X_t R_t^T E_t^{0.5} J_t = E_t^{0.5} R_t X_t^T X_t K_t = E_t^{0.5} R_t X_t^T X_t X_t^T X_t R_t^T E_t^{0.5} L_t = E_t^{0.5} R_t X_t^T X_t R_t^T E_t^{0.5}

we wrote above that Y_t = R_t S_t + (1-) (D_t + I) R_t so Y_t = /N R_t X_t^T X_t + (1-) (D_t + I) R_t = /N E_t^{-0.5} J_t + (1-) (D_t + I) R_t (eqn:yt) We will expand Z_t using the expression for Y_t in the line above: Z_t = Y_t Y_t^T = (/N)^2 E_t^{-0.5} J_t J_t^T E_t^{-0.5} +(/N)(1-) E_t^{-0.5} J_t R_t^T (D_t + I) +(/N)(1-) (D_t + I) R_t J_t^T E_t^{-0.5} +(1-)^2 (D_t + I)^2 = (/N)^2 E_t^{-0.5} K_t E_t^{-0.5} +(/N)(1-) E_t^{-0.5} L_t E_t^{-0.5} (D_t + I) +(/N)(1-) (D_t + I) E_t^{-0.5} L_t E_t^{-0.5} +(1-)^2 (D_t + I)^2 (eqn:Zt) We compute Z_t on the CPU using the expression above, and then do the symmetric eigenvalue decomposition (also on the CPU): Z_t = U_t C_t U_t^T. and we make sure the eigenvalues are sorted from largest to smallest, for reasons that will be mentioned later.

Mathematically, no diagonal element of C_t can be less than (1-)^2 ^2, and since negative or zero elements of C_t would cause us a problem later, we floor C_t to this value. (see below regarding how we ensure R_{t+1} has orthonormal rows).

We will continue the discussion below regarding what we do with C_t and U_t. Next, we need to digress briefly and describe how to compute tr({X}_t {X}_t^T) and tr(X_t X_t^2), since these appear in expressions for (needed to produce the output {X}_t), and for {t+1}. It happens that we need, for purposes of appying "max_change" in the neural net code, the squared 2-norm of each row of the output {X}_t. In order to be able to compute , it's most convenient to compute this squared row-norm for each row of {X}_t, as a vector, to compute tr({X}_t {X}_t^2) from this vector as its sum, and to then work back to compute tr(X_t X_t^2) from the relation between {X}_t and X_t. We can then scale the row-norms we computed for {X}_t, so they apply to {X}_t.

For current purposes, you can imagine that we computed tr({X}_t {X}_t^T) directly. Using (from eqn:pt2) {X}_t = X_t - X_t W_t^T W_t, we can expand tr({X}_t {X}_t^T) as: tr({X}_t {X}_t^T) = tr(X_t X_t^T) + tr(X_t W_t^T W_t W_t^T W_t X_t^T)

  • 2 tr(X_t W_t^T W_t X_t^T) = tr(X_t X_t^T) + tr(W_t X_t^T X_t W_t^T W_t W_t^T)
  • 2 tr(W_t X_t^T X_t W_t^T) = tr(X_t X_t^T) + tr(L_t W_t W_t^T) - 2 tr(L_t) = tr(X_t X_t^T) + tr(L_t E_t) - 2 tr(L_t) and all quantities have already been computed (or are quick to compute, such as the small traces on the right), except tr(X_t X_t^T), so we can write

tr(X_t X_t^T) = tr({X}_t {X}_t^T) - tr(L_t E_t) + 2 tr(L_t) and the above expression can be used to obtain tr(X_t X_t^2). We can then do <– sqrt(tr(X_t X_t^T) / tr({X}_t {X}_t^T)). (or one if the denominator is zero), and then {X}_t <– {X}_t We can then output the per-row squared-l2-norms of Q by scaling those we computed from P by ^2.

OK, the digression on how to compute and tr(X_t X_t^T) is over. We now return to the computation of R_{t+1}, W_{t+1}, {t+1}, D_{t+1} and E_{t+1}.

We found above in (eqn:rhot1) {t+1} = 1/(D - R) ( tr(S_t) + (1-)(D + tr(D_t)) - tr(C_t^{0.5})). Expanding out S_t from its definition in (eqn:St), {t+1} = 1/(D - R) (/N tr(X_t X_t^T) + (1-)(D + tr(D_t)) - tr(C_t^{0.5})). We can compute this directly as all the quantities involved are already known or easy to compute. Next, from (eqn:dt1), we compute D_{t+1} = C_t^{0.5} - {t+1} I At this point if {t+1} is smaller than some small value , e.g. 1.0e-10, we set it to ; as mentioned, we do this to stop F_t approaching zero if all inputs are zero. Next, if any diagonal element D_{t+1,i,i} has absolute value less than , we set it to +. This is to ensure that diagonal elements of E are never zero, which would cause problems.

Next, we compute (from eqn:betat2, eqn:etdef, eqn:tii), {t+1} = {t+1} (1+) + /D tr(D_{t+1}) E_{t+1} = 1/{t+1} (D_{t+1}^{-1} + 1/{t+1} I)^{-1}, i.e.: e_{tii} = 1/({t+1}/d_{t+1,ii} + 1)

We'll want to store D_{t+1}. We next want to compute W_{t+1}.

Before computing W_{t+1}, we need to find an expression for R_{t+1} = C_t^{-0.5} U_t^T Y_t Expanding out Y_t using the expression in (eqn:yt), R_{t+1} = C_t^{-0.5} U_t^T (/N E_t^{-0.5} J_t + (1-) (D_t + I) R_t) = (/N C_t^{-0.5} U_t^T E_t^{-0.5}) J_t +((1-) C_t^{-0.5} U_t^T (D_t + I) E_t^{-0.5}) W_t

What we actually want is W_{t+1} = E_{t+1}^{0.5} R_{t+1}: W_{t+1} = (/N E_{t+1}^{0.5} C_t^{-0.5} U_t^T E_t^{-0.5}) J_t +((1-) E_{t+1}^{0.5} C_t^{-0.5} U_t^T (D_t + I) E_t^{-0.5}) W_t and to minimize the number of matrix-matrix multiplies we can factorize this as: W_{t+1} = A_t B_t A_t = (/N) E_{t+1}^{0.5} C_t^{-0.5} U_t^T E_t^{-0.5} B_t = J_t + (1-)/(/N) (D_t + I) W_t [note: we use the fact that (D_t + I) and E_t^{-0.5} commute because they are diagonal].

A_t is computed on the CPU and transferred from there to the GPU, B_t is computed on the PGU, and the multiplication of A_t with B_t is done on the GPU.

Keeping R_t orthogonal *

Our method requires the R_t matrices to be orthogonal (which we define to mean that R_t R_t^T = I). If roundoff error causes this equality to be significantly violated, it could cause a problem for the stability of our method. We now address our method for making sure that the R_t values stay orthogonal. We do this in the algorithm described above, after creating W_{t+1}. This extra step is only executed if the condition number of C_t (i.e. the ratio of its largest to smallest diagonal element) exceeds a specified threshold, such as 1.0e+06 [this is tested before applying the floor to C_t]. The threshold was determined empirically by finding the largest value needed to ensure a certain level of orthogonality in R_{t+1}. For purposes of the present discussion, since R_{t+1} is not actually stored, define it as E_{t+1}^{-0.5} W_{t+1}. Define the following (and we will just use t instead of t+1 below, as all quantities have the same subscript):

O_t =(def) R_t R_t^T = E_t^{-0.5} W_t W_t^T E_t^{-0.5}

(and we would compute this by computing W_t W_t^T on the GPU, transferring it to the CPU, and doing the rest there). If O_t is not sufficiently close to the unit matrix, we can re-orthogonalize as follows: Do the Cholesky decomposition O_t = C C^T Clearly C^{-1} O_t C^{-T} = I, so if we correct R_t with: R_t <– C^{-1} R_t we can ensure orthogonality. If R_t's first k rows are orthogonal, this transform will not affect them, because of its lower-triangular structure... this is good because (thanks to the eigenvalue sorting), the larger eigenvectors are first and it is more critical to keep them pointing in the same direction. Any loss of orthogonality will be dealt with by modifying the smaller eigenvectors. As a modification to W_t, this would be: W_t <– (E_t^{0.5} C^{-1} E_t^{-0.5}) W_t, and the matrix in parentheses is computed on the CPU, transferred to the GPU, and the multiplication is done there.

Initialization *

Now, a note on what we do on time t = 0, i.e. for the first minibatch. We initialize X_0 to the top R eigenvectors of 1/N X_0 X_0^T, where N is the minibatch size (num-rows of R0). If L is the corresponding RxR diagonal matrix of eigenvalues, then we will set D_0 = L - I. We set to ensure that tr(F_0) = 1/N tr(X_0 X_0^T), tr(D_0) - D = 1/N tr(X_0 X_0^T), tr(L) + R - D = 1/N tr(X_0 X_0^T) = (1/N tr(X_0 X_0^T) - tr(L)) / (D - R)

We then floor to (e.g. 1.0e-10) and also floor the diagonal elements of D_0 to ; this ensures that we won't crash for zero inputs.

A note on multi-threading. This technique was really designed for use with a GPU, where we won't have multi-threading, but we want it to work also on a CPU, where we may have multiple worker threads. Our approach is as follows (we do this when we're about to start updating the parameters R_t, D_t, and derived quantities):

For time t > 0 (where the matrices are already initialized), before starting the part of the computation that updates the parameters (R_t, D_t, and derived quantities), we try to lock a mutex that guards the OnlinePreconditioner. If we can lock it right away, we go ahead and do the update, but if not, we just abandon the attempt to update those quantities.

We will have another mutex to ensure that when we access quantities like W_t, they are all "in sync" (and we don't access them while they are being written by another thread). This mutex will only be locked for short periods of time.

Note: it might be a good idea to make sure that the R_t still retain orthonormal rows even in the presence of roundoff, without errors accumulating. My instinct is that this isn't going to be a problem.

Definition at line 413 of file nnet-precondition-online.h.

Constructor & Destructor Documentation

◆ OnlinePreconditioner() [1/2]

◆ OnlinePreconditioner() [2/2]

OnlinePreconditioner ( const OnlinePreconditioner other)
explicit

Definition at line 595 of file nnet-precondition-online.cc.

595  :
596  rank_(other.rank_), update_period_(other.update_period_),
597  num_samples_history_(other.num_samples_history_),
598  alpha_(other.alpha_), epsilon_(other.epsilon_), delta_(other.delta_),
599  t_(other.t_), num_updates_skipped_(other.num_updates_skipped_),
600  self_debug_(other.self_debug_), W_t_(other.W_t_),
601  rho_t_(other.rho_t_), d_t_(other.d_t_) {
602  // use default constructor for the mutexes.
603 }

Member Function Documentation

◆ ComputeEt()

void ComputeEt ( const VectorBase< BaseFloat > &  d_t,
BaseFloat  beta_t,
VectorBase< BaseFloat > *  e_t,
VectorBase< BaseFloat > *  sqrt_e_t,
VectorBase< BaseFloat > *  inv_sqrt_e_t 
) const
private

Definition at line 577 of file nnet-precondition-online.cc.

References VectorBase< Real >::ApplyPow(), VectorBase< Real >::CopyFromVec(), rnnlm::d, VectorBase< Real >::Data(), VectorBase< Real >::Dim(), rnnlm::i, and VectorBase< Real >::InvertElements().

Referenced by OnlinePreconditioner::ComputeWt1(), OnlinePreconditioner::GetUpdatePeriod(), OnlinePreconditioner::PreconditionDirectionsInternal(), OnlinePreconditioner::ReorthogonalizeXt1(), and OnlinePreconditioner::SelfTest().

581  {
582  // e_{tii} = 1/(\beta_t/d_{tii} + 1)
583  int32 D = d_t.Dim();
584  const BaseFloat *d = d_t.Data();
585  BaseFloat *e = e_t->Data();
586  for (int32 i = 0; i < D; i++)
587  e[i] = 1.0 / (beta_t / d[i] + 1);
588  sqrt_e_t->CopyFromVec(*e_t);
589  sqrt_e_t->ApplyPow(0.5);
590  inv_sqrt_e_t->CopyFromVec(*sqrt_e_t);
591  inv_sqrt_e_t->InvertElements();
592 }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29

◆ ComputeWt1()

void ComputeWt1 ( int32  N,
const VectorBase< BaseFloat > &  d_t,
const VectorBase< BaseFloat > &  d_t1,
BaseFloat  rho_t,
BaseFloat  rho_t1,
const MatrixBase< BaseFloat > &  U_t,
const VectorBase< BaseFloat > &  sqrt_c_t,
const VectorBase< BaseFloat > &  inv_sqrt_e_t,
const CuMatrixBase< BaseFloat > &  W_t,
CuMatrixBase< BaseFloat > *  J_t,
CuMatrixBase< BaseFloat > *  W_t1 
) const
private

Definition at line 501 of file nnet-precondition-online.cc.

References CuMatrixBase< Real >::AddDiagVecMat(), CuMatrixBase< Real >::AddMatMat(), OnlinePreconditioner::alpha_, OnlinePreconditioner::ComputeEt(), VectorBase< Real >::Dim(), OnlinePreconditioner::Eta(), rnnlm::i, VectorBase< Real >::InvertElements(), rnnlm::j, KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, kaldi::kUndefined, CuMatrixBase< Real >::NumCols(), and VectorBase< Real >::Sum().

Referenced by OnlinePreconditioner::GetUpdatePeriod(), and OnlinePreconditioner::PreconditionDirectionsInternal().

511  {
512 
513  int32 R = d_t.Dim(), D = W_t.NumCols();
514  BaseFloat eta = Eta(N);
515 
516  // \beta_{t+1} = \rho_{t+1} (1+\alpha) + \alpha/D tr(D_{t+1})
517  BaseFloat beta_t1 = rho_t1 * (1.0 + alpha_) + alpha_ * d_t1.Sum() / D;
518  KALDI_ASSERT(beta_t1 > 0.0);
519  Vector<BaseFloat> e_t1(R, kUndefined), sqrt_e_t1(R, kUndefined),
520  inv_sqrt_e_t1(R, kUndefined);
521  ComputeEt(d_t1, beta_t1, &e_t1, &sqrt_e_t1, &inv_sqrt_e_t1);
522  Vector<BaseFloat> inv_sqrt_c_t(sqrt_c_t);
523  inv_sqrt_c_t.InvertElements();
524 
525  Vector<BaseFloat> w_t_coeff(R);
526  for (int32 i = 0; i < R; i++)
527  w_t_coeff(i) = (1.0 - eta) / (eta/N) * (d_t(i) + rho_t);
528  CuVector<BaseFloat> w_t_coeff_gpu(w_t_coeff);
529  // B_t = J_t + (1-\eta)/(\eta/N) (D_t + \rho_t I) W_t
530  J_t->AddDiagVecMat(1.0, w_t_coeff_gpu, W_t, kNoTrans, 1.0);
531 
532  // A_t = (\eta/N) E_{t+1}^{0.5} C_t^{-0.5} U_t^T E_t^{-0.5} B_t
533  Matrix<BaseFloat> A_t(U_t, kTrans);
534  for (int32 i = 0; i < R; i++) {
535  BaseFloat i_factor = (eta / N) * sqrt_e_t1(i) * inv_sqrt_c_t(i);
536  for (int32 j = 0; j < R; j++) {
537  BaseFloat j_factor = inv_sqrt_e_t(j);
538  A_t(i, j) *= i_factor * j_factor;
539  }
540  }
541  // W_{t+1} = A_t B_t
542  CuMatrix<BaseFloat> A_t_gpu(A_t);
543  W_t1->AddMatMat(1.0, A_t_gpu, kNoTrans, *J_t, kNoTrans, 0.0);
544 }
kaldi::int32 int32
void ComputeEt(const VectorBase< BaseFloat > &d_t, BaseFloat beta_t, VectorBase< BaseFloat > *e_t, VectorBase< BaseFloat > *sqrt_e_t, VectorBase< BaseFloat > *inv_sqrt_e_t) const
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ComputeZt()

void ComputeZt ( int32  N,
BaseFloat  rho_t,
const VectorBase< BaseFloat > &  d_t,
const VectorBase< BaseFloat > &  inv_sqrt_e_t,
const MatrixBase< BaseFloat > &  K_t,
const MatrixBase< BaseFloat > &  L_t,
SpMatrix< double > *  Z_t 
) const
private

Definition at line 546 of file nnet-precondition-online.cc.

References VectorBase< Real >::Add(), VectorBase< Real >::Dim(), OnlinePreconditioner::Eta(), rnnlm::i, and rnnlm::j.

Referenced by OnlinePreconditioner::GetUpdatePeriod(), and OnlinePreconditioner::PreconditionDirectionsInternal().

552  {
553  // Use doubles because the range of quantities in Z_t can get large (fourth
554  // power of data), and we want to avoid overflow. This routine is fast.
555  BaseFloat eta = Eta(N);
556  Vector<BaseFloat> d_t_rho_t(d_t);
557  d_t_rho_t.Add(rho_t); // now d_t_rho_t is diag(D_t + \rho_t I).
558  double etaN = eta / N, eta1 = 1.0 - eta,
559  etaN_sq = etaN * etaN, eta1_sq = eta1 * eta1,
560  etaN_eta1 = etaN * eta1;
561  int32 R = d_t.Dim();
562  for (int32 i = 0; i < R; i++) {
563  double inv_sqrt_e_t_i = inv_sqrt_e_t(i), d_t_rho_t_i = d_t_rho_t(i);
564  for (int32 j = 0; j <= i; j++) {
565  double inv_sqrt_e_t_j = inv_sqrt_e_t(j), d_t_rho_t_j = d_t_rho_t(j),
566  L_t_i_j = 0.5 * (L_t(i, j) + L_t(j, i)),
567  K_t_i_j = 0.5 * (K_t(i, j) + K_t(j, i));
568  // See (eqn:Zt) in header.
569  (*Z_t)(i, j) = etaN_sq * inv_sqrt_e_t_i * K_t_i_j * inv_sqrt_e_t_j
570  + etaN_eta1 * inv_sqrt_e_t_i * L_t_i_j * inv_sqrt_e_t_j * d_t_rho_t_j
571  + etaN_eta1 * d_t_rho_t_i * inv_sqrt_e_t_i * L_t_i_j * inv_sqrt_e_t_j
572  + (i == j ? eta1_sq * d_t_rho_t_i * d_t_rho_t_i : 0.0);
573  }
574  }
575 }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29

◆ Eta()

BaseFloat Eta ( int32  N) const
private

Definition at line 492 of file nnet-precondition-online.cc.

References KALDI_ASSERT, and OnlinePreconditioner::num_samples_history_.

Referenced by OnlinePreconditioner::ComputeWt1(), OnlinePreconditioner::ComputeZt(), OnlinePreconditioner::GetUpdatePeriod(), and OnlinePreconditioner::PreconditionDirectionsInternal().

492  {
494  BaseFloat ans = 1.0 - exp(-N / num_samples_history_);
495  // Don't let eta approach 1 too closely, as it can lead to NaN's appearing if
496  // the input is all zero.
497  if (ans > 0.9) ans = 0.9;
498  return ans;
499 }
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetAlpha()

BaseFloat GetAlpha ( ) const
inline

Definition at line 424 of file nnet-precondition-online.h.

References OnlinePreconditioner::alpha_.

◆ GetNumSamplesHistory()

BaseFloat GetNumSamplesHistory ( ) const
inline

◆ GetRank()

int32 GetRank ( ) const
inline

◆ GetUpdatePeriod()

◆ Init()

void Init ( const CuMatrixBase< BaseFloat > &  R0)
private

Definition at line 123 of file nnet-precondition-online.cc.

References OnlinePreconditioner::d_t_, rnnlm::i, OnlinePreconditioner::InitDefault(), kaldi::kUndefined, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), OnlinePreconditioner::PreconditionDirections(), OnlinePreconditioner::rank_, OnlinePreconditioner::rho_t_, OnlinePreconditioner::t_, and OnlinePreconditioner::W_t_.

Referenced by OnlinePreconditioner::GetUpdatePeriod(), and OnlinePreconditioner::PreconditionDirections().

123  {
124  int32 D = R0.NumCols();
125  // for locking reasons it's better to use a different object.
126  OnlinePreconditioner this_copy(*this);
127  this_copy.InitDefault(D);
128 
129  CuMatrix<BaseFloat> R0_copy(R0.NumRows(), R0.NumCols(), kUndefined);
130  // number of iterations with the same data from a pseudorandom start.
131  // this is a faster way of starting than doing eigenvalue decomposition.
132  int32 num_init_iters = 3;
133  for (int32 i = 0; i < num_init_iters; i++) {
134  BaseFloat scale;
135  R0_copy.CopyFromMat(R0);
136  this_copy.PreconditionDirections(&R0_copy, NULL, &scale);
137  }
138  rank_ = this_copy.rank_;
139  W_t_.Swap(&this_copy.W_t_);
140  d_t_.Swap(&this_copy.d_t_);
141  rho_t_ = this_copy.rho_t_;
142  t_ = 0;
143 }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29

◆ InitDefault()

void InitDefault ( int32  D)
private

Definition at line 75 of file nnet-precondition-online.cc.

References OnlinePreconditioner::alpha_, OnlinePreconditioner::d_t_, OnlinePreconditioner::delta_, OnlinePreconditioner::epsilon_, OnlinePreconditioner::InitOrthonormalSpecial(), KALDI_ASSERT, KALDI_WARN, kaldi::kUndefined, OnlinePreconditioner::num_samples_history_, OnlinePreconditioner::rank_, OnlinePreconditioner::rho_t_, OnlinePreconditioner::t_, and OnlinePreconditioner::W_t_.

Referenced by OnlinePreconditioner::GetUpdatePeriod(), and OnlinePreconditioner::Init().

75  {
76  if (rank_ >= D) {
77  KALDI_WARN << "Rank " << rank_ << " of online preconditioner is >= dim " << D
78  << ", setting it to "
79  << (D - 1) << " (but this is probably still too high)";
80  rank_ = D - 1;
81  }
82  if (rank_ == 0) {
83  // Dimension of input data was 1, so the natural gradient preconditioner
84  // would always be the unit matrix.
85  // We'll handle this as a special case, for generality.
86  return;
87  }
89  KALDI_ASSERT(alpha_ >= 0.0);
90  KALDI_ASSERT(rank_ > 0);
91  KALDI_ASSERT(epsilon_ > 0.0 && epsilon_ <= 1.0e-05); // plausible values.
92  KALDI_ASSERT(delta_ > 0.0 && delta_ <= 1.0e-02); // plausible values.
93 
94  // to initialize, in the equation
95  // F_t =(def) R_t^T D_t R_t + \rho_t I
96  // we will set the orthogonal R_t to a special orthogonal matrix with no zero
97  // rows or columns (see the function), rho_t to epsilon,
98  // and D_t to epsilon. But we don't store R_t directly. Instead, we store
99  // W_t =(def) E_t^{0.5} R_t,
100  // where E_t =(def) 1/\beta_t (D_t^{-1} + 1/\beta_t I)^{-1}
101  // from (eqn:tii),
102  // e_{tii} = 1/(\beta_t/d_{tii} + 1),
103  // where
104  // \beta_t =(def) \rho_t + \alpha/D tr(F_t)
105  // = epsilon + alpha/D * (epsilon * D + epsilon * rank)
106  // = epsilon * (1 + alpha * (D + rank) / D)
107  // And d_{tii} is epsilon, so
108  // e_{tii} = 1/((1 + alpha * (D + rank) / D) + 1) [for each i.]
109  // = 1/(2 + alpha * (D + rank) / D)).
110  BaseFloat epsilon = epsilon_; // we could make this a bit more.
111  rho_t_ = epsilon;
112  d_t_.Resize(rank_, kUndefined);
113  d_t_.Set(epsilon);
114  W_t_.Resize(rank_, D, kUndefined);
115  // after the next line, W_ will store the orthogonal matrix R_t.
117  BaseFloat E_tii = 1.0 / ( 2.0 + (D + rank_) * alpha_ / D );
118  // W_t =(def) E_t^{0.5} R_t.
119  W_t_.Scale(sqrt(E_tii));
120  t_ = 0;
121 }
static void InitOrthonormalSpecial(CuMatrixBase< BaseFloat > *R)
This function creates a matrix with orthonormal rows that is like the following matrix, except with each row normalized to have unit 2-norm: [ 1.1 0 1 0 1 0 0 1.1 0 1 0 1 ] The reason why the first element in each row is 1.1 and not 1, is for symmetry-breaking...
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ InitOrthonormalSpecial()

void InitOrthonormalSpecial ( CuMatrixBase< BaseFloat > *  R)
staticprivate

This function creates a matrix with orthonormal rows that is like the following matrix, except with each row normalized to have unit 2-norm: [ 1.1 0 1 0 1 0 0 1.1 0 1 0 1 ] The reason why the first element in each row is 1.1 and not 1, is for symmetry-breaking...

we don't want any weighted sum of all these rows to be all ones, because the derivative in that direction can be zero in some architectures and it causes us to have to do an inefficient CPU-based renormalization.

Definition at line 45 of file nnet-precondition-online.cc.

References CuMatrixBase< Real >::AddElements(), CuMatrixBase< Real >::AddMatMat(), rnnlm::i, CuMatrixBase< Real >::IsUnit(), KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), and CuMatrixBase< Real >::SetZero().

Referenced by OnlinePreconditioner::GetUpdatePeriod(), and OnlinePreconditioner::InitDefault().

45  {
46  int32 num_rows = R->NumRows(), num_cols = R->NumCols();
47  KALDI_ASSERT(num_cols >= num_rows);
48  R->SetZero();
49  std::vector<MatrixElement<BaseFloat> > elems;
50  elems.reserve(num_cols);
51  BaseFloat first_elem = 1.1;
52  for (int32 r = 0; r < num_rows; r++) {
53  std::vector<int32> cols; // columns that have an entry for this row
54  for (int32 c = r; c < num_cols; c += num_rows)
55  cols.push_back(c);
56  BaseFloat normalizer = 1.0 / sqrt(first_elem * first_elem +
57  cols.size() - 1);
58  for (size_t i = 0; i < cols.size(); i++) {
59  int32 c = cols[i];
60  MatrixElement<BaseFloat> e = { r, c,
61  normalizer * (i == 0 ? first_elem :
62  BaseFloat(1.0)) };
63  elems.push_back(e);
64  }
65  }
66  R->AddElements(1.0, elems);
67  { // TODO: remove this testing code.
68  CuMatrix<BaseFloat> prod(num_rows, num_rows);
69  prod.AddMatMat(1.0, *R, kNoTrans, *R, kTrans, 0.0);
70  KALDI_ASSERT(prod.IsUnit());
71  }
72 }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ operator=()

OnlinePreconditioner & operator= ( const OnlinePreconditioner other)

Definition at line 605 of file nnet-precondition-online.cc.

References OnlinePreconditioner::alpha_, OnlinePreconditioner::d_t_, OnlinePreconditioner::delta_, OnlinePreconditioner::epsilon_, OnlinePreconditioner::num_samples_history_, OnlinePreconditioner::rank_, OnlinePreconditioner::rho_t_, OnlinePreconditioner::self_debug_, OnlinePreconditioner::t_, OnlinePreconditioner::update_period_, and OnlinePreconditioner::W_t_.

Referenced by OnlinePreconditioner::GetUpdatePeriod().

606  {
607  rank_ = other.rank_;
608  update_period_ = other.update_period_;
609  num_samples_history_ = other.num_samples_history_;
610  alpha_ = other.alpha_;
611  epsilon_ = other.epsilon_;
612  delta_ = other.delta_;
613  t_ = other.t_;
614  self_debug_ = other.self_debug_;
615  W_t_ = other.W_t_;
616  rho_t_ = other.rho_t_;
617  d_t_ = other.d_t_;
618  return *this;
619 }

◆ PreconditionDirections()

void PreconditionDirections ( CuMatrixBase< BaseFloat > *  R,
CuVectorBase< BaseFloat > *  row_prod,
BaseFloat scale 
)

Definition at line 145 of file nnet-precondition-online.cc.

References CuVectorBase< Real >::AddDiagMat2(), OnlinePreconditioner::d_t_, OnlinePreconditioner::Init(), kaldi::kNoTrans, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), OnlinePreconditioner::PreconditionDirectionsInternal(), CuMatrixBase< Real >::Range(), OnlinePreconditioner::read_write_mutex_, OnlinePreconditioner::rho_t_, OnlinePreconditioner::t_, and OnlinePreconditioner::W_t_.

Referenced by OnlinePreconditioner::GetUpdatePeriod(), OnlinePreconditioner::Init(), kaldi::nnet2::UnitTestPreconditionDirectionsOnline(), and AffineComponentPreconditionedOnline::Update().

148  {
149  if (X_t->NumCols() == 1) {
150  // If the dimension of the space equals one then our natural gradient update
151  // with rescaling becomes a no-op, but the code wouldn't naturally handle it
152  // because rank would be zero. Support this as a special case.
153  if (row_prod)
154  row_prod->AddDiagMat2(1.0, *X_t, kNoTrans, 0.0);
155  *scale = 1.0;
156  return;
157  }
158 
159  if (row_prod == NULL) {
160  CuVector<BaseFloat> row_prod_tmp(X_t->NumRows());
161  PreconditionDirections(X_t, &row_prod_tmp, scale);
162  return;
163  }
164 
165  read_write_mutex_.lock();
166  if (t_ == -1) // not initialized
167  Init(*X_t);
168 
169  // Now t_ >= 0.
170  // We create local copies of the class variables... this is intended for
171  // multi-threaded safety so we can't read them in an inconsistent state,
172  // but we don't really waste anything here (a copy of W_t is needed anyway,
173  // if we're to update it).
174  int32 t = t_, R = W_t_.NumRows(), D = W_t_.NumCols();
175  // space for W_t, J_t, K_t, L_t.
176  CuMatrix<BaseFloat> WJKL_t(2 * R, D + R);
177  WJKL_t.Range(0, R, 0, D).CopyFromMat(W_t_);
178  BaseFloat rho_t(rho_t_);
179  Vector<BaseFloat> d_t(d_t_);
180  read_write_mutex_.unlock();
181  PreconditionDirectionsInternal(t, rho_t, d_t, &WJKL_t, X_t, row_prod, scale);
182 }
kaldi::int32 int32
void PreconditionDirections(CuMatrixBase< BaseFloat > *R, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)
float BaseFloat
Definition: kaldi-types.h:29
void Init(const CuMatrixBase< BaseFloat > &R0)
void PreconditionDirectionsInternal(const int32 t, const BaseFloat rho_t, const Vector< BaseFloat > &d_t, CuMatrixBase< BaseFloat > *WJKL_t, CuMatrixBase< BaseFloat > *X_t, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)

◆ PreconditionDirectionsInternal()

void PreconditionDirectionsInternal ( const int32  t,
const BaseFloat  rho_t,
const Vector< BaseFloat > &  d_t,
CuMatrixBase< BaseFloat > *  WJKL_t,
CuMatrixBase< BaseFloat > *  X_t,
CuVectorBase< BaseFloat > *  row_prod,
BaseFloat scale 
)
private

Definition at line 300 of file nnet-precondition-online.cc.

References VectorBase< Real >::Add(), CuVectorBase< Real >::AddDiagMat2(), CuMatrixBase< Real >::AddMatMat(), OnlinePreconditioner::alpha_, VectorBase< Real >::ApplyFloor(), VectorBase< Real >::ApplyPow(), kaldi::ApproxEqual(), OnlinePreconditioner::ComputeEt(), OnlinePreconditioner::ComputeWt1(), OnlinePreconditioner::ComputeZt(), MatrixBase< Real >::CopyLowerToUpper(), OnlinePreconditioner::d_t_, OnlinePreconditioner::delta_, SpMatrix< Real >::Eig(), OnlinePreconditioner::epsilon_, OnlinePreconditioner::Eta(), rnnlm::i, KALDI_ASSERT, KALDI_VLOG, KALDI_WARN, kaldi::kNoTrans, kaldi::kTrans, VectorBase< Real >::Max(), OnlinePreconditioner::num_updates_skipped_, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), OnlinePreconditioner::rank_, OnlinePreconditioner::read_write_mutex_, OnlinePreconditioner::ReorthogonalizeXt1(), OnlinePreconditioner::rho_t_, PackedMatrix< Real >::Scale(), VectorBase< Real >::Scale(), OnlinePreconditioner::self_debug_, OnlinePreconditioner::SelfTest(), kaldi::SortSvd(), CuVectorBase< Real >::Sum(), VectorBase< Real >::Sum(), OnlinePreconditioner::t_, SpMatrix< Real >::Trace(), kaldi::TraceMatMat(), OnlinePreconditioner::update_mutex_, OnlinePreconditioner::update_period_, and OnlinePreconditioner::W_t_.

Referenced by OnlinePreconditioner::GetUpdatePeriod(), and OnlinePreconditioner::PreconditionDirections().

307  {
308  int32 N = X_t->NumRows(), // Minibatch size.
309  D = X_t->NumCols(), // Dimensions of vectors we're preconditioning
310  R = rank_; // Rank of correction to unit matrix.
311  KALDI_ASSERT(R > 0 && R < D);
312  BaseFloat eta = Eta(N);
313 
314  CuMatrix<BaseFloat> H_t(N, R);
315  const CuSubMatrix<BaseFloat> W_t(*WJKL_t, 0, R, 0, D);
316  // Below, WJ_t and LK_t are combinations of two matrices,
317  // which we define in order to combine two separate multiplications into one.
318  CuSubMatrix<BaseFloat> J_t(*WJKL_t, R, R, 0, D),
319  L_t(*WJKL_t, 0, R, D, R),
320  K_t(*WJKL_t, R, R, D, R),
321  WJ_t(*WJKL_t, 0, 2 * R, 0, D),
322  LK_t(*WJKL_t, 0, 2 * R, D, R);
323 
324  H_t.AddMatMat(1.0, *X_t, kNoTrans, W_t, kTrans, 0.0); // H_t = X_t W_t^T
325 
326  bool locked = update_mutex_.try_lock();
327  if (locked) {
328  // Just hard-code it here that we do 10 updates before skipping any.
329  const int num_initial_updates = 10;
330  if (t_ > t || (num_updates_skipped_ < update_period_ - 1 &&
331  t_ >= num_initial_updates)) {
332  update_mutex_.unlock();
333  // We got the lock but we were already beaten to it by another thread, or
334  // we don't want to update yet due to update_period_ > 1 (this saves
335  // compute), so release the lock.
336  locked = false;
337  }
338  }
339 
340  if (!locked) {
341  // We're not updating the parameters, either because another thread is
342  // working on updating them, or because another thread already did so from
343  // the same or later starting point (making our update stale), or because
344  // update_period_ > 1. We just apply the preconditioning and return.
345 
346  // note: we don't bother with any locks before incrementing
347  // num_updates_skipped_ below, because the worst that could happen is that,
348  // on very rare occasions, we could skip one or two more updates than we
349  // intended.
351 
352  BaseFloat tr_Xt_XtT = TraceMatMat(*X_t, *X_t, kTrans);
353  // X_hat_t = X_t - H_t W_t
354  X_t->AddMatMat(-1.0, H_t, kNoTrans, W_t, kNoTrans, 1.0);
355  // each element i of row_prod will be inner product of row i of X_hat_t with
356  // itself.
357  row_prod->AddDiagMat2(1.0, *X_t, kNoTrans, 0.0);
358  BaseFloat tr_Xhat_XhatT = row_prod->Sum();
359  KALDI_ASSERT(tr_Xhat_XhatT == tr_Xhat_XhatT); // Check for NaN.
360  BaseFloat gamma_t = (tr_Xhat_XhatT == 0.0 ? 1.0 :
361  sqrt(tr_Xt_XtT / tr_Xhat_XhatT));
362  *scale = gamma_t;
363  return;
364  }
365  J_t.AddMatMat(1.0, H_t, kTrans, *X_t, kNoTrans, 0.0); // J_t = H_t^T X_t
366 
367  bool compute_lk_together = (N > D);
368 
369  if (compute_lk_together) {
370  // do the following two multiplies in one operation...
371  // note
372  // L_t = W_t J_t^T
373  // K_t = J_t J_t^T
374  // Note: L_t was defined as L_t = J_t W_t^T, but it's actually symmetric,
375  // so we can compute it as L_t = W_t J_t^T.
376  LK_t.AddMatMat(1.0, WJ_t, kNoTrans, J_t, kTrans, 0.0);
377  } else {
378  K_t.SymAddMat2(1.0, J_t, kNoTrans, 0.0);
379  L_t.SymAddMat2(1.0, H_t, kTrans, 0.0);
380  }
381 
382  Matrix<BaseFloat> LK_cpu(LK_t); // contains L and K on the CPU.
383  SubMatrix<BaseFloat> L_t_cpu(LK_cpu, 0, R, 0, R),
384  K_t_cpu(LK_cpu, R, R, 0, R);
385  if (!compute_lk_together) {
386  // the SymAddMat2 operations only set the lower triangle and diagonal.
387  L_t_cpu.CopyLowerToUpper();
388  K_t_cpu.CopyLowerToUpper();
389  }
390 
391  // beta_t = \rho_t(1+\alpha) + \alpha/D tr(D_t)
392  BaseFloat beta_t = rho_t * (1.0 + alpha_) + alpha_ * d_t.Sum() / D;
393  Vector<BaseFloat> e_t(R), sqrt_e_t(R), inv_sqrt_e_t(R);
394  ComputeEt(d_t, beta_t, &e_t, &sqrt_e_t, &inv_sqrt_e_t);
395  KALDI_VLOG(5) << "e_t = " << e_t;
396 
397  // The double-precision Z_t here, and the scaling, is to avoid potential
398  // overflow, because Z_t is proportional to the fourth power of data.
399  SpMatrix<double> Z_t_double(R);
400  ComputeZt(N, rho_t, d_t, inv_sqrt_e_t, K_t_cpu, L_t_cpu, &Z_t_double);
401  BaseFloat z_t_scale = std::max<double>(1.0, Z_t_double.Trace());
402  Z_t_double.Scale(1.0 / z_t_scale);
403  SpMatrix<BaseFloat> Z_t_scaled(Z_t_double);
404 
405  Matrix<BaseFloat> U_t(R, R);
406  Vector<BaseFloat> c_t(R);
407  // do the symmetric eigenvalue decomposition Z_t = U_t C_t U_t^T.
408  Z_t_scaled.Eig(&c_t, &U_t);
409  SortSvd(&c_t, &U_t);
410  c_t.Scale(z_t_scale);
411 
412  const BaseFloat condition_threshold = 1.0e+06;
413  // must_reorthogonalize will be true if the last diagonal element of c_t is
414  // negative, since we don't take the absolute value, but this is the right
415  // thing anyway.
416  bool must_reorthogonalize = (c_t(0) > condition_threshold * c_t(R - 1));
417 
418  BaseFloat c_t_floor = pow(rho_t * (1 - eta), 2);
419  int32 nf;
420  c_t.ApplyFloor(c_t_floor, &nf);
421  if (nf > 0)
422  must_reorthogonalize = true;
423  if (nf > 0 && self_debug_) {
424  KALDI_WARN << "Floored " << nf << " elements of C_t.";
425  }
426  BaseFloat tr_Xt_XtT_check;
427  if (self_debug_)
428  tr_Xt_XtT_check = TraceMatMat(*X_t, *X_t, kTrans);
429 
430  X_t->AddMatMat(-1.0, H_t, kNoTrans, W_t, kNoTrans, 1.0); // X_hat_t = X_t - H_t W_t
431  // set *row_prod to inner products of each row of X_hat_t with itself.
432  row_prod->AddDiagMat2(1.0, *X_t, kNoTrans, 0.0);
433 
434  BaseFloat tr_Xhat_XhatT = row_prod->Sum();
435  // tr(X_t X_t^T) = tr(X_hat_t X_hat_t^T) - tr(L_t E_t) + 2 tr(L_t)
436  double tr_Xt_XtT = tr_Xhat_XhatT;
437  for (int32 i = 0; i < R; i++)
438  tr_Xt_XtT += L_t_cpu(i, i) * (2.0 - e_t(i));
439  if (self_debug_) {
440  KALDI_ASSERT(ApproxEqual(tr_Xt_XtT, tr_Xt_XtT_check));
441  }
442  BaseFloat gamma_t = (tr_Xhat_XhatT == 0.0 ? 1.0 :
443  sqrt(tr_Xt_XtT / tr_Xhat_XhatT));
444  *scale = gamma_t;
445 
446  Vector<BaseFloat> sqrt_c_t(c_t);
447  sqrt_c_t.ApplyPow(0.5);
448 
449  // \rho_{t+1} = 1/(D - R) (\eta/N tr(X_t X_t^T) + (1-\eta)(D \rho_t + tr(D_t)) - tr(C_t^{0.5})).
450  BaseFloat rho_t1 = 1.0 / (D - R) * (eta / N * tr_Xt_XtT
451  + (1-eta)*(D * rho_t + d_t.Sum())
452  - sqrt_c_t.Sum());
453  // D_{t+1} = C_t^{0.5} - \rho_{t+1} I
454  Vector<BaseFloat> d_t1(sqrt_c_t);
455  d_t1.Add(-rho_t1);
456  BaseFloat floor_val = std::max(epsilon_, delta_ * sqrt_c_t.Max());
457  if (rho_t1 < floor_val)
458  rho_t1 = floor_val;
459  d_t1.ApplyFloor(floor_val);
460 
461  CuMatrix<BaseFloat> W_t1(R, D); // W_{t+1}
462  ComputeWt1(N, d_t, d_t1, rho_t, rho_t1, U_t, sqrt_c_t, inv_sqrt_e_t,
463  W_t, &J_t, &W_t1);
464 
465  if (must_reorthogonalize) {
466  if (self_debug_) {
467  KALDI_WARN << "Reorthogonalizing.";
468  }
469  ReorthogonalizeXt1(d_t1,
470  rho_t1,
471  &W_t1,
472  &J_t,
473  &L_t);
474  }
475 
476  // Commit the new parameters.
477  read_write_mutex_.lock();
478  KALDI_ASSERT(t_ == t); // we already ensured this.
479  t_ = t + 1;
481  W_t_.Swap(&W_t1);
482  d_t_.CopyFromVec(d_t1);
483  rho_t_ = rho_t1;
484 
485  if (self_debug_)
486  SelfTest();
487 
488  read_write_mutex_.unlock();
489  update_mutex_.unlock();
490 }
void ComputeWt1(int32 N, const VectorBase< BaseFloat > &d_t, const VectorBase< BaseFloat > &d_t1, BaseFloat rho_t, BaseFloat rho_t1, const MatrixBase< BaseFloat > &U_t, const VectorBase< BaseFloat > &sqrt_c_t, const VectorBase< BaseFloat > &inv_sqrt_e_t, const CuMatrixBase< BaseFloat > &W_t, CuMatrixBase< BaseFloat > *J_t, CuMatrixBase< BaseFloat > *W_t1) const
kaldi::int32 int32
void ComputeEt(const VectorBase< BaseFloat > &d_t, BaseFloat beta_t, VectorBase< BaseFloat > *e_t, VectorBase< BaseFloat > *sqrt_e_t, VectorBase< BaseFloat > *inv_sqrt_e_t) const
void ComputeZt(int32 N, BaseFloat rho_t, const VectorBase< BaseFloat > &d_t, const VectorBase< BaseFloat > &inv_sqrt_e_t, const MatrixBase< BaseFloat > &K_t, const MatrixBase< BaseFloat > &L_t, SpMatrix< double > *Z_t) const
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_WARN
Definition: kaldi-error.h:150
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
void ReorthogonalizeXt1(const VectorBase< BaseFloat > &d_t1, BaseFloat rho_t1, CuMatrixBase< BaseFloat > *W_t1, CuMatrixBase< BaseFloat > *temp_W, CuMatrixBase< BaseFloat > *temp_O)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Definition: kaldi-math.h:265
void SortSvd(VectorBase< Real > *s, MatrixBase< Real > *U, MatrixBase< Real > *Vt, bool sort_on_absolute_value)
Function to ensure that SVD is sorted.

◆ ReorthogonalizeXt1()

void ReorthogonalizeXt1 ( const VectorBase< BaseFloat > &  d_t1,
BaseFloat  rho_t1,
CuMatrixBase< BaseFloat > *  W_t1,
CuMatrixBase< BaseFloat > *  temp_W,
CuMatrixBase< BaseFloat > *  temp_O 
)
private

Definition at line 184 of file nnet-precondition-online.cc.

References CuMatrixBase< Real >::AddMatMat(), OnlinePreconditioner::alpha_, TpMatrix< Real >::Cholesky(), OnlinePreconditioner::ComputeEt(), CuMatrixBase< Real >::CopyFromMat(), MatrixBase< Real >::CopyFromTp(), rnnlm::i, TpMatrix< Real >::Invert(), SpMatrix< Real >::IsUnit(), rnnlm::j, KALDI_ERR, KALDI_WARN, kaldi::kNoTrans, kaldi::kTakeLower, kaldi::kUndefined, PackedMatrix< Real >::Max(), CuMatrixBase< Real >::MulRowsVec(), CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), MatrixBase< Real >::OrthogonalizeRows(), OnlinePreconditioner::self_debug_, VectorBase< Real >::Sum(), and CuMatrixBase< Real >::SymAddMat2().

Referenced by OnlinePreconditioner::GetUpdatePeriod(), and OnlinePreconditioner::PreconditionDirectionsInternal().

189  {
190  // threshold is a configuration value: a desired threshold on orthogonality,
191  // below which we won't reorthogonalize.
192  const BaseFloat threshold = 1.0e-03;
193 
194  int32 R = W_t1->NumRows(), D = W_t1->NumCols();
195  BaseFloat beta_t1 = rho_t1 * (1.0 + alpha_) + alpha_ * d_t1.Sum() / D;
196  Vector<BaseFloat> e_t1(R, kUndefined), sqrt_e_t1(R, kUndefined),
197  inv_sqrt_e_t1(R, kUndefined);
198  ComputeEt(d_t1, beta_t1, &e_t1, &sqrt_e_t1, &inv_sqrt_e_t1);
199 
200  temp_O->SymAddMat2(1.0, *W_t1, kNoTrans, 0.0);
201  // O_t = E_t^{-0.5} W_t W_t^T E_t^{-0.5}
202  Matrix<BaseFloat> O_mat(*temp_O);
203  SpMatrix<BaseFloat> O(O_mat, kTakeLower);
204  for (int32 i = 0; i < R; i++) {
205  BaseFloat i_factor = inv_sqrt_e_t1(i);
206  for (int32 j = 0; j <= i; j++) {
207  BaseFloat j_factor = inv_sqrt_e_t1(j);
208  O(i, j) *= i_factor * j_factor;
209  }
210  }
211  if (O.IsUnit(threshold)) {
212  if (self_debug_) {
213  KALDI_WARN << "Not reorthogonalizing since already orthognoal: " << O;
214  }
215  return;
216  }
217  TpMatrix<BaseFloat> C(R);
218  try {
219  C.Cholesky(O);
220  C.Invert(); // Now it's C^{-1}.
221  if (!(C.Max() < 100.0))
222  KALDI_ERR << "Cholesky out of expected range, "
223  << "reorthogonalizing with Gram-Schmidt";
224  } catch (...) {
225  // We do a Gram-Schmidt orthogonalization, which is a bit less efficient but
226  // more robust than the method using Cholesky.
227  KALDI_WARN << "Cholesky or Invert() failed while re-orthogonalizing R_t. "
228  << "Re-orthogonalizing on CPU.";
229  Matrix<BaseFloat> cpu_W_t1(*W_t1);
230  cpu_W_t1.OrthogonalizeRows();
231  W_t1->CopyFromMat(cpu_W_t1);
232  // at this point cpu_W_t1 represents R_{t+1}- it has orthonormal
233  // rows. Do: W_{t+1} = E_{t+1}^{0.5} R_{t+1}
234  CuVector<BaseFloat> sqrt_e_t1_gpu(sqrt_e_t1);
235  W_t1->MulRowsVec(sqrt_e_t1_gpu);
236  return;
237  }
238  // Next, compute (E_t^{0.5} C^{-1} E_t^{-0.5})
239  // but it's really t+1, not t.
240  for (int32 i = 0; i < R; i++) {
241  BaseFloat i_factor = sqrt_e_t1(i);
242  for (int32 j = 0; j < i; j++) {
243  // skip j == i because i_factor * j_factor == 1 for j == i.
244  BaseFloat j_factor = inv_sqrt_e_t1(j);
245  C(i, j) *= i_factor * j_factor;
246  }
247  }
248  O_mat.CopyFromTp(C);
249  temp_O->CopyFromMat(O_mat);
250  temp_W->CopyFromMat(*W_t1);
251  W_t1->AddMatMat(1.0, *temp_O, kNoTrans, *temp_W, kNoTrans, 0.0);
252 }
kaldi::int32 int32
void ComputeEt(const VectorBase< BaseFloat > &d_t, BaseFloat beta_t, VectorBase< BaseFloat > *e_t, VectorBase< BaseFloat > *sqrt_e_t, VectorBase< BaseFloat > *inv_sqrt_e_t) const
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150

◆ SelfTest()

void SelfTest ( ) const
private

Definition at line 255 of file nnet-precondition-online.cc.

References CuSpMatrix< Real >::AddMat2(), OnlinePreconditioner::alpha_, OnlinePreconditioner::ComputeEt(), OnlinePreconditioner::d_t_, OnlinePreconditioner::delta_, OnlinePreconditioner::epsilon_, rnnlm::i, SpMatrix< Real >::IsUnit(), rnnlm::j, KALDI_ASSERT, KALDI_WARN, kaldi::kNoTrans, kaldi::kUndefined, OnlinePreconditioner::rho_t_, and OnlinePreconditioner::W_t_.

Referenced by OnlinePreconditioner::GetUpdatePeriod(), and OnlinePreconditioner::PreconditionDirectionsInternal().

255  {
257  BaseFloat d_t_max = d_t_.Max(), d_t_min = d_t_.Min();
258  KALDI_ASSERT(d_t_min >= epsilon_);
259  KALDI_ASSERT(d_t_min > 0.9 * delta_ * d_t_max);
260  KALDI_ASSERT(rho_t_ > 0.9 * delta_ * d_t_max);
261 
262  int32 D = W_t_.NumCols(), R = W_t_.NumRows();
263  BaseFloat beta_t = rho_t_ * (1.0 + alpha_) + alpha_ * d_t_.Sum() / D;
264  Vector<BaseFloat> e_t(R, kUndefined), sqrt_e_t(R, kUndefined),
265  inv_sqrt_e_t(R, kUndefined);
266  ComputeEt(d_t_, beta_t, &e_t, &sqrt_e_t, &inv_sqrt_e_t);
267 
268  CuSpMatrix<BaseFloat> S(R);
269  S.AddMat2(1.0, W_t_, kNoTrans, 0.0);
270  SpMatrix<BaseFloat> O(S);
271  for (int32 i = 0; i < R; i++) {
272  BaseFloat i_factor = inv_sqrt_e_t(i);
273  for (int32 j = 0; j <= i; j++) {
274  BaseFloat j_factor = inv_sqrt_e_t(j);
275  O(i, j) *= i_factor * j_factor;
276  }
277  }
278  if (!O.IsUnit(1.0e-04) || O(0, 0) != O(0, 0)) {
279  BaseFloat worst_error = 0.0;
280  int32 worst_i = 0, worst_j = 0;
281  for (int32 i = 0; i < R; i++) {
282  for (int32 j = 0; j < R; j++) {
283  BaseFloat elem = O(i, j);
284  BaseFloat error = fabs(elem - (i == j ? 1.0 : 0.0));
285  if (error > worst_error || error != error) {
286  worst_error = error;
287  worst_i = i;
288  worst_j = j;
289  }
290  }
291  }
292  if (worst_error > 1.0e-02 || worst_error != worst_error) {
293  KALDI_WARN << "Failed to verify W_t (worst error: O[" << worst_i << ','
294  << worst_j << "] = " << O(worst_i, worst_j)
295  << ", d_t = " << d_t_;
296  }
297  }
298 }
kaldi::int32 int32
void ComputeEt(const VectorBase< BaseFloat > &d_t, BaseFloat beta_t, VectorBase< BaseFloat > *e_t, VectorBase< BaseFloat > *sqrt_e_t, VectorBase< BaseFloat > *inv_sqrt_e_t) const
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetAlpha()

void SetAlpha ( BaseFloat  alpha)

Definition at line 634 of file nnet-precondition-online.cc.

References OnlinePreconditioner::alpha_, and KALDI_ASSERT.

634  {
635  KALDI_ASSERT(alpha >= 0.0);
636  alpha_ = alpha;
637 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetNumSamplesHistory()

void SetNumSamplesHistory ( BaseFloat  num_samples_history)

Definition at line 629 of file nnet-precondition-online.cc.

References KALDI_ASSERT, and OnlinePreconditioner::num_samples_history_.

629  {
630  KALDI_ASSERT(num_samples_history > 0.0 &&
631  num_samples_history < 1.0e+6);
632  num_samples_history_ = num_samples_history;
633 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetRank()

void SetRank ( int32  rank)

Definition at line 621 of file nnet-precondition-online.cc.

References KALDI_ASSERT, and OnlinePreconditioner::rank_.

Referenced by kaldi::nnet2::UnitTestPreconditionDirectionsOnline().

621  {
622  KALDI_ASSERT(rank > 0);
623  rank_ = rank;
624 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetUpdatePeriod()

void SetUpdatePeriod ( int32  update_period)

Definition at line 625 of file nnet-precondition-online.cc.

References KALDI_ASSERT, and OnlinePreconditioner::update_period_.

625  {
626  KALDI_ASSERT(update_period > 0);
627  update_period_ = update_period;
628 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ TurnOnDebug()

Member Data Documentation

◆ alpha_

◆ d_t_

◆ delta_

◆ epsilon_

◆ num_samples_history_

◆ num_updates_skipped_

int32 num_updates_skipped_
private

◆ rank_

◆ read_write_mutex_

std::mutex read_write_mutex_
private

◆ rho_t_

◆ self_debug_

◆ t_

◆ update_mutex_

std::mutex update_mutex_
private

◆ update_period_

◆ W_t_


The documentation for this class was generated from the following files: