NaturalGradientRepeatedAffineComponent Class Reference

#include <nnet-simple-component.h>

Inheritance diagram for NaturalGradientRepeatedAffineComponent:
Collaboration diagram for NaturalGradientRepeatedAffineComponent:

Public Member Functions

 NaturalGradientRepeatedAffineComponent ()
virtual std::string Type () const
 Returns a string such as "SigmoidComponent", describing the type of the object. More...
virtual ComponentCopy () const
 Copies component (deep copy). More...
 NaturalGradientRepeatedAffineComponent (const NaturalGradientRepeatedAffineComponent &other)
virtual void ConsolidateMemory ()
 This virtual function relates to memory management, and avoiding fragmentation. More...
- Public Member Functions inherited from RepeatedAffineComponent
virtual int32 InputDim () const
 Returns input-dimension of this component. More...
virtual int32 OutputDim () const
 Returns output-dimension of this component. More...
virtual std::string Info () const
 Returns some text-form information about this component, for diagnostics. More...
virtual void InitFromConfig (ConfigLine *cfl)
 Initialize, from a ConfigLine object. More...
 RepeatedAffineComponent ()
virtual int32 Properties () const
 Return bitmask of the component's properties. More...
virtual void * Propagate (const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
 Propagate function. More...
virtual void Backprop (const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const
 Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update. More...
virtual void Read (std::istream &is, bool binary)
 Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed. More...
virtual void Write (std::ostream &os, bool binary) const
 Write component to stream. More...
virtual void Scale (BaseFloat scale)
 This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent. More...
virtual void Add (BaseFloat alpha, const Component &other)
 This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters. More...
virtual void PerturbParams (BaseFloat stddev)
 This function is to be used in testing. More...
virtual BaseFloat DotProduct (const UpdatableComponent &other) const
 Computes dot-product between parameters of two instances of a Component. More...
virtual int32 NumParameters () const
 The following new virtual function returns the total dimension of the parameters in this class. More...
virtual void Vectorize (VectorBase< BaseFloat > *params) const
 Turns the parameters into vector form. More...
virtual void UnVectorize (const VectorBase< BaseFloat > &params)
 Converts the parameters from vector form. More...
const CuVector< BaseFloat > & BiasParams () const
const CuMatrix< BaseFloat > & LinearParams () const
 RepeatedAffineComponent (const RepeatedAffineComponent &other)
friend BlockAffineComponent::BlockAffineComponent (const RepeatedAffineComponent &rac)
- Public Member Functions inherited from UpdatableComponent
 UpdatableComponent (const UpdatableComponent &other)
 UpdatableComponent ()
virtual ~UpdatableComponent ()
virtual void SetUnderlyingLearningRate (BaseFloat lrate)
 Sets the learning rate of gradient descent- gets multiplied by learning_rate_factor_. More...
virtual void SetActualLearningRate (BaseFloat lrate)
 Sets the learning rate directly, bypassing learning_rate_factor_. More...
virtual void SetAsGradient ()
 Sets is_gradient_ to true and sets learning_rate_ to 1, ignoring learning_rate_factor_. More...
virtual BaseFloat LearningRateFactor ()
virtual void SetLearningRateFactor (BaseFloat lrate_factor)
void SetUpdatableConfigs (const UpdatableComponent &other)
virtual void FreezeNaturalGradient (bool freeze)
 freezes/unfreezes NaturalGradient updates, if applicable (to be overriden by components that use Natural Gradient). More...
BaseFloat LearningRate () const
 Gets the learning rate to be used in gradient descent. More...
BaseFloat MaxChange () const
 Returns the per-component max-change value, which is interpreted as the maximum change (in l2 norm) in parameters that is allowed per minibatch for this component. More...
void SetMaxChange (BaseFloat max_change)
BaseFloat L2Regularization () const
 Returns the l2 regularization constant, which may be set in any updatable component (usually from the config file). More...
void SetL2Regularization (BaseFloat a)
- Public Member Functions inherited from Component
virtual void StoreStats (const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
 This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity. More...
virtual void ZeroStats ()
 Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero. More...
virtual void GetInputIndexes (const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
 This function only does something interesting for non-simple Components. More...
virtual bool IsComputable (const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
 This function only does something interesting for non-simple Components, and it exists to make it possible to manage optionally-required inputs. More...
virtual void ReorderIndexes (std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
 This function only does something interesting for non-simple Components. More...
virtual ComponentPrecomputedIndexesPrecomputeIndexes (const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
 This function must return NULL for simple Components. More...
virtual void DeleteMemo (void *memo) const
 This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function. More...
 Component ()
virtual ~Component ()

Private Member Functions

virtual void Update (const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_deriv)
const NaturalGradientRepeatedAffineComponentoperator= (const NaturalGradientRepeatedAffineComponent &other)
virtual void SetNaturalGradientConfigs ()

Private Attributes

OnlineNaturalGradient preconditioner_in_

Additional Inherited Members

- Static Public Member Functions inherited from Component
static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream (works out its type). Dies on error. More...
static ComponentNewComponentOfType (const std::string &type)
 Returns a new Component of the given type e.g. More...
- Protected Member Functions inherited from RepeatedAffineComponent
void Init (int32 input_dim, int32 output_dim, int32 num_repeats, BaseFloat param_stddev, BaseFloat bias_mean, BaseFloat bias_stddev)
const RepeatedAffineComponentoperator= (const RepeatedAffineComponent &other)
- Protected Member Functions inherited from UpdatableComponent
void InitLearningRatesFromConfig (ConfigLine *cfl)
std::string ReadUpdatableCommon (std::istream &is, bool binary)
void WriteUpdatableCommon (std::ostream &is, bool binary) const
- Protected Attributes inherited from RepeatedAffineComponent
CuMatrix< BaseFloatlinear_params_
CuVector< BaseFloatbias_params_
int32 num_repeats_
- Protected Attributes inherited from UpdatableComponent
BaseFloat learning_rate_
 learning rate (typically 0.0..0.01) More...
BaseFloat learning_rate_factor_
 learning rate factor (normally 1.0, but can be set to another < value so that when < you call SetLearningRate(), that value will be scaled by this factor. More...
BaseFloat l2_regularize_
 L2 regularization constant. More...
bool is_gradient_
 True if this component is to be treated as a gradient rather than as parameters. More...
BaseFloat max_change_
 configuration value for imposing max-change More...

Detailed Description

Definition at line 638 of file nnet-simple-component.h.

Constructor & Destructor Documentation

◆ NaturalGradientRepeatedAffineComponent() [1/2]

◆ NaturalGradientRepeatedAffineComponent() [2/2]

Member Function Documentation

◆ ConsolidateMemory()

void ConsolidateMemory ( )

This virtual function relates to memory management, and avoiding fragmentation.

It is called only once per model, after we do the first minibatch of training. The default implementation does nothing, but it can be overridden by child classes, where it may re-initialize certain quantities that may possibly have been allocated during the forward pass (e.g. certain statistics; OnlineNaturalGradient objects). We use our own CPU-based allocator (see cu-allocator.h) and since it can't do paging since we're not in control of the GPU page table, fragmentation can be a problem. The allocator always tries to put things in 'low-address memory' (i.e. at smaller memory addresses) near the beginning of the block it allocated, to avoid fragmentation; but if permanent things (belonging to the model) are allocated in the forward pass, they can permanently stay in high memory. This function helps to prevent that, by re-allocating those things into low-address memory (It's important that it's called after all the temporary buffers for the forward-backward have been freed, so that there is low-address memory available)).

Reimplemented from Component.

Definition at line 1656 of file

References NaturalGradientRepeatedAffineComponent::preconditioner_in_, and OnlineNaturalGradient::Swap().

1656  {
1657  OnlineNaturalGradient temp(preconditioner_in_);
1658  preconditioner_in_.Swap(&temp);
1659 }
void Swap(OnlineNaturalGradient *other)

◆ Copy()

Component * Copy ( ) const

◆ operator=()

◆ SetNaturalGradientConfigs()

void SetNaturalGradientConfigs ( )

Reimplemented from RepeatedAffineComponent.

Definition at line 1583 of file

References RepeatedAffineComponent::linear_params_.

1583  {
1584  int32 rank_in = 40;
1585  int32 input_dim = linear_params_.NumCols();
1586  if (rank_in > input_dim / 2)
1587  rank_in = input_dim / 2;
1588  if (rank_in < 1)
1589  rank_in = 1;
1590  preconditioner_in_.SetRank(rank_in);
1592 }
kaldi::int32 int32

◆ Type()

virtual std::string Type ( ) const

Returns a string such as "SigmoidComponent", describing the type of the object.

Reimplemented from RepeatedAffineComponent.

Definition at line 644 of file nnet-simple-component.h.

References Component::ConsolidateMemory(), and PnormComponent::Copy().

644  {
645  return "NaturalGradientRepeatedAffineComponent";
646  }

◆ Update()

void Update ( const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_deriv 

Reimplemented from RepeatedAffineComponent.

Definition at line 1604 of file

References RepeatedAffineComponent::bias_params_, CuMatrixBase< Real >::ColRange(), CuMatrixBase< Real >::CopyColFromVec(), CuMatrixBase< Real >::Data(), rnnlm::i, UpdatableComponent::is_gradient_, KALDI_ASSERT, KALDI_ERR, kaldi::kNoTrans, kaldi::kTrans, UpdatableComponent::learning_rate_, RepeatedAffineComponent::linear_params_, RepeatedAffineComponent::num_repeats_, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), OnlineNaturalGradient::PreconditionDirections(), NaturalGradientRepeatedAffineComponent::preconditioner_in_, CuMatrixBase< Real >::Row(), CuMatrixBase< Real >::Stride(), and CuMatrixBase< Real >::Sum().

1606  {
1607  KALDI_ASSERT(out_deriv.NumCols() == out_deriv.Stride() &&
1608  in_value.NumCols() == in_value.Stride() &&
1609  in_value.NumRows() == out_deriv.NumRows());
1611  int32 num_repeats = num_repeats_,
1612  num_rows = in_value.NumRows(),
1613  block_dim_out = linear_params_.NumRows(),
1614  block_dim_in = linear_params_.NumCols();
1616  CuSubMatrix<BaseFloat> in_value_reshaped(in_value.Data(),
1617  num_rows * num_repeats,
1618  block_dim_in, block_dim_in),
1619  out_deriv_reshaped(out_deriv.Data(),
1620  num_rows * num_repeats,
1621  block_dim_out, block_dim_out);
1623  CuVector<BaseFloat> bias_deriv(block_dim_out);
1624  bias_deriv.AddRowSumMat(1.0, out_deriv_reshaped);
1626  CuMatrix<BaseFloat> deriv(block_dim_out,
1627  block_dim_in + 1);
1628  deriv.ColRange(0, block_dim_in).AddMatMat(
1629  1.0, out_deriv_reshaped, kTrans,
1630  in_value_reshaped, kNoTrans, 1.0);
1631  deriv.CopyColFromVec(bias_deriv, block_dim_in);
1633  BaseFloat scale = 1.0;
1634  if (!is_gradient_) {
1635  try {
1636  // Only apply the preconditioning/natural-gradient if we're not computing
1637  // the exact gradient.
1639  } catch (...) {
1640  int32 num_bad_rows = 0;
1641  for (int32 i = 0; i < out_deriv.NumRows(); i++) {
1642  BaseFloat f = out_deriv.Row(i).Sum();
1643  if (!(f - f == 0)) num_bad_rows++;
1644  }
1645  KALDI_ERR << "Preonditioning failed, in_value sum is "
1646  << in_value.Sum() << ", out_deriv sum is " << out_deriv.Sum()
1647  << ", out_deriv has " << num_bad_rows << " bad rows.";
1648  }
1649  }
1650  linear_params_.AddMat(learning_rate_ * scale,
1651  deriv.ColRange(0, block_dim_in));
1652  bias_deriv.CopyColFromMat(deriv, block_dim_in);
1653  bias_params_.AddVec(learning_rate_ * scale, bias_deriv);
1654 }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
BaseFloat learning_rate_
learning rate (typically 0.0..0.01)
#define KALDI_ERR
Definition: kaldi-error.h:147
void PreconditionDirections(CuMatrixBase< BaseFloat > *X, BaseFloat *scale)
This call implements the main functionality of this class.
bool is_gradient_
True if this component is to be treated as a gradient rather than as parameters.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

Member Data Documentation

◆ preconditioner_in_

The documentation for this class was generated from the following files: