nnet-precondition-online.h
Go to the documentation of this file.
1 // nnet2/nnet-precondition-online.h
2 
3 // Copyright 2013-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2015 Xiaohui Zhang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET2_NNET_PRECONDITION_ONLINE_H_
22 #define KALDI_NNET2_NNET_PRECONDITION_ONLINE_H_
23 
24 #include <iostream>
25 #include <mutex>
26 #include "base/kaldi-common.h"
27 #include "matrix/matrix-lib.h"
29 
30 namespace kaldi {
31 namespace nnet2 {
32 
33 
414  public:
416 
417  void SetRank(int32 rank);
418  void SetUpdatePeriod(int32 update_period);
419  // num_samples_history is a time-constant (in samples) that determines eta.
420  void SetNumSamplesHistory(BaseFloat num_samples_history);
421  void SetAlpha(BaseFloat alpha);
422  void TurnOnDebug() { self_debug_ = true; }
424  BaseFloat GetAlpha() const { return alpha_; }
425  int32 GetRank() const { return rank_; }
427 
428  // The "R" pointer is both the input (R in the comment) and the output (P in
429  // the comment; equal to the preconditioned directions before scaling by
430  // gamma). If the pointer "row_prod" is supplied, it's set to the inner product
431  // of each row of the preconditioned directions P, at output, with itself.
432  // You would need to apply "scale" to R and "scale * scale" to row_prod, to
433  // get the preconditioned directions; we don't do this ourselves, in order to
434  // save CUDA calls.
436  CuVectorBase<BaseFloat> *row_prod,
437  BaseFloat *scale);
438 
439  // Copy constructor.
440  explicit OnlinePreconditioner(const OnlinePreconditioner &other);
441  // Assignent operator
443  private:
444 
445  // This does the work of PreconditionDirections (the top-level
446  // function handles some multithreading issues and then calls this function).
447  // Note: WJKL_t (dimension 2*R by D + R) is [ W_t L_t; J_t K_t ].
449  const BaseFloat rho_t,
450  const Vector<BaseFloat> &d_t,
451  CuMatrixBase<BaseFloat> *WJKL_t,
453  CuVectorBase<BaseFloat> *row_prod,
454  BaseFloat *scale);
455 
456  void ComputeEt(const VectorBase<BaseFloat> &d_t,
457  BaseFloat beta_t,
459  VectorBase<BaseFloat> *sqrt_e_t,
460  VectorBase<BaseFloat> *inv_sqrt_e_t) const;
461 
462  void ComputeZt(int32 N,
463  BaseFloat rho_t,
464  const VectorBase<BaseFloat> &d_t,
465  const VectorBase<BaseFloat> &inv_sqrt_e_t,
466  const MatrixBase<BaseFloat> &K_t,
467  const MatrixBase<BaseFloat> &L_t,
468  SpMatrix<double> *Z_t) const;
469  // Computes W_{t+1}. Overwrites J_t.
470  void ComputeWt1(int32 N,
471  const VectorBase<BaseFloat> &d_t,
472  const VectorBase<BaseFloat> &d_t1,
473  BaseFloat rho_t,
474  BaseFloat rho_t1,
475  const MatrixBase<BaseFloat> &U_t,
476  const VectorBase<BaseFloat> &sqrt_c_t,
477  const VectorBase<BaseFloat> &inv_sqrt_e_t,
478  const CuMatrixBase<BaseFloat> &W_t,
480  CuMatrixBase<BaseFloat> *W_t1) const;
481 
482  // This function is called if C_t has high condition number; it makes sure
483  // that R_{t+1} is orthogonal. See the section in the extended comment above
484  // on "keeping R_t orthogonal".
485  void ReorthogonalizeXt1(const VectorBase<BaseFloat> &d_t1,
486  BaseFloat rho_t1,
488  CuMatrixBase<BaseFloat> *temp_W,
489  CuMatrixBase<BaseFloat> *temp_O);
490 
491  void Init(const CuMatrixBase<BaseFloat> &R0);
492 
493  // Initialize to some small 'default' values, called from Init(). Init() then
494  // does a few iterations of update with the first batch's data to give more
495  // reasonable values.
496  void InitDefault(int32 D);
497 
498  // initializes R, which is assumed to have at least as many columns as rows,
499  // to a specially designed matrix with orthonormal rows, that has no zero rows
500  // or columns.
502 
503  // Returns the learning rate eta as the function of the number of samples
504  // (actually, N is the number of vectors we're preconditioning, which due to
505  // context is not always exactly the same as the number of samples). The
506  // value returned depends on num_samples_history_.
507  BaseFloat Eta(int32 N) const;
508 
509  // called if self_debug_ = true, makes sure the members satisfy certain
510  // properties.
511  void SelfTest() const;
512 
513  // Configuration values:
514 
515  // The rank of the correction to the unit matrix (e.g. 20).
517 
518  // After a few initial iterations of updating whenever we can, we start only
519  // updating the Fisher-matrix parameters every "update_period_" minibatches;
520  // this saves time.
522 
523  // num_samples_history_ determines the value of eta, which in turn affects how
524  // fast we update our estimate of the covariance matrix. We've done it this
525  // way in order to make it easy to have a single configuration value that
526  // doesn't have to be changed when we change the minibatch size.
528 
529  // alpha controls how much we smooth the Fisher matrix with the unit matrix.
530  // e.g. alpha = 4.0.
532 
533  // epsilon is an absolute floor on the unit-matrix scaling factor rho_t in our
534  // Fisher estimate, which we set to 1.0e-10. We don't actually make this
535  // configurable from the command line. It's needed to avoid crashes on
536  // all-zero inputs.
538 
539  // delta is a relative floor on the unit-matrix scaling factor rho_t in our
540  // Fisher estimate, which we set to 1.0e-05: this is relative to the largest
541  // value of D_t. It's needed to control roundoff error. We apply the same
542  // floor to the eigenvalues in D_t.
544 
545  // t is a counter that measures how many updates we've done.
547 
548  // This keeps track of how many minibatches we've skipped updating the parameters,
549  // since the most recent update; it's used in enforcing "update_period_", which
550  // is a mechanism to avoid spending too much time updating the subspace (which can
551  // be wasteful).
553 
554  // If true, activates certain checks.
556 
560 
561 
562  // Used to prevent parameters being read or written in an inconsistent state.
563  std::mutex read_write_mutex_;
564 
565  // This mutex is used to control which thread gets to update the
566  // parameters, in multi-threaded code.
567  std::mutex update_mutex_;
568 };
569 
570 } // namespace nnet2
571 } // namespace kaldi
572 
573 
574 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ComputeWt1(int32 N, const VectorBase< BaseFloat > &d_t, const VectorBase< BaseFloat > &d_t1, BaseFloat rho_t, BaseFloat rho_t1, const MatrixBase< BaseFloat > &U_t, const VectorBase< BaseFloat > &sqrt_c_t, const VectorBase< BaseFloat > &inv_sqrt_e_t, const CuMatrixBase< BaseFloat > &W_t, CuMatrixBase< BaseFloat > *J_t, CuMatrixBase< BaseFloat > *W_t1) const
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
static void InitOrthonormalSpecial(CuMatrixBase< BaseFloat > *R)
This function creates a matrix with orthonormal rows that is like the following matrix, except with each row normalized to have unit 2-norm: [ 1.1 0 1 0 1 0 0 1.1 0 1 0 1 ] The reason why the first element in each row is 1.1 and not 1, is for symmetry-breaking...
void ComputeEt(const VectorBase< BaseFloat > &d_t, BaseFloat beta_t, VectorBase< BaseFloat > *e_t, VectorBase< BaseFloat > *sqrt_e_t, VectorBase< BaseFloat > *inv_sqrt_e_t) const
void ComputeZt(int32 N, BaseFloat rho_t, const VectorBase< BaseFloat > &d_t, const VectorBase< BaseFloat > &inv_sqrt_e_t, const MatrixBase< BaseFloat > &K_t, const MatrixBase< BaseFloat > &L_t, SpMatrix< double > *Z_t) const
void PreconditionDirections(CuMatrixBase< BaseFloat > *R, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)
void SetNumSamplesHistory(BaseFloat num_samples_history)
void ReorthogonalizeXt1(const VectorBase< BaseFloat > &d_t1, BaseFloat rho_t1, CuMatrixBase< BaseFloat > *W_t1, CuMatrixBase< BaseFloat > *temp_W, CuMatrixBase< BaseFloat > *temp_O)
Matrix for CUDA computing.
Definition: matrix-common.h:69
A class representing a vector.
Definition: kaldi-vector.h:406
OnlinePreconditioner & operator=(const OnlinePreconditioner &other)
void Init(const CuMatrixBase< BaseFloat > &R0)
Keywords for search: natural gradient, naturalgradient, NG-SGD.
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void PreconditionDirectionsInternal(const int32 t, const BaseFloat rho_t, const Vector< BaseFloat > &d_t, CuMatrixBase< BaseFloat > *WJKL_t, CuMatrixBase< BaseFloat > *X_t, CuVectorBase< BaseFloat > *row_prod, BaseFloat *scale)
Vector for CUDA computing.
Definition: matrix-common.h:72