OnlinePreconditionerSimple Class Reference
Collaboration diagram for OnlinePreconditionerSimple:

Public Member Functions

 OnlinePreconditionerSimple ()
 
void SetRank (int32 rank)
 
void PreconditionDirections (CuMatrixBase< BaseFloat > *R, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)
 

Private Member Functions

BaseFloat Eta (int32 N) const
 
void PreconditionDirectionsCpu (MatrixBase< double > *R, VectorBase< double > *row_prod, BaseFloat *scale)
 
void Init (const MatrixBase< double > &R0)
 
void InitDefault (int32 D)
 

Private Attributes

int32 rank_
 
double num_samples_history_
 
double alpha_
 
double epsilon_
 
double delta_
 
Vector< double > d_t_
 
Matrix< double > R_t_
 
double rho_t_
 

Detailed Description

Definition at line 28 of file nnet-precondition-online-test.cc.

Constructor & Destructor Documentation

◆ OnlinePreconditionerSimple()

Member Function Documentation

◆ Eta()

BaseFloat Eta ( int32  N) const
private

Definition at line 120 of file nnet-precondition-online-test.cc.

References KALDI_ASSERT, and OnlinePreconditionerSimple::num_samples_history_.

Referenced by OnlinePreconditionerSimple::PreconditionDirectionsCpu(), and OnlinePreconditionerSimple::SetRank().

120  {
122  BaseFloat ans = 1.0 - exp(-N / num_samples_history_);
123  if (ans > 0.9) ans = 0.9;
124  return ans;
125 }
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Init()

void Init ( const MatrixBase< double > &  R0)
private

Definition at line 108 of file nnet-precondition-online-test.cc.

References rnnlm::i, OnlinePreconditionerSimple::InitDefault(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), and OnlinePreconditionerSimple::PreconditionDirections().

Referenced by OnlinePreconditionerSimple::PreconditionDirectionsCpu(), and OnlinePreconditionerSimple::SetRank().

108  {
109  int32 D = R0.NumCols(), N = R0.NumRows();
110  InitDefault(D);
111  int32 num_init_iters = 3;
112  for (int32 i = 0; i < num_init_iters; i++) {
113  CuMatrix<BaseFloat> R0_copy(R0);
114  CuVector<BaseFloat> row_products(N);
115  BaseFloat scale;
116  PreconditionDirections(&R0_copy, &row_products, &scale);
117  }
118 }
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
float BaseFloat
Definition: kaldi-types.h:29
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void PreconditionDirections(CuMatrixBase< BaseFloat > *R, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)

◆ InitDefault()

void InitDefault ( int32  D)
private

Definition at line 84 of file nnet-precondition-online-test.cc.

References OnlinePreconditionerSimple::d_t_, OnlinePreconditionerSimple::epsilon_, rnnlm::i, KALDI_WARN, OnlinePreconditionerSimple::R_t_, OnlinePreconditionerSimple::rank_, Vector< Real >::Resize(), Matrix< Real >::Resize(), OnlinePreconditionerSimple::rho_t_, and VectorBase< Real >::Set().

Referenced by OnlinePreconditionerSimple::Init(), and OnlinePreconditionerSimple::SetRank().

84  {
85  if (rank_ >= D) {
86  KALDI_WARN << "Rank " << rank_ << " of online preconditioner is >= dim " << D
87  << ", setting it to "
88  << (D - 1) << " (but this is probably still too high)";
89  rank_ = D - 1;
90  }
91  int32 R = rank_;
92  R_t_.Resize(R, D);
93  for (int32 r = 0; r < R; r++) {
94  std::vector<int32> cols;
95  for (int32 c = r; c < D; c += R)
96  cols.push_back(c);
97  for (int32 i = 0; i < cols.size(); i++) {
98  int32 c = cols[i];
99  R_t_(r, c) = (i == 0 ? 1.1 : 1.0) /
100  sqrt(1.1 * 1.1 + cols.size() - 1);
101  }
102  }
103  d_t_.Resize(R);
104  d_t_.Set(epsilon_);
105  rho_t_ = epsilon_;
106 }
kaldi::int32 int32
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
#define KALDI_WARN
Definition: kaldi-error.h:150
void Set(Real f)
Set all members of a vector to a specified value.
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).

