All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
BatchNormComponent Class Reference

#include <nnet-normalize-component.h>

Inheritance diagram for BatchNormComponent:
Collaboration diagram for BatchNormComponent:

Classes

struct  Memo
 

Public Member Functions

 BatchNormComponent ()
 
void SetTestMode (bool test_mode)
 
 BatchNormComponent (const BatchNormComponent &other)
 
virtual int32 InputDim () const
 Returns input-dimension of this component. More...
 
virtual int32 OutputDim () const
 Returns output-dimension of this component. More...
 
virtual std::string Info () const
 Returns some text-form information about this component, for diagnostics. More...
 
virtual void InitFromConfig (ConfigLine *cfl)
 Initialize, from a ConfigLine object. More...
 
virtual std::string Type () const
 Returns a string such as "SigmoidComponent", describing the type of the object. More...
 
virtual int32 Properties () const
 Return bitmask of the component's properties. More...
 
virtual void * Propagate (const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
 Propagate function. More...
 
virtual void Backprop (const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *, CuMatrixBase< BaseFloat > *in_deriv) const
 Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update. More...
 
virtual void Read (std::istream &is, bool binary)
 Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed. More...
 
virtual void Write (std::ostream &os, bool binary) const
 Write component to stream. More...
 
virtual ComponentCopy () const
 Copies component (deep copy). More...
 
virtual void Scale (BaseFloat scale)
 This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent. More...
 
virtual void Add (BaseFloat alpha, const Component &other)
 This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters. More...
 
virtual void ZeroStats ()
 Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero. More...
 
virtual void DeleteMemo (void *memo) const
 This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function. More...
 
virtual void StoreStats (const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
 This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity. More...
 
const CuVector< BaseFloat > & Offset () const
 
const CuVector< BaseFloat > & Scale () const
 
- Public Member Functions inherited from Component
virtual void GetInputIndexes (const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual bool IsComputable (const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
 This function only does something interesting for non-simple Components, and it exists to make it possible to manage optionally-required inputs. More...
 
virtual void ReorderIndexes (std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual
ComponentPrecomputedIndexes
PrecomputeIndexes (const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
 This function must return NULL for simple Components. More...
 
 Component ()
 
virtual ~Component ()
 

Private Member Functions

void Check () const
 
void ComputeDerived ()
 

Static Private Member Functions

static void ComputeOffsetAndScale (double count, BaseFloat epsilon, const Vector< double > &stats_sum, const Vector< double > &stats_sumsq, Vector< BaseFloat > *offset, Vector< BaseFloat > *scale)
 

Private Attributes

int32 dim_
 
int32 block_dim_
 
BaseFloat epsilon_
 
BaseFloat target_rms_
 
bool test_mode_
 
double count_
 
CuVector< double > stats_sum_
 
CuVector< double > stats_sumsq_
 
CuVector< BaseFloatoffset_
 
CuVector< BaseFloatscale_
 

Additional Inherited Members

- Static Public Member Functions inherited from Component
static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream (works out its type). Dies on error. More...
 
static ComponentNewComponentOfType (const std::string &type)
 Returns a new Component of the given type e.g. More...
 

Detailed Description

Definition at line 159 of file nnet-normalize-component.h.

Constructor & Destructor Documentation

BatchNormComponent ( )
inline

Definition at line 162 of file nnet-normalize-component.h.

Referenced by BatchNormComponent::Copy().

162 { }

Member Function Documentation

void Add ( BaseFloat  alpha,
const Component other 
)
virtual

This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters.

– a NonlinearComponent (or another component that stores stats, like BatchNormComponent)– it relates to adding stats. Otherwise it will normally do nothing.

Reimplemented from Component.

Definition at line 653 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVec(), BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::stats_sum_, and BatchNormComponent::stats_sumsq_.

653  {
654  const BatchNormComponent *other =
655  dynamic_cast<const BatchNormComponent*>(&other_in);
656  count_ += alpha * other->count_;
657  stats_sum_.AddVec(alpha, other->stats_sum_);
658  stats_sumsq_.AddVec(alpha, other->stats_sumsq_);
659  // this operation might change offset_ and scale_, so we recompute them
660  // in this instance (but not in Scale()).
661  ComputeDerived();
662 }
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1182
void Backprop ( const std::string &  debug_info,
const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
const CuMatrixBase< BaseFloat > &  out_deriv,
void *  memo,
Component to_update,
CuMatrixBase< BaseFloat > *  in_deriv 
) const
virtual

Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update.

Parameters
[in]debug_infoThe component name, to be printed out in any warning messages.
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]in_valueThe matrix that was given as input to the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsInput == 0.
[in]out_valueThe matrix that was output from the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsOutput == 0
[in]out_derivThe derivative at the output of this component.
[in]memoThis will normally be NULL, but for component types that set the flag kUsesMemo, this will be the return value of the Propagate() function that corresponds to this Backprop() function. Ownership of any pointers is not transferred to the Backprop function; DeleteMemo() will be called to delete it.
[out]to_updateIf model update is desired, the Component to be updated, else NULL. Does not have to be identical to this. If supplied, you can assume that to_update->Properties() & kUpdatableComponent is nonzero.
[out]in_derivThe derivative at the input of this component, if needed (else NULL). If Properties()&kBackpropInPlace, may be the same matrix as out_deriv. If Properties()&kBackpropAdds, this is added to by the Backprop routine, else it is set. The component code chooses which mode to work in, based on convenience.

Implements Component.

Definition at line 464 of file nnet-normalize-component.cc.

References CuMatrixBase< Real >::AddMatDiagVec(), CuMatrixBase< Real >::AddVecToRows(), BatchNormComponent::block_dim_, CuMatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::Data(), CuVectorBase< Real >::Dim(), BatchNormComponent::dim_, KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, BatchNormComponent::Memo::mean_uvar_scale, CuMatrixBase< Real >::MulColsVec(), BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), BatchNormComponent::offset_, kaldi::SameDim(), BatchNormComponent::scale_, CuMatrixBase< Real >::Stride(), BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

472  {
473 
474  KALDI_ASSERT(SameDim(out_value, out_deriv) &&
475  SameDim(out_value, *in_deriv) &&
476  (out_value.NumCols() == dim_ ||
477  out_value.NumCols() == block_dim_));
478  if (out_value.NumCols() != block_dim_) {
479  // if block_dim_ != dim_, we recurse; this helps keep the main code
480  // simple.
481  KALDI_ASSERT(out_value.Stride() == out_value.NumCols() &&
482  out_deriv.Stride() == out_deriv.NumCols() &&
483  in_deriv->Stride() == in_deriv->NumCols());
484  int32 ratio = dim_ / block_dim_,
485  orig_rows = out_value.NumRows(),
486  orig_cols = out_value.NumCols(),
487  new_rows = orig_rows * ratio, new_cols = orig_cols / ratio;
488  CuSubMatrix<BaseFloat> out_value_reshaped(out_value.Data(), new_rows,
489  new_cols, new_cols),
490  out_deriv_reshaped(out_deriv.Data(), new_rows, new_cols, new_cols),
491  in_deriv_reshaped(in_deriv->Data(), new_rows, new_cols, new_cols);
492  // we'll never use in_value, so pass it in unchanged.
493  Backprop(debug_info, indexes, in_value,
494  out_value_reshaped, out_deriv_reshaped,
495  memo_in, to_update, &in_deriv_reshaped);
496  return;
497  }
498 
499  Memo *memo = static_cast<Memo*>(memo_in);
500 
501  if (!test_mode_) {
502  // search above for BACKWARD PASS for a comment describing the math.
503  KALDI_ASSERT(memo != NULL && "memo not passed into backprop");
504  int32 num_frames = memo->num_frames;
505  KALDI_ASSERT(out_value.NumRows() == num_frames);
506  CuSubVector<BaseFloat>
507  scale(memo->mean_uvar_scale, 2),
508  var_deriv_mod(memo->mean_uvar_scale, 3),
509  temp(memo->mean_uvar_scale, 4);
510 
511  // var_deriv_mod is going to contain:
512  // 2 * power * target-rms^{1/power} * (1/I \sum_i z'(i) z(i)) * scale^{-(1+power)/power}
513  // which for power = -0.5 simplifies to:
514  // -1.0 / (target_rms * target_rms).
515  // but for now we don't have the power of 'scale', we'll add that later.
516  BaseFloat coeff = -1.0 / (target_rms_ * target_rms_ * num_frames);
517 
518  var_deriv_mod.AddDiagMatMat(coeff, out_value, kTrans,
519  out_deriv, kNoTrans, 0.0);
520  var_deriv_mod.MulElements(scale);
521 
522  temp.AddRowSumMat(-1.0 / num_frames, out_deriv, 0.0);
523  // the following statement does no work if in_deriv and out_deriv are the
524  // same matrix.
525  in_deriv->CopyFromMat(out_deriv);
526  in_deriv->AddVecToRows(1.0, temp);
527  // At this point, *in_deriv contains
528  // (z'(i) - 1/I * \sum_i z'(i))
529  in_deriv->MulColsVec(scale);
530  // At this point, *in_deriv contains
531  // scale * (z'(i) - 1/I * \sum_i z'(i))
532 
533  in_deriv->AddMatDiagVec(1.0, out_value, kNoTrans,
534  var_deriv_mod, 1.0);
535 
536  // At this point, *in_deriv contains what we described in the comment
537  // starting BATCHNORM_MATH as:
538  // x'(i) = scale * (z'(i) - 1/I * \sum_i z'(i)) + z(i) var_deriv_mod
539  } else {
541  // the next call does no work if they point to the same memory.
542  in_deriv->CopyFromMat(out_deriv);
543  in_deriv->MulColsVec(scale_);
544  }
545 }
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
Definition: cu-matrix.cc:1248
MatrixIndexT NumCols() const
Definition: cu-matrix.h:206
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:339
void MulColsVec(const CuVectorBase< Real > &scale)
scale i'th column by scale[i]
Definition: cu-matrix.cc:752
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
float BaseFloat
Definition: kaldi-types.h:29
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:68
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:205
void AddMatDiagVec(const Real alpha, const CuMatrixBase< Real > &M, MatrixTransposeType transM, CuVectorBase< Real > &v, Real beta=1.0)
Definition: cu-matrix.cc:1402
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
MatrixIndexT Stride() const
Definition: cu-matrix.h:207
const Real * Data() const
Return data pointer (const).
Definition: cu-matrix.h:673
void ComputeDerived ( )
private

Definition at line 206 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::Add(), CuVectorBase< Real >::AddVecVec(), CuVectorBase< Real >::ApplyFloor(), CuVectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, CuVectorBase< Real >::CopyFromVec(), BatchNormComponent::count_, BatchNormComponent::epsilon_, KALDI_WARN, CuVectorBase< Real >::MulElements(), BatchNormComponent::offset_, CuVector< Real >::Resize(), CuVectorBase< Real >::Scale(), BatchNormComponent::scale_, CuVectorBase< Real >::SetRandn(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

Referenced by BatchNormComponent::Add(), BatchNormComponent::BatchNormComponent(), BatchNormComponent::InitFromConfig(), BatchNormComponent::Read(), and BatchNormComponent::SetTestMode().

206  {
207  if (!test_mode_) {
208  offset_.Resize(0);
209  scale_.Resize(0);
210  return;
211  }
212 
213  if (count_ == 0.0) {
214  KALDI_WARN << "Test-mode is set but there is no data count. "
215  "Creating random counts. This only makes sense "
216  "in unit-tests (or compute_prob_*.0.log). If you see this "
217  "elsewhere, something is very wrong.";
218  count_ = 1.0;
222  }
223 
227  offset_.Scale(-1.0 / count_);
228  // now offset_ is -mean.
230  scale_.Scale(1.0 / count_);
231  scale_.AddVecVec(-1.0, offset_, offset_, 1.0);
232  // now scale_ is variance.
233  // Mathematically the ApplyFloor statement should be a no-op; this is in case
234  // of numerical roundoff.
235  scale_.ApplyFloor(0.0);
237  BaseFloat power = -0.5;
238  scale_.ApplyPow(power);
239  // now scale_ = min(variance, epsilon)^power
240  // next, multiply by the target RMS (normally 1.0).
243  // now offset_ is -(scale*mean).
244 }
void Scale(Real value)
Definition: cu-vector.cc:1161
void AddVecVec(Real alpha, const CuVectorBase< Real > &v, const CuVectorBase< Real > &r, Real beta)
Definition: cu-vector.cc:531
void MulElements(const CuVectorBase< Real > &v)
Definition: cu-vector.cc:790
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:941
float BaseFloat
Definition: kaldi-types.h:29
void ApplyPow(Real power)
Definition: cu-vector.cc:411
void Add(Real value)
Definition: cu-vector.cc:1102
void CopyFromVec(const CuVectorBase< Real > &src)
Copy functions; these will crash if the dimension do not match.
Definition: cu-vector.cc:1026
#define KALDI_WARN
Definition: kaldi-error.h:130
void ApplyFloor(Real floor_val, MatrixIndexT *floored_count=NULL)
Definition: cu-vector.cc:346
static void ComputeOffsetAndScale ( double  count,
BaseFloat  epsilon,
const Vector< double > &  stats_sum,
const Vector< double > &  stats_sumsq,
Vector< BaseFloat > *  offset,
Vector< BaseFloat > *  scale 
)
staticprivate
virtual Component* Copy ( ) const
inlinevirtual

Copies component (deep copy).

Implements Component.

Definition at line 209 of file nnet-normalize-component.h.

References BatchNormComponent::BatchNormComponent().

virtual void DeleteMemo ( void *  memo) const
inlinevirtual

This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function.

It's called by NnetComputer in cases where Propagate returns a memo but there will be no backprop to consume it.

Reimplemented from Component.

Definition at line 216 of file nnet-normalize-component.h.

216 { delete static_cast<Memo*>(memo); }
std::string Info ( ) const
virtual

Returns some text-form information about this component, for diagnostics.

Starts with the type of the component. E.g. "SigmoidComponent dim=900", although most components will have much more info.

Reimplemented from Component.

Definition at line 266 of file nnet-normalize-component.cc.

References VectorBase< Real >::AddVecVec(), VectorBase< Real >::ApplyFloor(), VectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, VectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, kaldi::nnet3::SummarizeVector(), BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, and BatchNormComponent::Type().

266  {
267  std::ostringstream stream;
268  stream << Type() << ", dim=" << dim_ << ", block-dim=" << block_dim_
269  << ", epsilon=" << epsilon_ << ", target-rms=" << target_rms_
270  << ", count=" << count_
271  << ", test-mode=" << (test_mode_ ? "true" : "false");
272  if (count_ > 0) {
273  Vector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
274  mean.Scale(1.0 / count_);
275  var.Scale(1.0 / count_);
276  // subtract mean^2 from var.
277  var.AddVecVec(-1.0, mean, mean, 1.0);
278  var.ApplyFloor(0.0);
279  var.ApplyPow(0.5); // make it the stddev.
280  stream << ", data-mean=" << SummarizeVector(mean)
281  << ", data-stddev=" << SummarizeVector(var);
282  }
283  return stream.str();
284 }
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
std::string SummarizeVector(const VectorBase< BaseFloat > &vec)
Returns a string that summarizes a vector fairly succintly, for printing stats in info lines...
Definition: nnet-parse.cc:484
void InitFromConfig ( ConfigLine cfl)
virtual

Initialize, from a ConfigLine object.

Parameters
[in]cflA ConfigLine containing any parameters that are needed for initialization. For example: "dim=100 param-stddev=0.1"

Implements Component.

Definition at line 286 of file nnet-normalize-component.cc.

References BatchNormComponent::block_dim_, BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, ConfigLine::GetValue(), ConfigLine::HasUnusedValues(), KALDI_ERR, CuVector< Real >::Resize(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, and ConfigLine::UnusedValues().

286  {
287  dim_ = -1;
288  block_dim_ = -1;
289  epsilon_ = 1.0e-03;
290  target_rms_ = 1.0;
291  test_mode_ = false;
292  bool ok = cfl->GetValue("dim", &dim_);
293  cfl->GetValue("block-dim", &block_dim_);
294  cfl->GetValue("epsilon", &epsilon_);
295  cfl->GetValue("target-rms", &target_rms_);
296  cfl->GetValue("test-mode", &test_mode_);
297  if (!ok || dim_ <= 0) {
298  KALDI_ERR << "BatchNormComponent must have 'dim' specified, and > 0";
299  }
300  if (block_dim_ == -1)
301  block_dim_ = dim_;
302  if (!(block_dim_ > 0 && dim_ % block_dim_ == 0 &&
303  epsilon_ > 0 && target_rms_ > 0))
304  KALDI_ERR << "Invalid configuration in BatchNormComponent.";
305  if (cfl->HasUnusedValues())
306  KALDI_ERR << "Could not process these elements in initializer: "
307  << cfl->UnusedValues();
308  count_ = 0;
311  if (test_mode_) {
312  ComputeDerived();
313  }
314 }
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:941
#define KALDI_ERR
Definition: kaldi-error.h:127
virtual int32 InputDim ( ) const
inlinevirtual

Returns input-dimension of this component.

Implements Component.

Definition at line 176 of file nnet-normalize-component.h.

References BatchNormComponent::dim_.

const CuVector<BaseFloat>& Offset ( ) const
inline

Definition at line 224 of file nnet-normalize-component.h.

References BatchNormComponent::offset_.

Referenced by ModelCollapser::CollapseComponentsBatchnorm().

224 { return offset_; }
virtual int32 OutputDim ( ) const
inlinevirtual

Returns output-dimension of this component.

Implements Component.

Definition at line 177 of file nnet-normalize-component.h.

References BatchNormComponent::dim_.

void * Propagate ( const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
) const
virtual

Propagate function.

Parameters
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]inThe input to this component. Num-columns == InputDim().
[out]outThe output of this component. Num-columns == OutputDim(). Note: output of this component will be added to the initial value of "out" if Properties()&kPropagateAdds != 0; otherwise the output will be set and the initial value ignored. Each Component chooses whether it is more convenient implementation-wise to add or set, and the calling code has to deal with it.
Returns
Normally returns NULL, but may return a non-NULL value for components which have the flag kUsesMemo set. This value will be passed into the corresponding Backprop routine.

Implements Component.

Definition at line 398 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::Add(), CuMatrixBase< Real >::AddVecToRows(), CuVectorBase< Real >::AddVecVec(), CuVectorBase< Real >::ApplyFloor(), CuVectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, CuMatrixBase< Real >::CopyFromMat(), CuVectorBase< Real >::CopyFromVec(), BatchNormComponent::count_, CuMatrixBase< Real >::Data(), CuVectorBase< Real >::Dim(), BatchNormComponent::dim_, BatchNormComponent::epsilon_, KALDI_ASSERT, KALDI_ERR, kaldi::kTrans, BatchNormComponent::Memo::mean_uvar_scale, CuMatrixBase< Real >::MulColsVec(), BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), BatchNormComponent::offset_, CuMatrix< Real >::Resize(), kaldi::SameDim(), BatchNormComponent::scale_, CuMatrixBase< Real >::Stride(), BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

400  {
401  KALDI_ASSERT(SameDim(in, *out) &&
402  (in.NumCols() == dim_ || in.NumCols() == block_dim_));
403  if (in.NumCols() != block_dim_) {
404  // if block_dim_ != dim_, we recurse; this helps keep the main code
405  // simple.
406  KALDI_ASSERT(in.Stride() == in.NumCols() && out->Stride() == out->NumCols());
407  int32 ratio = dim_ / block_dim_, orig_rows = in.NumRows(),
408  orig_cols = in.NumCols(), new_rows = orig_rows * ratio,
409  new_cols = orig_cols / ratio;
410  CuSubMatrix<BaseFloat> in_reshaped(in.Data(), new_rows, new_cols, new_cols),
411  out_reshaped(out->Data(), new_rows, new_cols, new_cols);
412  return Propagate(indexes, in_reshaped, &out_reshaped);
413  }
414 
415  // From this point, we can assume that the num-cols of 'in' and 'out'
416  // equals block_dim_.
417 
418  if (!test_mode_) {
419  // search in the comment above for FORWARD PASS to see what is being
420  // implemented here.
421  // if this takes too much time due to multiple different CUDA calls,
422  // we'll consider making a single kernel for some of it.
423  Memo *memo = new Memo;
424  int32 num_frames = in.NumRows(), dim = block_dim_;
425  memo->num_frames = num_frames;
426  memo->mean_uvar_scale.Resize(5, dim);
427  CuSubVector<BaseFloat> mean(memo->mean_uvar_scale, 0),
428  uvar(memo->mean_uvar_scale, 1),
429  scale(memo->mean_uvar_scale, 2);
430  mean.AddRowSumMat(1.0 / num_frames, in, 0.0);
431  uvar.AddDiagMat2(1.0 / num_frames, in, kTrans, 0.0);
432  scale.CopyFromVec(uvar);
433 
434  // by applying this scale at this point, we save a multiply later on.
435  BaseFloat var_scale = 1.0 / (target_rms_ * target_rms_);
436  scale.AddVecVec(-var_scale, mean, mean, var_scale);
437  // at this point, 'scale' contains just the variance (times target-rms^{-2}).
438  scale.ApplyFloor(0.0);
439  scale.Add(var_scale * epsilon_);
440  // Now 'scale' contains the variance floored to zero and then with epsilon
441  // added [both times 1/target-rms^2].
442  scale.ApplyPow(-0.5);
443  // now 'scale' is the actual scale we'll use.
444 
445  // the next command will do no work if out == in, for in-place propagation.
446  out->CopyFromMat(in);
447  out->AddVecToRows(-1.0, mean, 1.0);
448  out->MulColsVec(scale);
449  return static_cast<void*>(memo);
450  } else {
451  if (offset_.Dim() != block_dim_) {
452  if (count_ == 0)
453  KALDI_ERR << "Test mode set in BatchNormComponent, but no stats.";
454  else // why was ComputeDerived() not called?
455  KALDI_ERR << "Code error in BatchNormComponent";
456  }
457  out->CopyFromMat(in);
458  out->MulColsVec(scale_);
459  out->AddVecToRows(1.0, offset_, 1.0);
460  return NULL;
461  }
462 }
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
Definition: cu-matrix.cc:1248
MatrixIndexT NumCols() const
Definition: cu-matrix.h:206
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:339
void MulColsVec(const CuVectorBase< Real > &scale)
scale i'th column by scale[i]
Definition: cu-matrix.cc:752
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
float BaseFloat
Definition: kaldi-types.h:29
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:68
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:205
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
Propagate function.
MatrixIndexT Stride() const
Definition: cu-matrix.h:207
const Real * Data() const
Return data pointer (const).
Definition: cu-matrix.h:673
virtual int32 Properties ( ) const
inlinevirtual

Return bitmask of the component's properties.

These properties depend only on the component's type. See enum ComponentProperties.

Implements Component.

Definition at line 182 of file nnet-normalize-component.h.

References BatchNormComponent::block_dim_, BatchNormComponent::dim_, kaldi::nnet3::kBackpropInPlace, kaldi::nnet3::kBackpropNeedsOutput, kaldi::nnet3::kInputContiguous, kaldi::nnet3::kOutputContiguous, kaldi::nnet3::kPropagateInPlace, kaldi::nnet3::kSimpleComponent, kaldi::nnet3::kStoresStats, kaldi::nnet3::kUsesMemo, and BatchNormComponent::test_mode_.

182  {
183  // If the block-dim is less than the dim, we need the input and output
184  // matrices to be contiguous (stride==num-cols), as we'll be reshaping
185  // internally. This is not much of a cost, because this will be used
186  // in convnets where we have to do this anyway.
191  }
void Read ( std::istream &  is,
bool  binary 
)
virtual

Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed.

Implements Component.

Definition at line 587 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVecVec(), BatchNormComponent::block_dim_, BatchNormComponent::Check(), BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, kaldi::nnet3::ExpectOneOrTwoTokens(), kaldi::nnet3::ExpectToken(), CuVector< Real >::Read(), kaldi::ReadBasicType(), CuVectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

587  {
588  ExpectOneOrTwoTokens(is, binary, "<BatchNormComponent>", "<Dim>");
589  ReadBasicType(is, binary, &dim_);
590  ExpectToken(is, binary, "<BlockDim>");
591  ReadBasicType(is, binary, &block_dim_);
592  ExpectToken(is, binary, "<Epsilon>");
593  ReadBasicType(is, binary, &epsilon_);
594  ExpectToken(is, binary, "<TargetRms>");
595  ReadBasicType(is, binary, &target_rms_);
596  ExpectToken(is, binary, "<TestMode>");
597  ReadBasicType(is, binary, &test_mode_);
598  ExpectToken(is, binary, "<Count>");
599  ReadBasicType(is, binary, &count_);
600  ExpectToken(is, binary, "<StatsMean>");
601  stats_sum_.Read(is, binary);
602  ExpectToken(is, binary, "<StatsVar>");
603  stats_sumsq_.Read(is, binary);
607  ExpectToken(is, binary, "</BatchNormComponent>");
608  ComputeDerived();
609  Check();
610 }
void Scale(Real value)
Definition: cu-vector.cc:1161
void AddVecVec(Real alpha, const CuVectorBase< Real > &v, const CuVectorBase< Real > &r, Real beta)
Definition: cu-vector.cc:531
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void Read(std::istream &is, bool binary)
I/O.
Definition: cu-vector.cc:911
void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)
This function is like ExpectToken but for two tokens, and it will either accept token1 and then token...
Definition: nnet-parse.cc:224
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void Scale ( BaseFloat  scale)
virtual

This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent.

– a Nonlinear component (or another component that stores stats, like BatchNormComponent)– it relates to scaling activation stats, not parameters. Otherwise it will normally do nothing.

Reimplemented from Component.

Definition at line 640 of file nnet-normalize-component.cc.

References BatchNormComponent::count_, CuVectorBase< Real >::Scale(), CuVectorBase< Real >::SetZero(), BatchNormComponent::stats_sum_, and BatchNormComponent::stats_sumsq_.

Referenced by ModelCollapser::CollapseComponentsBatchnorm(), and kaldi::nnet3::ScaleBatchnormStats().

640  {
641  if (scale == 0) {
642  count_ = 0.0;
645  } else {
646  count_ *= scale;
647  stats_sum_.Scale(scale);
648  stats_sumsq_.Scale(scale);
649  }
650 }
void Scale(Real value)
Definition: cu-vector.cc:1161
void SetZero()
Math operations.
Definition: cu-vector.cc:1044
const CuVector<BaseFloat>& Scale ( ) const
inline

Definition at line 225 of file nnet-normalize-component.h.

References BatchNormComponent::scale_.

225 { return scale_; }
void SetTestMode ( bool  test_mode)
void StoreStats ( const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
void *  memo 
)
virtual

This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity.

It only does something for those components that have nonzero Properties()&kStoresStats.

Parameters
[in]in_valueThe input to the Propagate() function. Note: if the component sets the flag kPropagateInPlace, this should not be used; the empty matrix will be provided here if in-place propagation was used.
[in]out_valueThe output of the Propagate() function.
[in]memoThe 'memo' returned by the Propagate() function; this will usually be NULL.

Reimplemented from Component.

Definition at line 547 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVec(), BatchNormComponent::block_dim_, BatchNormComponent::count_, CuMatrixBase< Real >::Data(), CuVectorBase< Real >::Dim(), BatchNormComponent::dim_, KALDI_ASSERT, BatchNormComponent::Memo::mean_uvar_scale, BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), CuVector< Real >::Resize(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, CuMatrixBase< Real >::Stride(), and BatchNormComponent::test_mode_.

550  {
551  // in test mode this component does not store stats, it doesn't provide the
552  // kStoresStats flag.
554  KALDI_ASSERT(out_value.NumCols() == dim_ || out_value.NumCols() == block_dim_);
555  if (out_value.NumCols() != block_dim_) {
556  // if block_dim_ != dim_, we recurse; this helps keep the main code
557  // simple.
558  KALDI_ASSERT(out_value.Stride() == out_value.NumCols());
559  int32 ratio = dim_ / block_dim_,
560  orig_rows = out_value.NumRows(),
561  orig_cols = out_value.NumCols(),
562  new_rows = orig_rows * ratio, new_cols = orig_cols / ratio;
563  CuSubMatrix<BaseFloat> out_value_reshaped(out_value.Data(), new_rows,
564  new_cols, new_cols);
565  // we'll never use in_value, so just pass it in unchanged.
566  StoreStats(in_value, out_value_reshaped, memo_in);
567  return;
568  }
569 
570  Memo *memo = static_cast<Memo*>(memo_in);
571  KALDI_ASSERT(out_value.NumRows() == memo->num_frames);
572 
573  CuSubVector<BaseFloat> mean(memo->mean_uvar_scale, 0),
574  uvar(memo->mean_uvar_scale, 1);
575  KALDI_ASSERT(mean.Dim() == block_dim_ && memo->num_frames > 0);
576  BaseFloat num_frames = memo->num_frames;
577  if (stats_sum_.Dim() != block_dim_) {
580  KALDI_ASSERT(count_ == 0);
581  }
582  count_ += num_frames;
583  stats_sum_.AddVec(num_frames, mean, 1.0);
584  stats_sumsq_.AddVec(num_frames, uvar, 1.0);
585 }
MatrixIndexT NumCols() const
Definition: cu-matrix.h:206
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:941
virtual void StoreStats(const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
This function may store stats on average activation values, and for some component types...
float BaseFloat
Definition: kaldi-types.h:29
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:68
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:205
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1182
MatrixIndexT Stride() const
Definition: cu-matrix.h:207
const Real * Data() const
Return data pointer (const).
Definition: cu-matrix.h:673
virtual std::string Type ( ) const
inlinevirtual

Returns a string such as "SigmoidComponent", describing the type of the object.

Implements Component.

Definition at line 181 of file nnet-normalize-component.h.

Referenced by BatchNormComponent::Info().

181 { return "BatchNormComponent"; }
void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write component to stream.

Implements Component.

Definition at line 612 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVecVec(), BatchNormComponent::block_dim_, BatchNormComponent::Check(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, CuVectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, CuVector< Real >::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().

612  {
613  Check();
614  WriteToken(os, binary, "<BatchNormComponent>");
615  WriteToken(os, binary, "<Dim>");
616  WriteBasicType(os, binary, dim_);
617  WriteToken(os, binary, "<BlockDim>");
618  WriteBasicType(os, binary, block_dim_);
619  WriteToken(os, binary, "<Epsilon>");
620  WriteBasicType(os, binary, epsilon_);
621  WriteToken(os, binary, "<TargetRms>");
622  WriteBasicType(os, binary, target_rms_);
623  WriteToken(os, binary, "<TestMode>");
624  WriteBasicType(os, binary, test_mode_);
625  WriteToken(os, binary, "<Count>");
626  WriteBasicType(os, binary, count_);
627  CuVector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
628  if (count_ != 0) {
629  mean.Scale(1.0 / count_);
630  var.Scale(1.0 / count_);
631  var.AddVecVec(-1.0, mean, mean, 1.0);
632  }
633  WriteToken(os, binary, "<StatsMean>");
634  mean.Write(os, binary);
635  WriteToken(os, binary, "<StatsVar>");
636  var.Write(os, binary);
637  WriteToken(os, binary, "</BatchNormComponent>");
638 }
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
void ZeroStats ( )
virtual

Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero.

Other components that store other types of statistics (e.g. regarding gradient clipping) should implement ZeroStats() also.

Reimplemented from Component.

Definition at line 664 of file nnet-normalize-component.cc.

References BatchNormComponent::count_, CuVectorBase< Real >::SetZero(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, and BatchNormComponent::test_mode_.

664  {
665  // We only zero the stats if we're not in test mode. In test mode, this would
666  // be dangerous as the stats are the source for the transform, and zeroing
667  // them and then calling ComputeDerived() again would remove the transform
668  // parameters (offset_ and scale_).
669  if (!test_mode_) {
670  count_ = 0.0;
673  }
674 }
void SetZero()
Math operations.
Definition: cu-vector.cc:1044

Member Data Documentation


The documentation for this class was generated from the following files: