natural-gradient-online.h
Go to the documentation of this file.
1 // nnet3/natural-gradient-online.h
2 
3 // Copyright 2013-2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NATURAL_GRADIENT_ONLINE_H_
21 #define KALDI_NNET3_NATURAL_GRADIENT_ONLINE_H_
22 
23 #include <iostream>
24 #include "base/kaldi-common.h"
25 #include "matrix/matrix-lib.h"
27 
28 namespace kaldi {
29 namespace nnet3 {
30 
31 
415  public:
417 
418  void SetRank(int32 rank);
419  void SetUpdatePeriod(int32 update_period);
420  // num_samples_history is a time-constant (in samples) that determines eta.
421  void SetNumSamplesHistory(BaseFloat num_samples_history);
422  // num_minibatches_history is a time-constant measured in minibatches that
423  // provides an alternative way to set eta (the constant that determines how
424  // fast we update the Fisher matrix). If set to a value >0, it overrides any
425  // value of 'num_samples_history' that is present.
426  void SetNumMinibatchesHistory(BaseFloat num_minibatches_history);
427 
428  void SetAlpha(BaseFloat alpha);
429  void TurnOnDebug() { self_debug_ = true; }
432  BaseFloat GetAlpha() const { return alpha_; }
433  int32 GetRank() const { return rank_; }
435 
436  // see comment where 'frozen_' is declared.
437  inline void Freeze(bool frozen) { frozen_ = frozen; }
438 
461  BaseFloat *scale);
462 
463 
464 
465  // Copy constructor.
466  explicit OnlineNaturalGradient(const OnlineNaturalGradient &other);
467  // Assignent operator
469 
470  // Shallow swap
471  void Swap(OnlineNaturalGradient *other);
472  private:
473 
474 
475  // This is an internal function called from PreconditionDirections().
476  // Note: WJKL_t (dimension 2*R by D + R) is [ W_t L_t; J_t K_t ].
477  void PreconditionDirectionsInternal(const BaseFloat rho_t,
478  const BaseFloat tr_X_Xt,
479  bool updating,
480  const Vector<BaseFloat> &d_t,
481  CuMatrixBase<BaseFloat> *WJKL_t,
483 
484 
485  // Works out from t_ and various class variables whether we will update
486  // the parameters on this iteration (returns true if so).
487  bool Updating() const;
488 
489  void ComputeEt(const VectorBase<BaseFloat> &d_t,
490  BaseFloat beta_t,
492  VectorBase<BaseFloat> *sqrt_e_t,
493  VectorBase<BaseFloat> *inv_sqrt_e_t) const;
494 
495  void ComputeZt(int32 N,
496  BaseFloat rho_t,
497  const VectorBase<BaseFloat> &d_t,
498  const VectorBase<BaseFloat> &inv_sqrt_e_t,
499  const MatrixBase<BaseFloat> &K_t,
500  const MatrixBase<BaseFloat> &L_t,
501  SpMatrix<double> *Z_t) const;
502  // Computes W_{t+1}. Overwrites J_t.
503  void ComputeWt1(int32 N,
504  const VectorBase<BaseFloat> &d_t,
505  const VectorBase<BaseFloat> &d_t1,
506  BaseFloat rho_t,
507  BaseFloat rho_t1,
508  const MatrixBase<BaseFloat> &U_t,
509  const VectorBase<BaseFloat> &sqrt_c_t,
510  const VectorBase<BaseFloat> &inv_sqrt_e_t,
511  const CuMatrixBase<BaseFloat> &W_t,
513  CuMatrixBase<BaseFloat> *W_t1) const;
514 
515  // This function is called if C_t has high condition number; it makes sure
516  // that R_{t+1} is orthogonal. See the section in the extended comment above
517  // on "keeping R_t orthogonal".
518  void ReorthogonalizeRt1(const VectorBase<BaseFloat> &d_t1,
519  BaseFloat rho_t1,
521  CuMatrixBase<BaseFloat> *temp_W,
522  CuMatrixBase<BaseFloat> *temp_O);
523 
524  void Init(const CuMatrixBase<BaseFloat> &R0);
525 
526  // Initialize to some small 'default' values, called from Init(). Init() then
527  // does a few iterations of update with the first batch's data to give more
528  // reasonable values.
529  void InitDefault(int32 D);
530 
531  // initializes R, which is assumed to have at least as many columns as rows,
532  // to a specially designed matrix with orthonormal rows, that has no zero rows
533  // or columns.
535 
536  // Returns the value eta (with 0 < eta < 1) which reflects how fast we update
537  // the estimate of the Fisher matrix (larger == faster). This is a function
538  // rather than a constant because we set this indirectly, via
539  // num_samples_history_ or num_minibatches_history_. The argument N is the
540  // number of vectors we're preconditioning, which is the number of rows in the
541  // argument R to PreconditionDirections(); you can think of it as the number
542  // of vectors we're preconditioning (and in the common case it's some multiple
543  // of the minibatch size)
544  BaseFloat Eta(int32 N) const;
545 
546  // called if self_debug_ = true, makes sure the members satisfy certain
547  // properties.
548  void SelfTest() const;
549 
550  // Configuration values:
551 
552  // The rank of the correction to the unit matrix (e.g. 20).
554 
555  // After a few initial iterations of updating whenever we can, we start only
556  // updating the Fisher-matrix parameters every "update_period_" minibatches;
557  // this saves time.
559 
560 
561  // num_samples_history_ determines the value of eta, which in turn affects how
562  // fast we update our estimate of the covariance matrix. We've done it this
563  // way in order to make it easy to have a single configuration value that
564  // doesn't have to be changed when we change the minibatch size.
565  // Note: if num_minibatches_history_ is >0.0, it overrides this.
567 
568 
569  // num_minibatches_history_ is simpler alternative to num_samples_history_ for
570  // determining the value of eta, which in turn affects how fast we update our
571  // estimate of the covariance matrix. eta will be set to 1.0 /
572  // num_minibatches_history_. We require that num_minibatches_history_ > 0.0;
573  // it will normally be something like 10.0, if set. It makes sense to set
574  // 'num_minibatches_history_' when the rows of the matrix we are
575  // preconditioning can't be interpreted as independent samples, so the number
576  // of rows is not relevant to determining how fast to update the covariance
577  // matrix.
579 
580 
581  // alpha controls how much we smooth the Fisher matrix with the unit matrix.
582  // e.g. alpha = 4.0.
584 
585  // epsilon is an absolute floor on the unit-matrix scaling factor rho_t in our
586  // Fisher estimate, which we set to 1.0e-10. We don't actually make this
587  // configurable from the command line. It's needed to avoid crashes on
588  // all-zero inputs.
590 
591  // delta is a relative floor on the unit-matrix scaling factor rho_t in our
592  // Fisher estimate, which we set to 1.0e-05: this is relative to the largest
593  // value of D_t. It's needed to control roundoff error. We apply the same
594  // floor to the eigenvalues in D_t.
596 
597  // this is set to true if the user has called the function Freeze(true), until
598  // they call Freeze(false). It's used to disable the natural gradient
599  // update (and stop incrementing t_). However, if the object is uninitialized
600  // (t_ == 0) it doesn't prevent it from being initialized. This is used
601  // in adversarial training to ensure that the Fisher matrix is updated only
602  // the *second* time we see the same data (to avoid biasing the update).
603  bool frozen_;
604 
605  // t is a counter that measures how many times the user has previously called
606  // PreconditionDirections(); it's 0 if that has never been called.
608 
609  // If true, activates certain checks.
611 
615 };
616 
617 } // namespace nnet3
618 } // namespace kaldi
619 
620 
621 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
static void InitOrthonormalSpecial(CuMatrixBase< BaseFloat > *R)
This function creates a matrix with orthonormal rows that is like the following matrix, except with each row normalized to have unit 2-norm: [ 1.1 0 1 0 1 0 0 1.1 0 1 0 1 ] The reason why the first element in each row is 1.1 and not 1, is for symmetry-breaking...
void ReorthogonalizeRt1(const VectorBase< BaseFloat > &d_t1, BaseFloat rho_t1, CuMatrixBase< BaseFloat > *W_t1, CuMatrixBase< BaseFloat > *temp_W, CuMatrixBase< BaseFloat > *temp_O)
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
void SetNumSamplesHistory(BaseFloat num_samples_history)
OnlineNaturalGradient & operator=(const OnlineNaturalGradient &other)
kaldi::int32 int32
Keywords for search: natural gradient, naturalgradient, NG-SGD.
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void ComputeWt1(int32 N, const VectorBase< BaseFloat > &d_t, const VectorBase< BaseFloat > &d_t1, BaseFloat rho_t, BaseFloat rho_t1, const MatrixBase< BaseFloat > &U_t, const VectorBase< BaseFloat > &sqrt_c_t, const VectorBase< BaseFloat > &inv_sqrt_e_t, const CuMatrixBase< BaseFloat > &W_t, CuMatrixBase< BaseFloat > *J_t, CuMatrixBase< BaseFloat > *W_t1) const
void PreconditionDirections(CuMatrixBase< BaseFloat > *X, BaseFloat *scale)
This call implements the main functionality of this class.
void ComputeEt(const VectorBase< BaseFloat > &d_t, BaseFloat beta_t, VectorBase< BaseFloat > *e_t, VectorBase< BaseFloat > *sqrt_e_t, VectorBase< BaseFloat > *inv_sqrt_e_t) const
void Swap(OnlineNaturalGradient *other)
Matrix for CUDA computing.
Definition: matrix-common.h:69
A class representing a vector.
Definition: kaldi-vector.h:406
void ComputeZt(int32 N, BaseFloat rho_t, const VectorBase< BaseFloat > &d_t, const VectorBase< BaseFloat > &inv_sqrt_e_t, const MatrixBase< BaseFloat > &K_t, const MatrixBase< BaseFloat > &L_t, SpMatrix< double > *Z_t) const
void SetNumMinibatchesHistory(BaseFloat num_minibatches_history)
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void Init(const CuMatrixBase< BaseFloat > &R0)
void PreconditionDirectionsInternal(const BaseFloat rho_t, const BaseFloat tr_X_Xt, bool updating, const Vector< BaseFloat > &d_t, CuMatrixBase< BaseFloat > *WJKL_t, CuMatrixBase< BaseFloat > *X_t)