◆ PreconditionDirections()

void PreconditionDirections ( CuMatrixBase< BaseFloat > *  R,
CuVectorBase< BaseFloat > *  row_prod,
BaseFloat scale 
)

Definition at line 67 of file nnet-precondition-online-test.cc.

References MatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::CopyFromMat(), CuVectorBase< Real >::CopyFromVec(), VectorBase< Real >::CopyFromVec(), and OnlinePreconditionerSimple::PreconditionDirectionsCpu().

Referenced by OnlinePreconditionerSimple::Init(), OnlinePreconditionerSimple::SetRank(), and kaldi::nnet2::UnitTestPreconditionDirectionsOnline().

70  {
71  Matrix<BaseFloat> R_cpu(*R);
72  Vector<BaseFloat> row_prod_cpu(*row_prod);
73  Matrix<double> R_cpu_dbl(R_cpu);
74  Vector<double> row_prod_cpu_dbl(row_prod_cpu);
75  PreconditionDirectionsCpu(&R_cpu_dbl,
76  &row_prod_cpu_dbl,
77  scale);
78  row_prod_cpu.CopyFromVec(row_prod_cpu_dbl);
79  R_cpu.CopyFromMat(R_cpu_dbl);
80  R->CopyFromMat(R_cpu);
81  row_prod->CopyFromVec(row_prod_cpu);
82 }
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:344
void CopyFromVec(const CuVectorBase< Real > &src)
Copy functions; these will crash if the dimension do not match.
Definition: cu-vector.cc:1078
void PreconditionDirectionsCpu(MatrixBase< double > *R, VectorBase< double > *row_prod, BaseFloat *scale)
A class representing a vector.
Definition: kaldi-vector.h:406

◆ PreconditionDirectionsCpu()

void PreconditionDirectionsCpu ( MatrixBase< double > *  R,
VectorBase< double > *  row_prod,
BaseFloat scale 
)
private

Definition at line 128 of file nnet-precondition-online-test.cc.

References VectorBase< Real >::Add(), VectorBase< Real >::AddDiagMat2(), SpMatrix< Real >::AddMat2(), SpMatrix< Real >::AddMat2Vec(), MatrixBase< Real >::AddMatMat(), MatrixBase< Real >::AddMatSp(), SpMatrix< Real >::AddSp(), PackedMatrix< Real >::AddToDiag(), VectorBase< Real >::AddVec(), OnlinePreconditionerSimple::alpha_, VectorBase< Real >::ApplyFloor(), VectorBase< Real >::ApplyPow(), kaldi::AssertEqual(), MatrixBase< Real >::CopyFromMat(), VectorBase< Real >::CopyFromVec(), OnlinePreconditionerSimple::d_t_, OnlinePreconditionerSimple::delta_, SpMatrix< Real >::Eig(), OnlinePreconditionerSimple::epsilon_, OnlinePreconditionerSimple::Eta(), rnnlm::i, OnlinePreconditionerSimple::Init(), SpMatrix< Real >::Invert(), VectorBase< Real >::InvertElements(), SpMatrix< Real >::IsUnit(), rnnlm::j, KALDI_ASSERT, KALDI_VLOG, KALDI_WARN, kaldi::kNoTrans, kaldi::kTrans, VectorBase< Real >::Max(), VectorBase< Real >::Min(), MatrixBase< Real >::MulRowsVec(), VectorBase< Real >::Norm(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), OnlinePreconditionerSimple::R_t_, OnlinePreconditionerSimple::rho_t_, VectorBase< Real >::Scale(), kaldi::SortSvd(), VectorBase< Real >::Sum(), SpMatrix< Real >::Trace(), kaldi::TraceMatMat(), and kaldi::VecVec().

Referenced by OnlinePreconditionerSimple::PreconditionDirections(), and OnlinePreconditionerSimple::SetRank().

131  {
132  if (R_t_.NumRows() == 0)
133  Init(*X_t);
134  int32 R = R_t_.NumRows(), D = R_t_.NumCols(), N = X_t->NumRows();
135  BaseFloat eta = Eta(N);
136 
137  SpMatrix<double> F_t(D);
138  // F_t =(def) R_t^T D_t R_t + \rho_t I
139  F_t.AddToDiag(rho_t_);
140  F_t.AddMat2Vec(1.0, R_t_, kTrans, d_t_, 1.0);
141 
142  // Make sure F_t is +ve definite.
143  {
144  KALDI_ASSERT(d_t_.Min() > 0);
145  Vector<double> eigs(D);
146  F_t.Eig(&eigs, NULL);
147  KALDI_ASSERT(eigs.Min() > 0);
148  }
149 
150  // S_t =(def) 1/N X_t^T X_t.
151  SpMatrix<double> S_t(D);
152  S_t.AddMat2(1.0 / N, *X_t, kTrans, 0.0);
153 
154  // T_t =(def) \eta S_t + (1-\eta) F_t
155  SpMatrix<double> T_t(D);
156  T_t.AddSp(eta, S_t);
157  T_t.AddSp(1.0 - eta, F_t);
158 
159  // Y_t =(def) R_t T_t
160  Matrix<double> Y_t(R, D);
161  Y_t.AddMatSp(1.0, R_t_, kNoTrans, T_t, 0.0);
162 
163  // Z_t =(def) Y_t Y_t^T
164  SpMatrix<double> Z_t(R);
165  Z_t.AddMat2(1.0, Y_t, kNoTrans, 0.0);
166 
167  Matrix<double> U_t(R, R);
168  Vector<double> c_t(R);
169  // decompose Z_t = U_t C_t U_t^T
170  Z_t.Eig(&c_t, &U_t);
171  SortSvd(&c_t, &U_t);
172  double c_t_floor = pow(rho_t_ * (1.0 - eta), 2);
173  int32 nf;
174  c_t.ApplyFloor(c_t_floor, &nf);
175  if (nf > 0) {
176  KALDI_WARN << "Floored " << nf << " elements of c_t.";
177  }
178  // KALDI_LOG << "c_t is " << c_t;
179  // KALDI_LOG << "U_t is " << U_t;
180  // KALDI_LOG << "Z_t is " << Z_t;
181 
182  Vector<double> sqrt_c_t(c_t);
183  sqrt_c_t.ApplyPow(0.5);
184  Vector<double> inv_sqrt_c_t(sqrt_c_t);
185  inv_sqrt_c_t.InvertElements();
186  Matrix<double> R_t1(R, D);
187  // R_{t+1} = C_t^{-0.5} U_t^T Y_t
188  R_t1.AddMatMat(1.0, U_t, kTrans, Y_t, kNoTrans, 0.0);
189  R_t1.MulRowsVec(inv_sqrt_c_t);
190 
191  double rho_t1 = (1.0 / (D - R)) *
192  (eta * S_t.Trace() + (1.0 - eta) * (D * rho_t_ + d_t_.Sum()) - sqrt_c_t.Sum());
193 
194  Vector<double> d_t1(sqrt_c_t);
195  d_t1.Add(-rho_t1);
196 
197  double floor_val = std::max(epsilon_, delta_ * sqrt_c_t.Max());
198  if (rho_t1 < floor_val) {
199  KALDI_WARN << "flooring rho_{t+1} to " << floor_val << ", was " << rho_t1;
200  rho_t1 = floor_val;
201  }
202  d_t1.ApplyFloor(floor_val, &nf);
203  if (nf > 0) {
204  KALDI_VLOG(3) << "d_t1 was " << d_t1;
205  KALDI_WARN << "Floored " << nf << " elements of d_{t+1}.";
206  }
207  // a check.
208  if (nf == 0 && rho_t1 > floor_val) {
209  double tr_F_t1 = D * rho_t1 + d_t1.Sum(), tr_T_t = T_t.Trace();
210  AssertEqual(tr_F_t1, tr_T_t);
211  }
212 
213  // G_t = F_t + alpha/D tr(F_t)
214  SpMatrix<double> G_t(F_t);
215  G_t.AddToDiag(alpha_ / D * F_t.Trace());
216  SpMatrix<double> G_t_inv(G_t);
217  G_t_inv.Invert();
218 
219  double beta_t = rho_t_ + alpha_/D * F_t.Trace();
220  // X_hat_t = beta_t X_t G_t^{-1}.
221  Matrix<double> X_hat_t(N, D);
222  X_hat_t.AddMatSp(beta_t, *X_t, kNoTrans, G_t_inv, 0.0);
223 
224  double tr_x_x = TraceMatMat(*X_t, *X_t, kTrans),
225  tr_Xhat_Xhat = TraceMatMat(X_hat_t, X_hat_t, kTrans);
226  double gamma = (tr_Xhat_Xhat == 0 ? 1.0 : sqrt(tr_x_x / tr_Xhat_Xhat));
227 
228  X_t->CopyFromMat(X_hat_t);
229  row_prod->AddDiagMat2(1.0, *X_t, kNoTrans, 0.0);
230  *scale = gamma;
231 
232  // Update the parameters
233  rho_t_ = rho_t1;
234  d_t_.CopyFromVec(d_t1);
235  R_t_.CopyFromMat(R_t1);
236 
237  KALDI_VLOG(3) << "rho_t_ = " << rho_t_;
238  KALDI_VLOG(3) << "d_t_ = " << d_t_;
239  KALDI_VLOG(3) << "R_t_ = " << R_t_;
240 
241 
242  { // check that R_t_ R_t_^T = I.
243  SpMatrix<double> unit(R);
244  unit.AddMat2(1.0, R_t_, kNoTrans, 0.0);
245  if (!unit.IsUnit(1.0e-03)) {
246  KALDI_WARN << "R is not orthogonal, reorthogonalizing.";
247  for (int32 i = 0; i < R; i++) {
248  SubVector<double> row(R_t_, i);
249  for (int32 j = 0; j < i; j++) {
250  SubVector<double> row_j(R_t_, j);
251  row.AddVec(-VecVec(row_j, row), row_j);
252  }
253  row.Scale(1.0 / row.Norm(2.0));
254  }
255  }
256  unit.AddMat2(1.0, R_t_, kNoTrans, 0.0);
257  KALDI_ASSERT(unit.IsUnit(1.0e-03));
258  }
259 }
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
void AddDiagMat2(Real alpha, const MatrixBase< Real > &M, MatrixTransposeType trans=kNoTrans, Real beta=1.0)
Add the diagonal of a matrix times itself: *this = diag(M M^T) + beta * *this (if trans == kNoTrans)...
kaldi::int32 int32
void CopyFromMat(const MatrixBase< OtherReal > &M, MatrixTransposeType trans=kNoTrans)
Copy given matrix. (no resize is done).
Real Min() const
Returns the minimum value of any element, or +infinity for the empty vector.
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_WARN
Definition: kaldi-error.h:150
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
Real Sum() const
Returns sum of the elements.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:37
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
void SortSvd(VectorBase< Real > *s, MatrixBase< Real > *U, MatrixBase< Real > *Vt, bool sort_on_absolute_value)
Function to ensure that SVD is sorted.

◆ SetRank()

Member Data Documentation

◆ alpha_

double alpha_
private

◆ d_t_

◆ delta_

double delta_
private

◆ epsilon_

◆ num_samples_history_

double num_samples_history_
private

Definition at line 55 of file nnet-precondition-online-test.cc.

Referenced by OnlinePreconditionerSimple::Eta().

◆ R_t_

◆ rank_

◆ rho_t_


The documentation for this class was generated from the following file: