All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
OnlinePreconditionerSimple Class Reference
Collaboration diagram for OnlinePreconditionerSimple:

Public Member Functions

 OnlinePreconditionerSimple ()
 
void SetRank (int32 rank)
 
void PreconditionDirections (CuMatrixBase< BaseFloat > *R, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)
 

Private Member Functions

BaseFloat Eta (int32 N) const
 
void PreconditionDirectionsCpu (MatrixBase< double > *R, VectorBase< double > *row_prod, BaseFloat *scale)
 
void Init (const MatrixBase< double > &R0)
 
void InitDefault (int32 D)
 

Private Attributes

int32 rank_
 
double num_samples_history_
 
double alpha_
 
double epsilon_
 
double delta_
 
Vector< double > d_t_
 
Matrix< double > R_t_
 
double rho_t_
 

Detailed Description

Definition at line 28 of file nnet-precondition-online-test.cc.

Constructor & Destructor Documentation

Member Function Documentation

BaseFloat Eta ( int32  N) const
private

Definition at line 120 of file nnet-precondition-online-test.cc.

References KALDI_ASSERT, and OnlinePreconditionerSimple::num_samples_history_.

Referenced by OnlinePreconditionerSimple::PreconditionDirectionsCpu().

120  {
122  BaseFloat ans = 1.0 - exp(-N / num_samples_history_);
123  if (ans > 0.9) ans = 0.9;
124  return ans;
125 }
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void Init ( const MatrixBase< double > &  R0)
private

Definition at line 108 of file nnet-precondition-online-test.cc.

References rnnlm::i, OnlinePreconditionerSimple::InitDefault(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), and OnlinePreconditionerSimple::PreconditionDirections().

Referenced by OnlinePreconditionerSimple::PreconditionDirectionsCpu().

108  {
109  int32 D = R0.NumCols(), N = R0.NumRows();
110  InitDefault(D);
111  int32 num_init_iters = 3;
112  for (int32 i = 0; i < num_init_iters; i++) {
113  CuMatrix<BaseFloat> R0_copy(R0);
114  CuVector<BaseFloat> row_products(N);
115  BaseFloat scale;
116  PreconditionDirections(&R0_copy, &row_products, &scale);
117  }
118 }
float BaseFloat
Definition: kaldi-types.h:29
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
void PreconditionDirections(CuMatrixBase< BaseFloat > *R, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)
void InitDefault ( int32  D)
private

Definition at line 84 of file nnet-precondition-online-test.cc.

References OnlinePreconditionerSimple::d_t_, OnlinePreconditionerSimple::epsilon_, rnnlm::i, KALDI_WARN, OnlinePreconditionerSimple::R_t_, OnlinePreconditionerSimple::rank_, Vector< Real >::Resize(), Matrix< Real >::Resize(), OnlinePreconditionerSimple::rho_t_, and VectorBase< Real >::Set().

Referenced by OnlinePreconditionerSimple::Init().

84  {
85  if (rank_ >= D) {
86  KALDI_WARN << "Rank " << rank_ << " of online preconditioner is >= dim " << D
87  << ", setting it to "
88  << (D - 1) << " (but this is probably still too high)";
89  rank_ = D - 1;
90  }
91  int32 R = rank_;
92  R_t_.Resize(R, D);
93  for (int32 r = 0; r < R; r++) {
94  std::vector<int32> cols;
95  for (int32 c = r; c < D; c += R)
96  cols.push_back(c);
97  for (int32 i = 0; i < cols.size(); i++) {
98  int32 c = cols[i];
99  R_t_(r, c) = (i == 0 ? 1.1 : 1.0) /
100  sqrt(1.1 * 1.1 + cols.size() - 1);
101  }
102  }
103  d_t_.Resize(R);
104  d_t_.Set(epsilon_);
105  rho_t_ = epsilon_;
106 }
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
#define KALDI_WARN
Definition: kaldi-error.h:130
void Set(Real f)
Set all members of a vector to a specified value.
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
void PreconditionDirections ( CuMatrixBase< BaseFloat > *  R,
CuVectorBase< BaseFloat > *  row_prod,
BaseFloat scale 
)

Definition at line 67 of file nnet-precondition-online-test.cc.

References MatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::CopyFromMat(), CuVectorBase< Real >::CopyFromVec(), VectorBase< Real >::CopyFromVec(), and OnlinePreconditionerSimple::PreconditionDirectionsCpu().

Referenced by OnlinePreconditionerSimple::Init(), and kaldi::nnet2::UnitTestPreconditionDirectionsOnline().

70  {
71  Matrix<BaseFloat> R_cpu(*R);
72  Vector<BaseFloat> row_prod_cpu(*row_prod);
73  Matrix<double> R_cpu_dbl(R_cpu);
74  Vector<double> row_prod_cpu_dbl(row_prod_cpu);
75  PreconditionDirectionsCpu(&R_cpu_dbl,
76  &row_prod_cpu_dbl,
77  scale);
78  row_prod_cpu.CopyFromVec(row_prod_cpu_dbl);
79  R_cpu.CopyFromMat(R_cpu_dbl);
80  R->CopyFromMat(R_cpu);
81  row_prod->CopyFromVec(row_prod_cpu);
82 }
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:337
void CopyFromVec(const CuVectorBase< Real > &src)
Copy functions; these will crash if the dimension do not match.
Definition: cu-vector.cc:970
void PreconditionDirectionsCpu(MatrixBase< double > *R, VectorBase< double > *row_prod, BaseFloat *scale)
void PreconditionDirectionsCpu ( MatrixBase< double > *  R,
VectorBase< double > *  row_prod,
BaseFloat scale 
)
private

Definition at line 128 of file nnet-precondition-online-test.cc.

References VectorBase< Real >::Add(), VectorBase< Real >::AddDiagMat2(), SpMatrix< Real >::AddMat2(), SpMatrix< Real >::AddMat2Vec(), MatrixBase< Real >::AddMatMat(), MatrixBase< Real >::AddMatSp(), SpMatrix< Real >::AddSp(), PackedMatrix< Real >::AddToDiag(), VectorBase< Real >::AddVec(), OnlinePreconditionerSimple::alpha_, VectorBase< Real >::ApplyFloor(), VectorBase< Real >::ApplyPow(), kaldi::AssertEqual(), MatrixBase< Real >::CopyFromMat(), VectorBase< Real >::CopyFromVec(), OnlinePreconditionerSimple::d_t_, OnlinePreconditionerSimple::delta_, SpMatrix< Real >::Eig(), OnlinePreconditionerSimple::epsilon_, OnlinePreconditionerSimple::Eta(), rnnlm::i, OnlinePreconditionerSimple::Init(), SpMatrix< Real >::Invert(), VectorBase< Real >::InvertElements(), SpMatrix< Real >::IsUnit(), rnnlm::j, KALDI_ASSERT, KALDI_VLOG, KALDI_WARN, kaldi::kNoTrans, kaldi::kTrans, VectorBase< Real >::Max(), VectorBase< Real >::Min(), MatrixBase< Real >::MulRowsVec(), VectorBase< Real >::Norm(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), OnlinePreconditionerSimple::R_t_, OnlinePreconditionerSimple::rho_t_, VectorBase< Real >::Scale(), kaldi::SortSvd(), VectorBase< Real >::Sum(), SpMatrix< Real >::Trace(), kaldi::TraceMatMat(), and kaldi::VecVec().

Referenced by OnlinePreconditionerSimple::PreconditionDirections().

131  {
132  if (R_t_.NumRows() == 0)
133  Init(*X_t);
134  int32 R = R_t_.NumRows(), D = R_t_.NumCols(), N = X_t->NumRows();
135  BaseFloat eta = Eta(N);
136 
137  SpMatrix<double> F_t(D);
138  // F_t =(def) R_t^T D_t R_t + \rho_t I
139  F_t.AddToDiag(rho_t_);
140  F_t.AddMat2Vec(1.0, R_t_, kTrans, d_t_, 1.0);
141 
142  // Make sure F_t is +ve definite.
143  {
144  KALDI_ASSERT(d_t_.Min() > 0);
145  Vector<double> eigs(D);
146  F_t.Eig(&eigs, NULL);
147  KALDI_ASSERT(eigs.Min() > 0);
148  }
149 
150  // S_t =(def) 1/N X_t^T X_t.
151  SpMatrix<double> S_t(D);
152  S_t.AddMat2(1.0 / N, *X_t, kTrans, 0.0);
153 
154  // T_t =(def) \eta S_t + (1-\eta) F_t
155  SpMatrix<double> T_t(D);
156  T_t.AddSp(eta, S_t);
157  T_t.AddSp(1.0 - eta, F_t);
158 
159  // Y_t =(def) R_t T_t
160  Matrix<double> Y_t(R, D);
161  Y_t.AddMatSp(1.0, R_t_, kNoTrans, T_t, 0.0);
162 
163  // Z_t =(def) Y_t Y_t^T
164  SpMatrix<double> Z_t(R);
165  Z_t.AddMat2(1.0, Y_t, kNoTrans, 0.0);
166 
167  Matrix<double> U_t(R, R);
168  Vector<double> c_t(R);
169  // decompose Z_t = U_t C_t U_t^T
170  Z_t.Eig(&c_t, &U_t);
171  SortSvd(&c_t, &U_t);
172  double c_t_floor = pow(rho_t_ * (1.0 - eta), 2);
173  int32 nf = c_t.ApplyFloor(c_t_floor);
174  if (nf > 0) {
175  KALDI_WARN << "Floored " << nf << " elements of c_t.";
176  }
177  // KALDI_LOG << "c_t is " << c_t;
178  // KALDI_LOG << "U_t is " << U_t;
179  // KALDI_LOG << "Z_t is " << Z_t;
180 
181  Vector<double> sqrt_c_t(c_t);
182  sqrt_c_t.ApplyPow(0.5);
183  Vector<double> inv_sqrt_c_t(sqrt_c_t);
184  inv_sqrt_c_t.InvertElements();
185  Matrix<double> R_t1(R, D);
186  // R_{t+1} = C_t^{-0.5} U_t^T Y_t
187  R_t1.AddMatMat(1.0, U_t, kTrans, Y_t, kNoTrans, 0.0);
188  R_t1.MulRowsVec(inv_sqrt_c_t);
189 
190  double rho_t1 = (1.0 / (D - R)) *
191  (eta * S_t.Trace() + (1.0 - eta) * (D * rho_t_ + d_t_.Sum()) - sqrt_c_t.Sum());
192 
193  Vector<double> d_t1(sqrt_c_t);
194  d_t1.Add(-rho_t1);
195 
196  double floor_val = std::max(epsilon_, delta_ * sqrt_c_t.Max());
197  if (rho_t1 < floor_val) {
198  KALDI_WARN << "flooring rho_{t+1} to " << floor_val << ", was " << rho_t1;
199  rho_t1 = floor_val;
200  }
201  nf = d_t1.ApplyFloor(floor_val);
202  if (nf > 0) {
203  KALDI_VLOG(3) << "d_t1 was " << d_t1;
204  KALDI_WARN << "Floored " << nf << " elements of d_{t+1}.";
205  }
206  // a check.
207  if (nf == 0 && rho_t1 > floor_val) {
208  double tr_F_t1 = D * rho_t1 + d_t1.Sum(), tr_T_t = T_t.Trace();
209  AssertEqual(tr_F_t1, tr_T_t);
210  }
211 
212  // G_t = F_t + alpha/D tr(F_t)
213  SpMatrix<double> G_t(F_t);
214  G_t.AddToDiag(alpha_ / D * F_t.Trace());
215  SpMatrix<double> G_t_inv(G_t);
216  G_t_inv.Invert();
217 
218  double beta_t = rho_t_ + alpha_/D * F_t.Trace();
219  // X_hat_t = beta_t X_t G_t^{-1}.
220  Matrix<double> X_hat_t(N, D);
221  X_hat_t.AddMatSp(beta_t, *X_t, kNoTrans, G_t_inv, 0.0);
222 
223  double tr_x_x = TraceMatMat(*X_t, *X_t, kTrans),
224  tr_Xhat_Xhat = TraceMatMat(X_hat_t, X_hat_t, kTrans);
225  double gamma = (tr_Xhat_Xhat == 0 ? 1.0 : sqrt(tr_x_x / tr_Xhat_Xhat));
226 
227  X_t->CopyFromMat(X_hat_t);
228  row_prod->AddDiagMat2(1.0, *X_t, kNoTrans, 0.0);
229  *scale = gamma;
230 
231  // Update the parameters
232  rho_t_ = rho_t1;
233  d_t_.CopyFromVec(d_t1);
234  R_t_.CopyFromMat(R_t1);
235 
236  KALDI_VLOG(3) << "rho_t_ = " << rho_t_;
237  KALDI_VLOG(3) << "d_t_ = " << d_t_;
238  KALDI_VLOG(3) << "R_t_ = " << R_t_;
239 
240 
241  { // check that R_t_ R_t_^T = I.
242  SpMatrix<double> unit(R);
243  unit.AddMat2(1.0, R_t_, kNoTrans, 0.0);
244  if (!unit.IsUnit(1.0e-03)) {
245  KALDI_WARN << "R is not orthogonal, reorthogonalizing.";
246  for (int32 i = 0; i < R; i++) {
247  SubVector<double> row(R_t_, i);
248  for (int32 j = 0; j < i; j++) {
249  SubVector<double> row_j(R_t_, j);
250  row.AddVec(-VecVec(row_j, row), row_j);
251  }
252  row.Scale(1.0 / row.Norm(2.0));
253  }
254  }
255  unit.AddMat2(1.0, R_t_, kNoTrans, 0.0);
256  KALDI_ASSERT(unit.IsUnit(1.0e-03));
257  }
258 }
void AddDiagMat2(Real alpha, const MatrixBase< Real > &M, MatrixTransposeType trans=kNoTrans, Real beta=1.0)
Add the diagonal of a matrix times itself: *this = diag(M M^T) + beta * *this (if trans == kNoTrans)...
Real Sum() const
Returns sum of the elements.
void CopyFromMat(const MatrixBase< OtherReal > &M, MatrixTransposeType trans=kNoTrans)
Copy given matrix. (no resize is done).
Real Min() const
Returns the minimum value of any element, or +infinity for the empty vector.
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_WARN
Definition: kaldi-error.h:130
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:273
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:36
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:482
void SortSvd(VectorBase< Real > *s, MatrixBase< Real > *U, MatrixBase< Real > *Vt, bool sort_on_absolute_value)
Function to ensure that SVD is sorted.

Member Data Documentation

double alpha_
private
double delta_
private
double num_samples_history_
private

Definition at line 55 of file nnet-precondition-online-test.cc.

Referenced by OnlinePreconditionerSimple::Eta().


The documentation for this class was generated from the following file: