All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
BatchNormComponent Class Reference

#include <nnet-normalize-component.h>

Inheritance diagram for BatchNormComponent:
Collaboration diagram for BatchNormComponent:

Classes

struct  Memo
 

Public Member Functions

 BatchNormComponent ()
 
void SetTestMode (bool test_mode)
 
 BatchNormComponent (const BatchNormComponent &other)
 
virtual int32 InputDim () const
 Returns input-dimension of this component. More...
 
virtual int32 OutputDim () const
 Returns output-dimension of this component. More...
 
virtual std::string Info () const
 Returns some text-form information about this component, for diagnostics. More...
 
virtual void InitFromConfig (ConfigLine *cfl)
 Initialize, from a ConfigLine object. More...
 
virtual std::string Type () const
 Returns a string such as "SigmoidComponent", describing the type of the object. More...
 
virtual int32 Properties () const
 Return bitmask of the component's properties. More...
 
virtual void * Propagate (const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
 Propagate function. More...
 
virtual void Backprop (const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *, CuMatrixBase< BaseFloat > *in_deriv) const
 Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update. More...
 
virtual void Read (std::istream &is, bool binary)
 Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed. More...
 
virtual void Write (std::ostream &os, bool binary) const
 Write component to stream. More...
 
virtual ComponentCopy () const
 Copies component (deep copy). More...
 
virtual void Scale (BaseFloat scale)
 This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent. More...
 
virtual void Add (BaseFloat alpha, const Component &other)
 This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters. More...
 
virtual void ZeroStats ()
 Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero. More...
 
virtual void DeleteMemo (void *memo) const
 This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function. More...
 
virtual void StoreStats (const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
 This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity. More...
 
const CuVector< BaseFloat > & Offset () const
 
const CuVector< BaseFloat > & Scale () const
 
- Public Member Functions inherited from Component
virtual void GetInputIndexes (const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual bool IsComputable (const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
 This function only does something interesting for non-simple Components, and it exists to make it possible to manage optionally-required inputs. More...
 
virtual void ReorderIndexes (std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual
ComponentPrecomputedIndexes
PrecomputeIndexes (const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
 This function must return NULL for simple Components. More...
 
 Component ()
 
virtual ~Component ()
 

Private Member Functions

void Check () const
 
void ComputeDerived ()
 

Static Private Member Functions

static void ComputeOffsetAndScale (double count, BaseFloat epsilon, const Vector< double > &stats_sum, const Vector< double > &stats_sumsq, Vector< BaseFloat > *offset, Vector< BaseFloat > *scale)
 

Private Attributes

int32 dim_
 
int32 block_dim_
 
BaseFloat epsilon_
 
BaseFloat target_rms_
 
bool test_mode_
 
double count_
 
CuVector< double > stats_sum_
 
CuVector< double > stats_sumsq_
 
CuVector< BaseFloatoffset_
 
CuVector< BaseFloatscale_
 

Additional Inherited Members

- Static Public Member Functions inherited from Component
static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream (works out its type). Dies on error. More...
 
static ComponentNewComponentOfType (const std::string &type)
 Returns a new Component of the given type e.g. More...
 

Detailed Description

Definition at line 159 of file nnet-normalize-component.h.

Constructor & Destructor Documentation

BatchNormComponent ( )
inline

Definition at line 162 of file nnet-normalize-component.h.

Referenced by BatchNormComponent::Copy().

162 { }

Member Function Documentation

void Add ( BaseFloat  alpha,
const Component other 
)
virtual

This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters.

– a NonlinearComponent (or another component that stores stats, like BatchNormComponent)– it relates to adding stats. Otherwise it will normally do nothing.

Reimplemented from Component.

Definition at line 655 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVec(), BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::stats_sum_, and BatchNormComponent::stats_sumsq_.

655  {
656  const BatchNormComponent *other =
657  dynamic_cast<const BatchNormComponent*>(&other_in);
658  count_ += alpha * other->count_;
659  stats_sum_.AddVec(alpha, other->stats_sum_);
660  stats_sumsq_.AddVec(alpha, other->stats_sumsq_);
661  // this operation might change offset_ and scale_, so we recompute them
662  // in this instance (but not in Scale()).
663  ComputeDerived();
664 }
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1182
void Backprop ( const std::string &  debug_info,
const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
const CuMatrixBase< BaseFloat > &  out_deriv,
void *  memo,
Component to_update,
CuMatrixBase< BaseFloat > *  in_deriv 
) const
virtual

Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update.

Parameters
[in]debug_infoThe component name, to be printed out in any warning messages.
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]in_valueThe matrix that was given as input to the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsInput == 0.
[in]out_valueThe matrix that was output from the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsOutput == 0
[in]out_derivThe derivative at the output of this component.
[in]memoThis will normally be NULL, but for component types that set the flag kUsesMemo, this will be the return value of the Propagate() function that corresponds to this Backprop() function. Ownership of any pointers is not transferred to the Backprop function; DeleteMemo() will be called to delete it.
[out]to_updateIf model update is desired, the Component to be updated, else NULL. Does not have to be identical to this. If supplied, you can assume that to_update->Properties() & kUpdatableComponent is nonzero.
[out]in_derivThe derivative at the input of this component, if needed (else NULL). If Properties()&kBackpropInPlace, may be the same matrix as out_deriv. If Properties()&kBackpropAdds, this is added to by the Backprop routine, else it is set. The component code chooses which mode to work in, based on convenience.

Implements Component.

Definition at line 466 of file nnet-normalize-component.cc.

References CuMatrixBase< Real >::AddMatDiagVec(), CuMatrixBase< Real >::AddVecToRows(), BatchNormComponent::block_dim_, CuMatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::Data(), CuVectorBase< Real >::Dim(), BatchNormComponent::dim_, KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, BatchNormComponent::Memo::mean_uvar_scale, CuMatrixBase< Real >::MulColsVec(), BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), BatchNormComponent::offset_, kaldi::SameDim(), BatchNormComponent::scale_, CuMatrixBase< Real >::Stride(), BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

474  {
475 
476  KALDI_ASSERT(SameDim(out_value, out_deriv) &&
477  SameDim(out_value, *in_deriv) &&
478  (out_value.NumCols() == dim_ ||
479  out_value.NumCols() == block_dim_));
480  if (out_value.NumCols() != block_dim_) {
481  // if block_dim_ != dim_, we recurse; this helps keep the main code
482  // simple.
483  KALDI_ASSERT(out_value.Stride() == out_value.NumCols() &&
484  out_deriv.Stride() == out_deriv.NumCols() &&
485  in_deriv->Stride() == in_deriv->NumCols());
486  int32 ratio = dim_ / block_dim_,
487  orig_rows = out_value.NumRows(),
488  orig_cols = out_value.NumCols(),
489  new_rows = orig_rows * ratio, new_cols = orig_cols / ratio;
490  CuSubMatrix<BaseFloat> out_value_reshaped(out_value.Data(), new_rows,
491  new_cols, new_cols),
492  out_deriv_reshaped(out_deriv.Data(), new_rows, new_cols, new_cols),
493  in_deriv_reshaped(in_deriv->Data(), new_rows, new_cols, new_cols);
494  // we'll never use in_value, so pass it in unchanged.
495  Backprop(debug_info, indexes, in_value,
496  out_value_reshaped, out_deriv_reshaped,
497  memo_in, to_update, &in_deriv_reshaped);
498  return;
499  }
500 
501  Memo *memo = static_cast<Memo*>(memo_in);
502 
503  if (!test_mode_) {
504  // search above for BACKWARD PASS for a comment describing the math.
505  KALDI_ASSERT(memo != NULL && "memo not passed into backprop");
506  int32 num_frames = memo->num_frames;
507  KALDI_ASSERT(out_value.NumRows() == num_frames);
508  CuSubVector<BaseFloat>
509  scale(memo->mean_uvar_scale, 2),
510  var_deriv_mod(memo->mean_uvar_scale, 3),
511  temp(memo->mean_uvar_scale, 4);
512 
513  // var_deriv_mod is going to contain:
514  // 2 * power * target-rms^{1/power} * (1/I \sum_i z'(i) z(i)) * scale^{-(1+power)/power}
515  // which for power = -0.5 simplifies to:
516  // -1.0 / (target_rms * target_rms).
517  // but for now we don't have the power of 'scale', we'll add that later.
518  BaseFloat coeff = -1.0 / (target_rms_ * target_rms_ * num_frames);
519 
520  var_deriv_mod.AddDiagMatMat(coeff, out_value, kTrans,
521  out_deriv, kNoTrans, 0.0);
522  var_deriv_mod.MulElements(scale);
523 
524  temp.AddRowSumMat(-1.0 / num_frames, out_deriv, 0.0);
525  // the following statement does no work if in_deriv and out_deriv are the
526  // same matrix.
527  in_deriv->CopyFromMat(out_deriv);
528  in_deriv->AddVecToRows(1.0, temp);
529  // At this point, *in_deriv contains
530  // (z'(i) - 1/I * \sum_i z'(i))
531  in_deriv->MulColsVec(scale);
532  // At this point, *in_deriv contains
533  // scale * (z'(i) - 1/I * \sum_i z'(i))
534 
535  in_deriv->AddMatDiagVec(1.0, out_value, kNoTrans,
536  var_deriv_mod, 1.0);
537 
538  // At this point, *in_deriv contains what we described in the comment
539  // starting BATCHNORM_MATH as:
540  // x'(i) = scale * (z'(i) - 1/I * \sum_i z'(i)) + z(i) var_deriv_mod
541  } else {
543  // the next call does no work if they point to the same memory.
544  in_deriv->CopyFromMat(out_deriv);
545  in_deriv->MulColsVec(scale_);
546  }
547 }
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
Definition: cu-matrix.cc:1248
MatrixIndexT NumCols() const
Definition: cu-matrix.h:215
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:339
void MulColsVec(const CuVectorBase< Real > &scale)
scale i'th column by scale[i]
Definition: cu-matrix.cc:752
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
float BaseFloat
Definition: kaldi-types.h:29
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:68
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:214
void AddMatDiagVec(const Real alpha, const CuMatrixBase< Real > &M, MatrixTransposeType transM, CuVectorBase< Real > &v, Real beta=1.0)
Definition: cu-matrix.cc:1402
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
MatrixIndexT Stride() const
Definition: cu-matrix.h:216
const Real * Data() const
Return data pointer (const).
Definition: cu-matrix.h:689
void ComputeDerived ( )
private

Definition at line 208 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::Add(), CuVectorBase< Real >::AddVecVec(), CuVectorBase< Real >::ApplyFloor(), CuVectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, CuVectorBase< Real >::CopyFromVec(), BatchNormComponent::count_, BatchNormComponent::epsilon_, KALDI_WARN, CuVectorBase< Real >::MulElements(), BatchNormComponent::offset_, CuVector< Real >::Resize(), CuVectorBase< Real >::Scale(), BatchNormComponent::scale_, CuVectorBase< Real >::SetRandn(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

Referenced by BatchNormComponent::Add(), BatchNormComponent::BatchNormComponent(), BatchNormComponent::InitFromConfig(), BatchNormComponent::Read(), and BatchNormComponent::SetTestMode().

208  {
209  if (!test_mode_) {
210  offset_.Resize(0);
211  scale_.Resize(0);
212  return;
213  }
214 
215  if (count_ == 0.0) {
216  KALDI_WARN << "Test-mode is set but there is no data count. "
217  "Creating random counts. This only makes sense "
218  "in unit-tests (or compute_prob_*.0.log). If you see this "
219  "elsewhere, something is very wrong.";
220  count_ = 1.0;
224  }
225 
229  offset_.Scale(-1.0 / count_);
230  // now offset_ is -mean.
232  scale_.Scale(1.0 / count_);
233  scale_.AddVecVec(-1.0, offset_, offset_, 1.0);
234  // now scale_ is variance.
235  // Mathematically the ApplyFloor statement should be a no-op; this is in case
236  // of numerical roundoff.
237  scale_.ApplyFloor(0.0);
239  BaseFloat power = -0.5;
240  scale_.ApplyPow(power);
241  // now scale_ = min(variance, epsilon)^power
242  // next, multiply by the target RMS (normally 1.0).
245  // now offset_ is -(scale*mean).
246 }
void Scale(Real value)
Definition: cu-vector.cc:1161
void AddVecVec(Real alpha, const CuVectorBase< Real > &v, const CuVectorBase< Real > &r, Real beta)
Definition: cu-vector.cc:531
void MulElements(const CuVectorBase< Real > &v)
Definition: cu-vector.cc:790
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:941
float BaseFloat
Definition: kaldi-types.h:29
void ApplyPow(Real power)
Definition: cu-vector.cc:411
void Add(Real value)
Definition: cu-vector.cc:1102
void CopyFromVec(const CuVectorBase< Real > &src)
Copy functions; these will crash if the dimension do not match.
Definition: cu-vector.cc:1026
#define KALDI_WARN
Definition: kaldi-error.h:130
void ApplyFloor(Real floor_val, MatrixIndexT *floored_count=NULL)
Definition: cu-vector.cc:346
static void ComputeOffsetAndScale ( double  count,
BaseFloat  epsilon,
const Vector< double > &  stats_sum,
const Vector< double > &  stats_sumsq,
Vector< BaseFloat > *  offset,
Vector< BaseFloat > *  scale 
)
staticprivate
virtual Component* Copy ( ) const
inlinevirtual

Copies component (deep copy).

Implements Component.

Definition at line 209 of file nnet-normalize-component.h.

References BatchNormComponent::BatchNormComponent().

virtual void DeleteMemo ( void *  memo) const
inlinevirtual

This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function.

It's called by NnetComputer in cases where Propagate returns a memo but there will be no backprop to consume it.

Reimplemented from Component.

Definition at line 216 of file nnet-normalize-component.h.

216 { delete static_cast<Memo*>(memo); }
std::string Info ( ) const
virtual

Returns some text-form information about this component, for diagnostics.

Starts with the type of the component. E.g. "SigmoidComponent dim=900", although most components will have much more info.

Reimplemented from Component.

Definition at line 268 of file nnet-normalize-component.cc.

References VectorBase< Real >::AddVecVec(), VectorBase< Real >::ApplyFloor(), VectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, VectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, kaldi::nnet3::SummarizeVector(), BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, and BatchNormComponent::Type().

268  {
269  std::ostringstream stream;
270  stream << Type() << ", dim=" << dim_ << ", block-dim=" << block_dim_
271  << ", epsilon=" << epsilon_ << ", target-rms=" << target_rms_
272  << ", count=" << count_
273  << ", test-mode=" << (test_mode_ ? "true" : "false");
274  if (count_ > 0) {
275  Vector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
276  mean.Scale(1.0 / count_);
277  var.Scale(1.0 / count_);
278  // subtract mean^2 from var.
279  var.AddVecVec(-1.0, mean, mean, 1.0);
280  var.ApplyFloor(0.0);
281  var.ApplyPow(0.5); // make it the stddev.
282  stream << ", data-mean=" << SummarizeVector(mean)
283  << ", data-stddev=" << SummarizeVector(var);
284  }
285  return stream.str();
286 }
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
std::string SummarizeVector(const VectorBase< BaseFloat > &vec)
Returns a string that summarizes a vector fairly succintly, for printing stats in info lines...
Definition: nnet-parse.cc:484
void InitFromConfig ( ConfigLine cfl)
virtual

Initialize, from a ConfigLine object.

Parameters
[in]cflA ConfigLine containing any parameters that are needed for initialization. For example: "dim=100 param-stddev=0.1"

Implements Component.

Definition at line 288 of file nnet-normalize-component.cc.

References BatchNormComponent::block_dim_, BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, ConfigLine::GetValue(), ConfigLine::HasUnusedValues(), KALDI_ERR, CuVector< Real >::Resize(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, and ConfigLine::UnusedValues().

288  {
289  dim_ = -1;
290  block_dim_ = -1;
291  epsilon_ = 1.0e-03;
292  target_rms_ = 1.0;
293  test_mode_ = false;
294  bool ok = cfl->GetValue("dim", &dim_);
295  cfl->GetValue("block-dim", &block_dim_);
296  cfl->GetValue("epsilon", &epsilon_);
297  cfl->GetValue("target-rms", &target_rms_);
298  cfl->GetValue("test-mode", &test_mode_);
299  if (!ok || dim_ <= 0) {
300  KALDI_ERR << "BatchNormComponent must have 'dim' specified, and > 0";
301  }
302  if (block_dim_ == -1)
303  block_dim_ = dim_;
304  if (!(block_dim_ > 0 && dim_ % block_dim_ == 0 &&
305  epsilon_ > 0 && target_rms_ > 0))
306  KALDI_ERR << "Invalid configuration in BatchNormComponent.";
307  if (cfl->HasUnusedValues())
308  KALDI_ERR << "Could not process these elements in initializer: "
309  << cfl->UnusedValues();
310  count_ = 0;
313  if (test_mode_) {
314  ComputeDerived();
315  }
316 }
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:941
#define KALDI_ERR
Definition: kaldi-error.h:127
virtual int32 InputDim ( ) const
inlinevirtual

Returns input-dimension of this component.

Implements Component.

Definition at line 176 of file nnet-normalize-component.h.

References BatchNormComponent::dim_.

const CuVector<BaseFloat>& Offset ( ) const
inline

Definition at line 224 of file nnet-normalize-component.h.

References BatchNormComponent::offset_.

Referenced by ModelCollapser::CollapseComponentsBatchnorm().

224 { return offset_; }
virtual int32 OutputDim ( ) const
inlinevirtual

Returns output-dimension of this component.

Implements Component.

Definition at line 177 of file nnet-normalize-component.h.

References BatchNormComponent::dim_.

void * Propagate ( const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
) const
virtual

Propagate function.

Parameters
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]inThe input to this component. Num-columns == InputDim().
[out]outThe output of this component. Num-columns == OutputDim(). Note: output of this component will be added to the initial value of "out" if Properties()&kPropagateAdds != 0; otherwise the output will be set and the initial value ignored. Each Component chooses whether it is more convenient implementation-wise to add or set, and the calling code has to deal with it.
Returns
Normally returns NULL, but may return a non-NULL value for components which have the flag kUsesMemo set. This value will be passed into the corresponding Backprop routine.

Implements Component.

Definition at line 400 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::Add(), CuMatrixBase< Real >::AddVecToRows(), CuVectorBase< Real >::AddVecVec(), CuVectorBase< Real >::ApplyFloor(), CuVectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, CuMatrixBase< Real >::CopyFromMat(), CuVectorBase< Real >::CopyFromVec(), BatchNormComponent::count_, CuMatrixBase< Real >::Data(), CuVectorBase< Real >::Dim(), BatchNormComponent::dim_, BatchNormComponent::epsilon_, KALDI_ASSERT, KALDI_ERR, kaldi::kTrans, BatchNormComponent::Memo::mean_uvar_scale, CuMatrixBase< Real >::MulColsVec(), BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), BatchNormComponent::offset_, CuMatrix< Real >::Resize(), kaldi::SameDim(), BatchNormComponent::scale_, CuMatrixBase< Real >::Stride(), BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

402  {
403  KALDI_ASSERT(SameDim(in, *out) &&
404  (in.NumCols() == dim_ || in.NumCols() == block_dim_));
405  if (in.NumCols() != block_dim_) {
406  // if block_dim_ != dim_, we recurse; this helps keep the main code
407  // simple.
408  KALDI_ASSERT(in.Stride() == in.NumCols() && out->Stride() == out->NumCols());
409  int32 ratio = dim_ / block_dim_, orig_rows = in.NumRows(),
410  orig_cols = in.NumCols(), new_rows = orig_rows * ratio,
411  new_cols = orig_cols / ratio;
412  CuSubMatrix<BaseFloat> in_reshaped(in.Data(), new_rows, new_cols, new_cols),
413  out_reshaped(out->Data(), new_rows, new_cols, new_cols);
414  return Propagate(indexes, in_reshaped, &out_reshaped);
415  }
416 
417  // From this point, we can assume that the num-cols of 'in' and 'out'
418  // equals block_dim_.
419 
420  if (!test_mode_) {
421  // search in the comment above for FORWARD PASS to see what is being
422  // implemented here.
423  // if this takes too much time due to multiple different CUDA calls,
424  // we'll consider making a single kernel for some of it.
425  Memo *memo = new Memo;
426  int32 num_frames = in.NumRows(), dim = block_dim_;
427  memo->num_frames = num_frames;
428  memo->mean_uvar_scale.Resize(5, dim);
429  CuSubVector<BaseFloat> mean(memo->mean_uvar_scale, 0),
430  uvar(memo->mean_uvar_scale, 1),
431  scale(memo->mean_uvar_scale, 2);
432  mean.AddRowSumMat(1.0 / num_frames, in, 0.0);
433  uvar.AddDiagMat2(1.0 / num_frames, in, kTrans, 0.0);
434  scale.CopyFromVec(uvar);
435 
436  // by applying this scale at this point, we save a multiply later on.
437  BaseFloat var_scale = 1.0 / (target_rms_ * target_rms_);
438  scale.AddVecVec(-var_scale, mean, mean, var_scale);
439  // at this point, 'scale' contains just the variance (times target-rms^{-2}).
440  scale.ApplyFloor(0.0);
441  scale.Add(var_scale * epsilon_);
442  // Now 'scale' contains the variance floored to zero and then with epsilon
443  // added [both times 1/target-rms^2].
444  scale.ApplyPow(-0.5);
445  // now 'scale' is the actual scale we'll use.
446 
447  // the next command will do no work if out == in, for in-place propagation.
448  out->CopyFromMat(in);
449  out->AddVecToRows(-1.0, mean, 1.0);
450  out->MulColsVec(scale);
451  return static_cast<void*>(memo);
452  } else {
453  if (offset_.Dim() != block_dim_) {
454  if (count_ == 0)
455  KALDI_ERR << "Test mode set in BatchNormComponent, but no stats.";
456  else // why was ComputeDerived() not called?
457  KALDI_ERR << "Code error in BatchNormComponent";
458  }
459  out->CopyFromMat(in);
460  out->MulColsVec(scale_);
461  out->AddVecToRows(1.0, offset_, 1.0);
462  return NULL;
463  }
464 }
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
Definition: cu-matrix.cc:1248
MatrixIndexT NumCols() const
Definition: cu-matrix.h:215
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:339
void MulColsVec(const CuVectorBase< Real > &scale)
scale i'th column by scale[i]
Definition: cu-matrix.cc:752
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
float BaseFloat
Definition: kaldi-types.h:29
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:68
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:214
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
Propagate function.
MatrixIndexT Stride() const
Definition: cu-matrix.h:216
const Real * Data() const
Return data pointer (const).
Definition: cu-matrix.h:689
virtual int32 Properties ( ) const
inlinevirtual

Return bitmask of the component's properties.

These properties depend only on the component's type. See enum ComponentProperties.

Implements Component.

Definition at line 182 of file nnet-normalize-component.h.

References BatchNormComponent::block_dim_, BatchNormComponent::dim_, kaldi::nnet3::kBackpropInPlace, kaldi::nnet3::kBackpropNeedsOutput, kaldi::nnet3::kInputContiguous, kaldi::nnet3::kOutputContiguous, kaldi::nnet3::kPropagateInPlace, kaldi::nnet3::kSimpleComponent, kaldi::nnet3::kStoresStats, kaldi::nnet3::kUsesMemo, and BatchNormComponent::test_mode_.

182  {
183  // If the block-dim is less than the dim, we need the input and output
184  // matrices to be contiguous (stride==num-cols), as we'll be reshaping
185  // internally. This is not much of a cost, because this will be used
186  // in convnets where we have to do this anyway.
191  }
void Read ( std::istream &  is,
bool  binary 
)
virtual

Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed.

Implements Component.

Definition at line 589 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVecVec(), BatchNormComponent::block_dim_, BatchNormComponent::Check(), BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, kaldi::nnet3::ExpectOneOrTwoTokens(), kaldi::nnet3::ExpectToken(), CuVector< Real >::Read(), kaldi::ReadBasicType(), CuVectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

589  {
590  ExpectOneOrTwoTokens(is, binary, "<BatchNormComponent>", "<Dim>");
591  ReadBasicType(is, binary, &dim_);
592  ExpectToken(is, binary, "<BlockDim>");
593  ReadBasicType(is, binary, &block_dim_);
594  ExpectToken(is, binary, "<Epsilon>");
595  ReadBasicType(is, binary, &epsilon_);
596  ExpectToken(is, binary, "<TargetRms>");
597  ReadBasicType(is, binary, &target_rms_);
598  ExpectToken(is, binary, "<TestMode>");
599  ReadBasicType(is, binary, &test_mode_);
600  ExpectToken(is, binary, "<Count>");
601  ReadBasicType(is, binary, &count_);
602  ExpectToken(is, binary, "<StatsMean>");
603  stats_sum_.Read(is, binary);
604  ExpectToken(is, binary, "<StatsVar>");
605  stats_sumsq_.Read(is, binary);
609  ExpectToken(is, binary, "</BatchNormComponent>");
610  ComputeDerived();
611  Check();
612 }
void Scale(Real value)
Definition: cu-vector.cc:1161
void AddVecVec(Real alpha, const CuVectorBase< Real > &v, const CuVectorBase< Real > &r, Real beta)
Definition: cu-vector.cc:531
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void Read(std::istream &is, bool binary)
I/O.
Definition: cu-vector.cc:911
void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)
This function is like ExpectToken but for two tokens, and it will either accept token1 and then token...
Definition: nnet-parse.cc:224
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void Scale ( BaseFloat  scale)
virtual

This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent.

– a Nonlinear component (or another component that stores stats, like BatchNormComponent)– it relates to scaling activation stats, not parameters. Otherwise it will normally do nothing.

Reimplemented from Component.

Definition at line 642 of file nnet-normalize-component.cc.

References BatchNormComponent::count_, CuVectorBase< Real >::Scale(), CuVectorBase< Real >::SetZero(), BatchNormComponent::stats_sum_, and BatchNormComponent::stats_sumsq_.

Referenced by ModelCollapser::CollapseComponentsBatchnorm(), and kaldi::nnet3::ScaleBatchnormStats().

642  {
643  if (scale == 0) {
644  count_ = 0.0;
647  } else {
648  count_ *= scale;
649  stats_sum_.Scale(scale);
650  stats_sumsq_.Scale(scale);
651  }
652 }
void Scale(Real value)
Definition: cu-vector.cc:1161
void SetZero()
Math operations.
Definition: cu-vector.cc:1044
const CuVector<BaseFloat>& Scale ( ) const
inline

Definition at line 225 of file nnet-normalize-component.h.

References BatchNormComponent::scale_.

225 { return scale_; }
void SetTestMode ( bool  test_mode)
void StoreStats ( const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
void *  memo 
)
virtual

This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity.

It only does something for those components that have nonzero Properties()&kStoresStats.

Parameters
[in]in_valueThe input to the Propagate() function. Note: if the component sets the flag kPropagateInPlace, this should not be used; the empty matrix will be provided here if in-place propagation was used.
[in]out_valueThe output of the Propagate() function.
[in]memoThe 'memo' returned by the Propagate() function; this will usually be NULL.

Reimplemented from Component.

Definition at line 549 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVec(), BatchNormComponent::block_dim_, BatchNormComponent::count_, CuMatrixBase< Real >::Data(), CuVectorBase< Real >::Dim(), BatchNormComponent::dim_, KALDI_ASSERT, BatchNormComponent::Memo::mean_uvar_scale, BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), CuVector< Real >::Resize(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, CuMatrixBase< Real >::Stride(), and BatchNormComponent::test_mode_.

552  {
553  // in test mode this component does not store stats, it doesn't provide the
554  // kStoresStats flag.
556  KALDI_ASSERT(out_value.NumCols() == dim_ || out_value.NumCols() == block_dim_);
557  if (out_value.NumCols() != block_dim_) {
558  // if block_dim_ != dim_, we recurse; this helps keep the main code
559  // simple.
560  KALDI_ASSERT(out_value.Stride() == out_value.NumCols());
561  int32 ratio = dim_ / block_dim_,
562  orig_rows = out_value.NumRows(),
563  orig_cols = out_value.NumCols(),
564  new_rows = orig_rows * ratio, new_cols = orig_cols / ratio;
565  CuSubMatrix<BaseFloat> out_value_reshaped(out_value.Data(), new_rows,
566  new_cols, new_cols);
567  // we'll never use in_value, so just pass it in unchanged.
568  StoreStats(in_value, out_value_reshaped, memo_in);
569  return;
570  }
571 
572  Memo *memo = static_cast<Memo*>(memo_in);
573  KALDI_ASSERT(out_value.NumRows() == memo->num_frames);
574 
575  CuSubVector<BaseFloat> mean(memo->mean_uvar_scale, 0),
576  uvar(memo->mean_uvar_scale, 1);
577  KALDI_ASSERT(mean.Dim() == block_dim_ && memo->num_frames > 0);
578  BaseFloat num_frames = memo->num_frames;
579  if (stats_sum_.Dim() != block_dim_) {
582  KALDI_ASSERT(count_ == 0);
583  }
584  count_ += num_frames;
585  stats_sum_.AddVec(num_frames, mean, 1.0);
586  stats_sumsq_.AddVec(num_frames, uvar, 1.0);
587 }
MatrixIndexT NumCols() const
Definition: cu-matrix.h:215
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:941
virtual void StoreStats(const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
This function may store stats on average activation values, and for some component types...
float BaseFloat
Definition: kaldi-types.h:29
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:68
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:214
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1182
MatrixIndexT Stride() const
Definition: cu-matrix.h:216
const Real * Data() const
Return data pointer (const).
Definition: cu-matrix.h:689
virtual std::string Type ( ) const
inlinevirtual

Returns a string such as "SigmoidComponent", describing the type of the object.

Implements Component.

Definition at line 181 of file nnet-normalize-component.h.

Referenced by BatchNormComponent::Info().

181 { return "BatchNormComponent"; }
void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write component to stream.

Implements Component.

Definition at line 614 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVecVec(), BatchNormComponent::block_dim_, BatchNormComponent::Check(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, CuVectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, CuVector< Real >::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().

614  {
615  Check();
616  WriteToken(os, binary, "<BatchNormComponent>");
617  WriteToken(os, binary, "<Dim>");
618  WriteBasicType(os, binary, dim_);
619  WriteToken(os, binary, "<BlockDim>");
620  WriteBasicType(os, binary, block_dim_);
621  WriteToken(os, binary, "<Epsilon>");
622  WriteBasicType(os, binary, epsilon_);
623  WriteToken(os, binary, "<TargetRms>");
624  WriteBasicType(os, binary, target_rms_);
625  WriteToken(os, binary, "<TestMode>");
626  WriteBasicType(os, binary, test_mode_);
627  WriteToken(os, binary, "<Count>");
628  WriteBasicType(os, binary, count_);
629  CuVector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
630  if (count_ != 0) {
631  mean.Scale(1.0 / count_);
632  var.Scale(1.0 / count_);
633  var.AddVecVec(-1.0, mean, mean, 1.0);
634  }
635  WriteToken(os, binary, "<StatsMean>");
636  mean.Write(os, binary);
637  WriteToken(os, binary, "<StatsVar>");
638  var.Write(os, binary);
639  WriteToken(os, binary, "</BatchNormComponent>");
640 }
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
void ZeroStats ( )
virtual

Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero.

Other components that store other types of statistics (e.g. regarding gradient clipping) should implement ZeroStats() also.

Reimplemented from Component.

Definition at line 666 of file nnet-normalize-component.cc.

References BatchNormComponent::count_, CuVectorBase< Real >::SetZero(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, and BatchNormComponent::test_mode_.

666  {
667  // We only zero the stats if we're not in test mode. In test mode, this would
668  // be dangerous as the stats are the source for the transform, and zeroing
669  // them and then calling ComputeDerived() again would remove the transform
670  // parameters (offset_ and scale_).
671  if (!test_mode_) {
672  count_ = 0.0;
675  }
676 }
void SetZero()
Math operations.
Definition: cu-vector.cc:1044

Member Data Documentation


The documentation for this class was generated from the following files: