All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
BatchNormComponent Class Reference

#include <nnet-simple-component.h>

Inheritance diagram for BatchNormComponent:
Collaboration diagram for BatchNormComponent:

Classes

struct  Memo
 

Public Member Functions

 BatchNormComponent ()
 
void SetTestMode (bool test_mode)
 
 BatchNormComponent (const BatchNormComponent &other)
 
virtual int32 InputDim () const
 Returns input-dimension of this component. More...
 
virtual int32 OutputDim () const
 Returns output-dimension of this component. More...
 
virtual std::string Info () const
 Returns some text-form information about this component, for diagnostics. More...
 
virtual void InitFromConfig (ConfigLine *cfl)
 Initialize, from a ConfigLine object. More...
 
virtual std::string Type () const
 Returns a string such as "SigmoidComponent", describing the type of the object. More...
 
virtual int32 Properties () const
 Return bitmask of the component's properties. More...
 
virtual void * Propagate (const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
 Propagate function. More...
 
virtual void Backprop (const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *, CuMatrixBase< BaseFloat > *in_deriv) const
 Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update. More...
 
virtual void Read (std::istream &is, bool binary)
 Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed. More...
 
virtual void Write (std::ostream &os, bool binary) const
 Write component to stream. More...
 
virtual ComponentCopy () const
 Copies component (deep copy). More...
 
virtual void Scale (BaseFloat scale)
 This virtual function when called by. More...
 
virtual void Add (BaseFloat alpha, const Component &other)
 This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters. More...
 
virtual void ZeroStats ()
 Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero. More...
 
virtual void DeleteMemo (void *memo) const
 This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function. More...
 
virtual void StoreStats (const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
 This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity. More...
 
const CuVector< BaseFloat > & Offset () const
 
const CuVector< BaseFloat > & Scale () const
 
- Public Member Functions inherited from Component
virtual void GetInputIndexes (const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual bool IsComputable (const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
 This function only does something interesting for non-simple Components, and it exists to make it possible to manage optionally-required inputs. More...
 
virtual void ReorderIndexes (std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual
ComponentPrecomputedIndexes
PrecomputeIndexes (const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
 This function must return NULL for simple Components. More...
 
 Component ()
 
virtual ~Component ()
 

Private Member Functions

void Check () const
 
void ComputeDerived ()
 

Static Private Member Functions

static void ComputeOffsetAndScale (double count, BaseFloat epsilon, const Vector< double > &stats_sum, const Vector< double > &stats_sumsq, Vector< BaseFloat > *offset, Vector< BaseFloat > *scale)
 

Private Attributes

int32 dim_
 
int32 block_dim_
 
BaseFloat epsilon_
 
BaseFloat target_rms_
 
bool test_mode_
 
double count_
 
CuVector< double > stats_sum_
 
CuVector< double > stats_sumsq_
 
CuVector< BaseFloatoffset_
 
CuVector< BaseFloatscale_
 

Additional Inherited Members

- Static Public Member Functions inherited from Component
static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream (works out its type). Dies on error. More...
 
static ComponentNewComponentOfType (const std::string &type)
 Returns a new Component of the given type e.g. More...
 

Detailed Description

Definition at line 2144 of file nnet-simple-component.h.

Constructor & Destructor Documentation

Member Function Documentation

void Add ( BaseFloat  alpha,
const Component other 
)
virtual

This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters.

– a NonlinearComponent it relates to adding stats Otherwise it should do nothing.

Reimplemented from Component.

Definition at line 5759 of file nnet-simple-component.cc.

References CuVectorBase< Real >::AddVec(), BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::stats_sum_, and BatchNormComponent::stats_sumsq_.

5759  {
5760  const BatchNormComponent *other =
5761  dynamic_cast<const BatchNormComponent*>(&other_in);
5762  count_ += alpha * other->count_;
5763  stats_sum_.AddVec(alpha, other->stats_sum_);
5764  stats_sumsq_.AddVec(alpha, other->stats_sumsq_);
5765  // this operation might change offset_ and scale_, so we recompute them
5766  // in this instance (but not in Scale()).
5767  ComputeDerived();
5768 }
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1126
void Backprop ( const std::string &  debug_info,
const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
const CuMatrixBase< BaseFloat > &  out_deriv,
void *  memo,
Component to_update,
CuMatrixBase< BaseFloat > *  in_deriv 
) const
virtual

Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update.

Parameters
[in]debug_infoThe component name, to be printed out in any warning messages.
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]in_valueThe matrix that was given as input to the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsInput == 0.
[in]out_valueThe matrix that was output from the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsOutput == 0
[in]out_derivThe derivative at the output of this component.
[in]memoThis will normally be NULL, but for component types that set the flag kUsesMemo, this will be the return value of the Propagate() function that corresponds to this Backprop() function. Ownership of any pointers is not transferred to the Backprop function; DeleteMemo() will be called to delete it.
[out]to_updateIf model update is desired, the Component to be updated, else NULL. Does not have to be identical to this. If supplied, you can assume that to_update->Properties() & kUpdatableComponent is nonzero.
[out]in_derivThe derivative at the input of this component, if needed (else NULL). If Properties()&kBackpropInPlace, may be the same matrix as out_deriv. If Properties()&kBackpropAdds, this is added to by the Backprop routine, else it is set. The component code chooses which mode to work in, based on convenience.

Implements Component.

Definition at line 5582 of file nnet-simple-component.cc.

References CuMatrixBase< Real >::AddMatDiagVec(), CuMatrixBase< Real >::AddVecToRows(), BatchNormComponent::block_dim_, CuMatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::Data(), CuVectorBase< Real >::Dim(), BatchNormComponent::dim_, KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, BatchNormComponent::Memo::mean_uvar_scale, CuMatrixBase< Real >::MulColsVec(), BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), BatchNormComponent::offset_, kaldi::SameDim(), BatchNormComponent::scale_, CuMatrixBase< Real >::Stride(), BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

5590  {
5591 
5592  KALDI_ASSERT(SameDim(out_value, out_deriv) &&
5593  SameDim(out_value, *in_deriv) &&
5594  (out_value.NumCols() == dim_ ||
5595  out_value.NumCols() == block_dim_));
5596  if (out_value.NumCols() != block_dim_) {
5597  // if block_dim_ != dim_, we recurse; this helps keep the main code
5598  // simple.
5599  KALDI_ASSERT(out_value.Stride() == out_value.NumCols() &&
5600  out_deriv.Stride() == out_deriv.NumCols() &&
5601  in_deriv->Stride() == in_deriv->NumCols());
5602  int32 ratio = dim_ / block_dim_,
5603  orig_rows = out_value.NumRows(),
5604  orig_cols = out_value.NumCols(),
5605  new_rows = orig_rows * ratio, new_cols = orig_cols / ratio;
5606  CuSubMatrix<BaseFloat> out_value_reshaped(out_value.Data(), new_rows,
5607  new_cols, new_cols),
5608  out_deriv_reshaped(out_deriv.Data(), new_rows, new_cols, new_cols),
5609  in_deriv_reshaped(in_deriv->Data(), new_rows, new_cols, new_cols);
5610  // we'll never use in_value, so pass it in unchanged.
5611  Backprop(debug_info, indexes, in_value,
5612  out_value_reshaped, out_deriv_reshaped,
5613  memo_in, to_update, &in_deriv_reshaped);
5614  return;
5615  }
5616 
5617  Memo *memo = static_cast<Memo*>(memo_in);
5618 
5619  if (!test_mode_) {
5620  // search above for BACKWARD PASS for a comment describing the math.
5621  KALDI_ASSERT(memo != NULL && "memo not passed into backprop");
5622  int32 num_frames = memo->num_frames;
5623  KALDI_ASSERT(out_value.NumRows() == num_frames);
5624  CuSubVector<BaseFloat> temp(memo->mean_uvar_scale, 3),
5625  scale(memo->mean_uvar_scale, 2);
5626  temp.AddRowSumMat(-1.0 / num_frames, out_deriv, 0.0);
5627  // the following does no work if in_deriv and out_deriv are the same matrix.
5628  in_deriv->CopyFromMat(out_deriv);
5629  in_deriv->AddVecToRows(1.0, temp);
5630  in_deriv->MulColsVec(scale);
5631  // at this point, 'in_deriv' contains:
5632  // x_deriv_base(i) = (y'(i) - 1/n sum_i y'(i)) * scale
5633  temp.AddDiagMatMat(-1.0 / (num_frames * target_rms_ * target_rms_),
5634  out_value, kTrans, *in_deriv, kNoTrans, 0.0);
5635  // now, 'temp' contains the quantity which we described
5636  // in the math as:
5637  // alpha = - \sum_i y(i) x_deriv_base(i) / n.
5638  // The factor 1 / (target_rms_ * target_rms_) comes from following
5639  // this additional scaling factor through the math. In the comment I said
5640  // "we know that \sum_i y(i) y(i) = n". Taking target-rms into account
5641  // this becomes "we know that \sum_i y(i) y(i) = n * target-rms^2".
5642  in_deriv->AddMatDiagVec(1.0, out_value, kNoTrans, temp, 1.0);
5643  // At this point, in_deriv contains x'(i) = x_deriv_base(i) + alpha y(i).
5644 
5645  } else {
5647  // the next call does no work if they point to the same memory.
5648  in_deriv->CopyFromMat(out_deriv);
5649  in_deriv->MulColsVec(scale_);
5650  }
5651 }
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
Definition: cu-matrix.cc:1111
MatrixIndexT NumCols() const
Definition: cu-matrix.h:196
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:337
void MulColsVec(const CuVectorBase< Real > &scale)
scale i'th column by scale[i]
Definition: cu-matrix.cc:750
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
void AddMatDiagVec(const Real alpha, const CuMatrixBase< Real > &M, MatrixTransposeType transM, CuVectorBase< Real > &v, Real beta=1.0)
Definition: cu-matrix.cc:1264
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
MatrixIndexT Stride() const
Definition: cu-matrix.h:197
const Real * Data() const
Return data pointer (const).
Definition: cu-matrix.h:625
void ComputeDerived ( )
private

Definition at line 5298 of file nnet-simple-component.cc.

References CuVectorBase< Real >::AddVecVec(), CuVectorBase< Real >::ApplyFloor(), CuVectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, CuVectorBase< Real >::CopyFromVec(), BatchNormComponent::count_, BatchNormComponent::epsilon_, KALDI_WARN, CuVectorBase< Real >::MulElements(), BatchNormComponent::offset_, CuVector< Real >::Resize(), CuVectorBase< Real >::Scale(), BatchNormComponent::scale_, CuVectorBase< Real >::SetRandn(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

Referenced by BatchNormComponent::Add(), BatchNormComponent::BatchNormComponent(), BatchNormComponent::InitFromConfig(), BatchNormComponent::Read(), and BatchNormComponent::SetTestMode().

5298  {
5299  if (!test_mode_) {
5300  offset_.Resize(0);
5301  scale_.Resize(0);
5302  return;
5303  }
5304 
5305  if (count_ == 0.0) {
5306  KALDI_WARN << "Test-mode is set but there is no data count. "
5307  "Creating random counts. This only makes sense "
5308  "in unit-tests (or compute_prob_*.0.log). If you see this "
5309  "elsewhere, something is very wrong.";
5310  count_ = 1.0;
5311  stats_sum_.SetRandn();
5314  }
5315 
5319  offset_.Scale(-1.0 / count_);
5320  // now offset_ is -mean.
5322  scale_.Scale(1.0 / count_);
5323  scale_.AddVecVec(-1.0, offset_, offset_, 1.0);
5324  // now scale_ is variance.
5326  scale_.ApplyPow(-0.5);
5327  // now scale_ = min(variance, epsilon)^{-0.5}.
5328  // next, multiply by the target RMS (normally 1.0).
5331  // now offset_ is -(scale*mean).
5332 }
void Scale(Real value)
Definition: cu-vector.cc:1105
void AddVecVec(Real alpha, const CuVectorBase< Real > &v, const CuVectorBase< Real > &r, Real beta)
Definition: cu-vector.cc:493
void MulElements(const CuVectorBase< Real > &v)
Definition: cu-vector.cc:752
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:892
void ApplyPow(Real power)
Definition: cu-vector.cc:373
void CopyFromVec(const CuVectorBase< Real > &src)
Copy functions; these will crash if the dimension do not match.
Definition: cu-vector.cc:970
#define KALDI_WARN
Definition: kaldi-error.h:130
MatrixIndexT ApplyFloor(Real floor_val)
Definition: cu-vector.cc:324
static void ComputeOffsetAndScale ( double  count,
BaseFloat  epsilon,
const Vector< double > &  stats_sum,
const Vector< double > &  stats_sumsq,
Vector< BaseFloat > *  offset,
Vector< BaseFloat > *  scale 
)
staticprivate
virtual Component* Copy ( ) const
inlinevirtual

Copies component (deep copy).

Implements Component.

Definition at line 2202 of file nnet-simple-component.h.

References BatchNormComponent::BatchNormComponent().

virtual void DeleteMemo ( void *  memo) const
inlinevirtual

This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function.

It's called by NnetComputer in cases where Propagate returns a memo but there will be no backprop to consume it.

Reimplemented from Component.

Definition at line 2209 of file nnet-simple-component.h.

2209 { delete static_cast<Memo*>(memo); }
std::string Info ( ) const
virtual

Returns some text-form information about this component, for diagnostics.

Starts with the type of the component. E.g. "SigmoidComponent dim=900", although most components will have much more info.

Reimplemented from Component.

Definition at line 5354 of file nnet-simple-component.cc.

References VectorBase< Real >::AddVecVec(), VectorBase< Real >::ApplyFloor(), VectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, VectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, kaldi::nnet3::SummarizeVector(), BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, and BatchNormComponent::Type().

5354  {
5355  std::ostringstream stream;
5356  stream << Type() << ", dim=" << dim_ << ", block-dim=" << block_dim_
5357  << ", epsilon=" << epsilon_ << ", target-rms=" << target_rms_
5358  << ", count=" << count_
5359  << ", test-mode=" << (test_mode_ ? "true" : "false");
5360  if (count_ > 0) {
5361  Vector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
5362  mean.Scale(1.0 / count_);
5363  var.Scale(1.0 / count_);
5364  // subtract mean^2 from var.
5365  var.AddVecVec(-1.0, mean, mean, 1.0);
5366  var.ApplyFloor(0.0);
5367  var.ApplyPow(0.5); // make it the stddev.
5368  stream << ", data-mean=" << SummarizeVector(mean)
5369  << ", data-stddev=" << SummarizeVector(var);
5370  }
5371  return stream.str();
5372 }
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
std::string SummarizeVector(const Vector< BaseFloat > &vec)
Definition: nnet-parse.cc:484
void InitFromConfig ( ConfigLine cfl)
virtual

Initialize, from a ConfigLine object.

Parameters
[in]cflA ConfigLine containing any parameters that are needed for initialization. For example: "dim=100 param-stddev=0.1"

Implements Component.

Definition at line 5374 of file nnet-simple-component.cc.

References BatchNormComponent::block_dim_, BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, ConfigLine::GetValue(), ConfigLine::HasUnusedValues(), KALDI_ERR, CuVector< Real >::Resize(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, and ConfigLine::UnusedValues().

5374  {
5375  dim_ = -1;
5376  block_dim_ = -1;
5377  epsilon_ = 1.0e-03;
5378  target_rms_ = 1.0;
5379  test_mode_ = false;
5380  bool ok = cfl->GetValue("dim", &dim_);
5381  cfl->GetValue("block-dim", &block_dim_);
5382  cfl->GetValue("epsilon", &epsilon_);
5383  cfl->GetValue("target-rms", &target_rms_);
5384  cfl->GetValue("test-mode", &test_mode_);
5385  if (!ok || dim_ <= 0) {
5386  KALDI_ERR << "BatchNormComponent must have 'dim' specified, and > 0";
5387  }
5388  if (block_dim_ == -1)
5389  block_dim_ = dim_;
5390  if (!(block_dim_ > 0 && dim_ % block_dim_ == 0 &&
5391  epsilon_ > 0 && target_rms_ > 0))
5392  KALDI_ERR << "Invalid configuration dim=" << dim_
5393  << ", block-dim=" << block_dim_ << ", epsilon=" << epsilon_
5394  << ", target-rms=" << target_rms_ << " in BatchNormComponent.";
5395  if (cfl->HasUnusedValues())
5396  KALDI_ERR << "Could not process these elements in initializer: "
5397  << cfl->UnusedValues();
5398  count_ = 0;
5401  if (test_mode_) {
5402  ComputeDerived();
5403  }
5404 }
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:892
#define KALDI_ERR
Definition: kaldi-error.h:127
virtual int32 InputDim ( ) const
inlinevirtual

Returns input-dimension of this component.

Implements Component.

Definition at line 2161 of file nnet-simple-component.h.

References BatchNormComponent::dim_.

const CuVector<BaseFloat>& Offset ( ) const
inline

Definition at line 2217 of file nnet-simple-component.h.

References BatchNormComponent::offset_.

Referenced by ModelCollapser::CollapseComponentsBatchnorm().

2217 { return offset_; }
virtual int32 OutputDim ( ) const
inlinevirtual

Returns output-dimension of this component.

Implements Component.

Definition at line 2162 of file nnet-simple-component.h.

References BatchNormComponent::dim_.

void * Propagate ( const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
) const
virtual

Propagate function.

Parameters
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]inThe input to this component. Num-columns == InputDim().
[out]outThe output of this component. Num-columns == OutputDim(). Note: output of this component will be added to the initial value of "out" if Properties()&kPropagateAdds != 0; otherwise the output will be set and the initial value ignored. Each Component chooses whether it is more convenient implementation-wise to add or set, and the calling code has to deal with it.
Returns
Normally returns NULL, but may return a non-NULL value for components which have the flag kUsesMemo set. This value will be passed into the corresponding Backprop routine.

Implements Component.

Definition at line 5517 of file nnet-simple-component.cc.

References CuMatrixBase< Real >::AddVecToRows(), CuVectorBase< Real >::AddVecVec(), CuVectorBase< Real >::ApplyFloor(), CuVectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, CuMatrixBase< Real >::CopyFromMat(), CuVectorBase< Real >::CopyFromVec(), BatchNormComponent::count_, CuMatrixBase< Real >::Data(), CuVectorBase< Real >::Dim(), BatchNormComponent::dim_, BatchNormComponent::epsilon_, KALDI_ASSERT, KALDI_ERR, kaldi::kTrans, BatchNormComponent::Memo::mean_uvar_scale, CuMatrixBase< Real >::MulColsVec(), BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), BatchNormComponent::offset_, CuMatrix< Real >::Resize(), kaldi::SameDim(), BatchNormComponent::scale_, CuMatrixBase< Real >::Stride(), BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

5519  {
5520  KALDI_ASSERT(SameDim(in, *out) &&
5521  (in.NumCols() == dim_ || in.NumCols() == block_dim_));
5522  if (in.NumCols() != block_dim_) {
5523  // if block_dim_ != dim_, we recurse; this helps keep the main code
5524  // simple.
5525  KALDI_ASSERT(in.Stride() == in.NumCols() && out->Stride() == out->NumCols());
5526  int32 ratio = dim_ / block_dim_, orig_rows = in.NumRows(),
5527  orig_cols = in.NumCols(), new_rows = orig_rows * ratio,
5528  new_cols = orig_cols / ratio;
5529  CuSubMatrix<BaseFloat> in_reshaped(in.Data(), new_rows, new_cols, new_cols),
5530  out_reshaped(out->Data(), new_rows, new_cols, new_cols);
5531  return Propagate(indexes, in_reshaped, &out_reshaped);
5532  }
5533 
5534  // From this point, we can assume that the num-cols of 'in' and 'out'
5535  // equals block_dim_.
5536 
5537  if (!test_mode_) {
5538  // search in the comment above for FORWARD PASS to see what is being
5539  // implemented here.
5540  // if this takes too much time due to multiple different CUDA calls,
5541  // we'll consider making a single kernel for some of it.
5542  Memo *memo = new Memo;
5543  int32 num_frames = in.NumRows(), dim = block_dim_;
5544  memo->num_frames = num_frames;
5545  memo->mean_uvar_scale.Resize(4, dim);
5546  CuSubVector<BaseFloat> mean(memo->mean_uvar_scale, 0),
5547  uvar(memo->mean_uvar_scale, 1),
5548  scale(memo->mean_uvar_scale, 2);
5549  mean.AddRowSumMat(1.0 / num_frames, in, 0.0);
5550  uvar.AddDiagMat2(1.0 / num_frames, in, kTrans, 0.0);
5551  scale.CopyFromVec(uvar);
5552  // by applying this scale at this point, we save a multiply later on.
5553  BaseFloat var_scale = 1.0 / (target_rms_ * target_rms_);
5554  scale.AddVecVec(-var_scale, mean, mean, var_scale);
5555  // at this point, 'scale' contains just the variance [divided by target-rms^2].
5556  scale.ApplyFloor(var_scale * epsilon_);
5557  // Now 'scale' contains the variance floored to epsilon [both divided by
5558  // target-rms^2]. We floor instead of adding so that we can be 100% sure
5559  // that it's positive even in the presence of roundoff.
5560  scale.ApplyPow(-0.5);
5561  // now 'scale' is the actual scale we'll use.
5562 
5563  // the next command will do no work if out == in, for in-place propagation.
5564  out->CopyFromMat(in);
5565  out->AddVecToRows(-1.0, mean, 1.0);
5566  out->MulColsVec(scale);
5567  return static_cast<void*>(memo);
5568  } else {
5569  if (offset_.Dim() != block_dim_) {
5570  if (count_ == 0)
5571  KALDI_ERR << "Test mode set in BatchNormComponent, but no stats.";
5572  else // why was ComputeDerived() not called?
5573  KALDI_ERR << "Code error in BatchNormComponent";
5574  }
5575  out->CopyFromMat(in);
5576  out->MulColsVec(scale_);
5577  out->AddVecToRows(1.0, offset_, 1.0);
5578  return NULL;
5579  }
5580 }
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
Definition: cu-matrix.cc:1111
MatrixIndexT NumCols() const
Definition: cu-matrix.h:196
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:337
void MulColsVec(const CuVectorBase< Real > &scale)
scale i'th column by scale[i]
Definition: cu-matrix.cc:750
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
float BaseFloat
Definition: kaldi-types.h:29
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
Propagate function.
MatrixIndexT Stride() const
Definition: cu-matrix.h:197
const Real * Data() const
Return data pointer (const).
Definition: cu-matrix.h:625
virtual int32 Properties ( ) const
inlinevirtual

Return bitmask of the component's properties.

These properties depend only on the component's type. See enum ComponentProperties.

Implements Component.

Definition at line 2175 of file nnet-simple-component.h.

References BatchNormComponent::block_dim_, BatchNormComponent::dim_, kaldi::nnet3::kBackpropInPlace, kaldi::nnet3::kBackpropNeedsOutput, kaldi::nnet3::kInputContiguous, kaldi::nnet3::kOutputContiguous, kaldi::nnet3::kPropagateInPlace, kaldi::nnet3::kSimpleComponent, kaldi::nnet3::kStoresStats, kaldi::nnet3::kUsesMemo, and BatchNormComponent::test_mode_.

2175  {
2176  // If the block-dim is less than the dim, we need the input and output
2177  // matrices to be contiguous (stride==num-cols), as we'll be reshaping
2178  // internally. This is not much of a cost, because this will be used
2179  // in convnets where we have to do this anyway.
2184  }
void Read ( std::istream &  is,
bool  binary 
)
virtual

Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed.

Implements Component.

Definition at line 5693 of file nnet-simple-component.cc.

References CuVectorBase< Real >::AddVecVec(), BatchNormComponent::block_dim_, BatchNormComponent::Check(), BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, kaldi::nnet3::ExpectOneOrTwoTokens(), kaldi::nnet3::ExpectToken(), CuVector< Real >::Read(), kaldi::ReadBasicType(), CuVectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

5693  {
5694  ExpectOneOrTwoTokens(is, binary, "<BatchNormComponent>", "<Dim>");
5695  ReadBasicType(is, binary, &dim_);
5696  ExpectToken(is, binary, "<BlockDim>");
5697  ReadBasicType(is, binary, &block_dim_);
5698  ExpectToken(is, binary, "<Epsilon>");
5699  ReadBasicType(is, binary, &epsilon_);
5700  ExpectToken(is, binary, "<TargetRms>");
5701  ReadBasicType(is, binary, &target_rms_);
5702  ExpectToken(is, binary, "<TestMode>");
5703  ReadBasicType(is, binary, &test_mode_);
5704  ExpectToken(is, binary, "<Count>");
5705  ReadBasicType(is, binary, &count_);
5706  ExpectToken(is, binary, "<StatsMean>");
5707  stats_sum_.Read(is, binary);
5708  ExpectToken(is, binary, "<StatsVar>");
5709  stats_sumsq_.Read(is, binary);
5713  ExpectToken(is, binary, "</BatchNormComponent>");
5714  ComputeDerived();
5715  Check();
5716 }
void Scale(Real value)
Definition: cu-vector.cc:1105
void AddVecVec(Real alpha, const CuVectorBase< Real > &v, const CuVectorBase< Real > &r, Real beta)
Definition: cu-vector.cc:493
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void Read(std::istream &is, bool binary)
I/O.
Definition: cu-vector.cc:862
void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)
This function is like ExpectToken but for two tokens, and it will either accept token1 and then token...
Definition: nnet-parse.cc:224
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void Scale ( BaseFloat  scale)
virtual

This virtual function when called by.

by "scale" when called by an UpdatableComponent. stores stats, like BatchNormComponent– it relates to scaling activation stats, not parameters.

Reimplemented from Component.

Definition at line 5746 of file nnet-simple-component.cc.

References BatchNormComponent::count_, CuVectorBase< Real >::Scale(), CuVectorBase< Real >::SetZero(), BatchNormComponent::stats_sum_, and BatchNormComponent::stats_sumsq_.

Referenced by ModelCollapser::CollapseComponentsBatchnorm().

5746  {
5747  if (scale == 0) {
5748  count_ = 0.0;
5749  stats_sum_.SetZero();
5751  } else {
5752  count_ *= scale;
5753  stats_sum_.Scale(scale);
5754  stats_sumsq_.Scale(scale);
5755  }
5756 }
void Scale(Real value)
Definition: cu-vector.cc:1105
void SetZero()
Math operations.
Definition: cu-vector.cc:988
const CuVector<BaseFloat>& Scale ( ) const
inline

Definition at line 2218 of file nnet-simple-component.h.

References BatchNormComponent::scale_.

2218 { return scale_; }
void SetTestMode ( bool  test_mode)
void StoreStats ( const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
void *  memo 
)
virtual

This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity.

It only does something for those components that have nonzero Properties()&kStoresStats.

Parameters
[in]in_valueThe input to the Propagate() function. Note: if the component sets the flag kPropagateInPlace, this should not be used; the empty matrix will be provided here if in-place propagation was used.
[in]out_valueThe output of the Propagate() function.
[in]memoThe 'memo' returned by the Propagate() function; this will usually be NULL.

Reimplemented from Component.

Definition at line 5653 of file nnet-simple-component.cc.

References CuVectorBase< Real >::AddVec(), BatchNormComponent::block_dim_, BatchNormComponent::count_, CuMatrixBase< Real >::Data(), CuVectorBase< Real >::Dim(), BatchNormComponent::dim_, KALDI_ASSERT, BatchNormComponent::Memo::mean_uvar_scale, BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), CuVector< Real >::Resize(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, CuMatrixBase< Real >::Stride(), and BatchNormComponent::test_mode_.

5656  {
5657  // in test mode this component does not store stats, it doesn't provide the
5658  // kStoresStats flag.
5660  KALDI_ASSERT(out_value.NumCols() == dim_ || out_value.NumCols() == block_dim_);
5661  if (out_value.NumCols() != block_dim_) {
5662  // if block_dim_ != dim_, we recurse; this helps keep the main code
5663  // simple.
5664  KALDI_ASSERT(out_value.Stride() == out_value.NumCols());
5665  int32 ratio = dim_ / block_dim_,
5666  orig_rows = out_value.NumRows(),
5667  orig_cols = out_value.NumCols(),
5668  new_rows = orig_rows * ratio, new_cols = orig_cols / ratio;
5669  CuSubMatrix<BaseFloat> out_value_reshaped(out_value.Data(), new_rows,
5670  new_cols, new_cols);
5671  // we'll never use in_value, so just pass it in unchanged.
5672  StoreStats(in_value, out_value_reshaped, memo_in);
5673  return;
5674  }
5675 
5676  Memo *memo = static_cast<Memo*>(memo_in);
5677  KALDI_ASSERT(out_value.NumRows() == memo->num_frames);
5678 
5679  CuSubVector<BaseFloat> mean(memo->mean_uvar_scale, 0),
5680  uvar(memo->mean_uvar_scale, 1);
5681  KALDI_ASSERT(mean.Dim() == block_dim_ && memo->num_frames > 0);
5682  BaseFloat num_frames = memo->num_frames;
5683  if (stats_sum_.Dim() != block_dim_) {
5686  KALDI_ASSERT(count_ == 0);
5687  }
5688  count_ += num_frames;
5689  stats_sum_.AddVec(num_frames, mean, 1.0);
5690  stats_sumsq_.AddVec(num_frames, uvar, 1.0);
5691 }
MatrixIndexT NumCols() const
Definition: cu-matrix.h:196
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:892
virtual void StoreStats(const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
This function may store stats on average activation values, and for some component types...
float BaseFloat
Definition: kaldi-types.h:29
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1126
MatrixIndexT Stride() const
Definition: cu-matrix.h:197
const Real * Data() const
Return data pointer (const).
Definition: cu-matrix.h:625
virtual std::string Type ( ) const
inlinevirtual

Returns a string such as "SigmoidComponent", describing the type of the object.

Implements Component.

Definition at line 2174 of file nnet-simple-component.h.

Referenced by BatchNormComponent::Info().

2174 { return "BatchNormComponent"; }
void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write component to stream.

Implements Component.

Definition at line 5718 of file nnet-simple-component.cc.

References CuVectorBase< Real >::AddVecVec(), BatchNormComponent::block_dim_, BatchNormComponent::Check(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, CuVectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, CuVector< Real >::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().

5718  {
5719  Check();
5720  WriteToken(os, binary, "<BatchNormComponent>");
5721  WriteToken(os, binary, "<Dim>");
5722  WriteBasicType(os, binary, dim_);
5723  WriteToken(os, binary, "<BlockDim>");
5724  WriteBasicType(os, binary, block_dim_);
5725  WriteToken(os, binary, "<Epsilon>");
5726  WriteBasicType(os, binary, epsilon_);
5727  WriteToken(os, binary, "<TargetRms>");
5728  WriteBasicType(os, binary, target_rms_);
5729  WriteToken(os, binary, "<TestMode>");
5730  WriteBasicType(os, binary, test_mode_);
5731  WriteToken(os, binary, "<Count>");
5732  WriteBasicType(os, binary, count_);
5733  CuVector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
5734  if (count_ != 0) {
5735  mean.Scale(1.0 / count_);
5736  var.Scale(1.0 / count_);
5737  var.AddVecVec(-1.0, mean, mean, 1.0);
5738  }
5739  WriteToken(os, binary, "<StatsMean>");
5740  mean.Write(os, binary);
5741  WriteToken(os, binary, "<StatsVar>");
5742  var.Write(os, binary);
5743  WriteToken(os, binary, "</BatchNormComponent>");
5744 }
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
void ZeroStats ( )
virtual

Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero.

Other components that store other types of statistics (e.g. regarding gradient clipping) should implement ZeroStats() also.

Reimplemented from Component.

Definition at line 5770 of file nnet-simple-component.cc.

References BatchNormComponent::count_, CuVectorBase< Real >::SetZero(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, and BatchNormComponent::test_mode_.

5770  {
5771  // We only zero the stats if we're not in test mode. In test mode, this would
5772  // be dangerous as the stats are the source for the transform, and zeroing
5773  // them and then calling ComputeDerived() again would remove the transform
5774  // parameters (offset_ and scale_).
5775  if (!test_mode_) {
5776  count_ = 0.0;
5777  stats_sum_.SetZero();
5779  }
5780 }
void SetZero()
Math operations.
Definition: cu-vector.cc:988

Member Data Documentation


The documentation for this class was generated from the following files: