BatchNormComponent Class Reference

#include <nnet-normalize-component.h>

Inheritance diagram for BatchNormComponent:
Collaboration diagram for BatchNormComponent:

Classes

struct  Memo
 

Public Member Functions

 BatchNormComponent ()
 
void SetTestMode (bool test_mode)
 
 BatchNormComponent (const BatchNormComponent &other)
 
virtual int32 InputDim () const
 Returns input-dimension of this component. More...
 
virtual int32 OutputDim () const
 Returns output-dimension of this component. More...
 
virtual std::string Info () const
 Returns some text-form information about this component, for diagnostics. More...
 
virtual void InitFromConfig (ConfigLine *cfl)
 Initialize, from a ConfigLine object. More...
 
virtual std::string Type () const
 Returns a string such as "SigmoidComponent", describing the type of the object. More...
 
virtual int32 Properties () const
 Return bitmask of the component's properties. More...
 
virtual void * Propagate (const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
 Propagate function. More...
 
virtual void Backprop (const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *, CuMatrixBase< BaseFloat > *in_deriv) const
 Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update. More...
 
virtual void Read (std::istream &is, bool binary)
 Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed. More...
 
virtual void Write (std::ostream &os, bool binary) const
 Write component to stream. More...
 
virtual ComponentCopy () const
 Copies component (deep copy). More...
 
virtual void Scale (BaseFloat scale)
 This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent. More...
 
virtual void Add (BaseFloat alpha, const Component &other)
 This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters. More...
 
virtual void ZeroStats ()
 Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero. More...
 
virtual void DeleteMemo (void *memo) const
 This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function. More...
 
virtual void StoreStats (const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
 This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity. More...
 
const CuVector< BaseFloat > & Offset () const
 
const CuVector< BaseFloat > & Scale () const
 
- Public Member Functions inherited from Component
virtual void GetInputIndexes (const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual bool IsComputable (const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
 This function only does something interesting for non-simple Components, and it exists to make it possible to manage optionally-required inputs. More...
 
virtual void ReorderIndexes (std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual ComponentPrecomputedIndexesPrecomputeIndexes (const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
 This function must return NULL for simple Components. More...
 
virtual void ConsolidateMemory ()
 This virtual function relates to memory management, and avoiding fragmentation. More...
 
 Component ()
 
virtual ~Component ()
 

Private Member Functions

void Check () const
 
void ComputeDerived ()
 

Static Private Member Functions

static void ComputeOffsetAndScale (double count, BaseFloat epsilon, const Vector< double > &stats_sum, const Vector< double > &stats_sumsq, Vector< BaseFloat > *offset, Vector< BaseFloat > *scale)
 

Private Attributes

int32 dim_
 
int32 block_dim_
 
BaseFloat epsilon_
 
BaseFloat target_rms_
 
bool test_mode_
 
double count_
 
CuVector< double > stats_sum_
 
CuVector< double > stats_sumsq_
 
CuVector< BaseFloatoffset_
 
CuVector< BaseFloatscale_
 

Additional Inherited Members

- Static Public Member Functions inherited from Component
static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream (works out its type). Dies on error. More...
 
static ComponentNewComponentOfType (const std::string &type)
 Returns a new Component of the given type e.g. More...
 

Detailed Description

Definition at line 159 of file nnet-normalize-component.h.

Constructor & Destructor Documentation

◆ BatchNormComponent() [1/2]

BatchNormComponent ( )
inline

Definition at line 162 of file nnet-normalize-component.h.

162 { }

◆ BatchNormComponent() [2/2]

Member Function Documentation

◆ Add()

void Add ( BaseFloat  alpha,
const Component other 
)
virtual

This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters.

– a NonlinearComponent (or another component that stores stats, like BatchNormComponent)– it relates to adding stats. Otherwise it will normally do nothing.

Reimplemented from Component.

Definition at line 657 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVec(), BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::stats_sum_, and BatchNormComponent::stats_sumsq_.

657  {
658  const BatchNormComponent *other =
659  dynamic_cast<const BatchNormComponent*>(&other_in);
660  count_ += alpha * other->count_;
661  stats_sum_.AddVec(alpha, other->stats_sum_);
662  stats_sumsq_.AddVec(alpha, other->stats_sumsq_);
663  // this operation might change offset_ and scale_, so we recompute them
664  // in this instance (but not in Scale()).
665  ComputeDerived();
666 }
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1237

◆ Backprop()

void Backprop ( const std::string &  debug_info,
const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
const CuMatrixBase< BaseFloat > &  out_deriv,
void *  memo,
Component to_update,
CuMatrixBase< BaseFloat > *  in_deriv 
) const
virtual

Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update.

Parameters
[in]debug_infoThe component name, to be printed out in any warning messages.
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]in_valueThe matrix that was given as input to the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsInput == 0.
[in]out_valueThe matrix that was output from the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsOutput == 0
[in]out_derivThe derivative at the output of this component.
[in]memoThis will normally be NULL, but for component types that set the flag kUsesMemo, this will be the return value of the Propagate() function that corresponds to this Backprop() function. Ownership of any pointers is not transferred to the Backprop function; DeleteMemo() will be called to delete it.
[out]to_updateIf model update is desired, the Component to be updated, else NULL. Does not have to be identical to this. If supplied, you can assume that to_update->Properties() & kUpdatableComponent is nonzero.
[out]in_derivThe derivative at the input of this component, if needed (else NULL). If Properties()&kBackpropInPlace, may be the same matrix as out_deriv. If Properties()&kBackpropAdds, this is added to by the Backprop routine, else it is set. The component code chooses which mode to work in, based on convenience.

Implements Component.

Definition at line 467 of file nnet-normalize-component.cc.

References CuMatrixBase< Real >::AddMatDiagVec(), CuMatrixBase< Real >::AddVecToRows(), BatchNormComponent::block_dim_, CuMatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::Data(), BatchNormComponent::dim_, KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, BatchNormComponent::Memo::mean_uvar_scale, CuMatrixBase< Real >::MulColsVec(), BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), NVTX_RANGE, BatchNormComponent::offset_, kaldi::SameDim(), BatchNormComponent::scale_, CuMatrixBase< Real >::Stride(), BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

475  {
476  NVTX_RANGE("BatchNormComponent::Backprop");
477 
478  KALDI_ASSERT(SameDim(out_value, out_deriv) &&
479  SameDim(out_value, *in_deriv) &&
480  (out_value.NumCols() == dim_ ||
481  out_value.NumCols() == block_dim_));
482  if (out_value.NumCols() != block_dim_) {
483  // if block_dim_ != dim_, we recurse; this helps keep the main code
484  // simple.
485  KALDI_ASSERT(out_value.Stride() == out_value.NumCols() &&
486  out_deriv.Stride() == out_deriv.NumCols() &&
487  in_deriv->Stride() == in_deriv->NumCols());
488  int32 ratio = dim_ / block_dim_,
489  orig_rows = out_value.NumRows(),
490  orig_cols = out_value.NumCols(),
491  new_rows = orig_rows * ratio, new_cols = orig_cols / ratio;
492  CuSubMatrix<BaseFloat> out_value_reshaped(out_value.Data(), new_rows,
493  new_cols, new_cols),
494  out_deriv_reshaped(out_deriv.Data(), new_rows, new_cols, new_cols),
495  in_deriv_reshaped(in_deriv->Data(), new_rows, new_cols, new_cols);
496  // we'll never use in_value, so pass it in unchanged.
497  Backprop(debug_info, indexes, in_value,
498  out_value_reshaped, out_deriv_reshaped,
499  memo_in, to_update, &in_deriv_reshaped);
500  return;
501  }
502 
503  Memo *memo = static_cast<Memo*>(memo_in);
504 
505  if (!test_mode_) {
506  // search above for BACKWARD PASS for a comment describing the math.
507  KALDI_ASSERT(memo != NULL && "memo not passed into backprop");
508  int32 num_frames = memo->num_frames;
509  KALDI_ASSERT(out_value.NumRows() == num_frames);
510  CuSubVector<BaseFloat>
511  scale(memo->mean_uvar_scale, 2),
512  var_deriv_mod(memo->mean_uvar_scale, 3),
513  temp(memo->mean_uvar_scale, 4);
514 
515  // var_deriv_mod is going to contain:
516  // 2 * power * target-rms^{1/power} * (1/I \sum_i z'(i) z(i)) * scale^{-(1+power)/power}
517  // which for power = -0.5 simplifies to:
518  // -1.0 / (target_rms * target_rms).
519  // but for now we don't have the power of 'scale', we'll add that later.
520  BaseFloat coeff = -1.0 / (target_rms_ * target_rms_ * num_frames);
521 
522  var_deriv_mod.AddDiagMatMat(coeff, out_value, kTrans,
523  out_deriv, kNoTrans, 0.0);
524  var_deriv_mod.MulElements(scale);
525 
526  temp.AddRowSumMat(-1.0 / num_frames, out_deriv, 0.0);
527  // the following statement does no work if in_deriv and out_deriv are the
528  // same matrix.
529  in_deriv->CopyFromMat(out_deriv);
530  in_deriv->AddVecToRows(1.0, temp);
531  // At this point, *in_deriv contains
532  // (z'(i) - 1/I * \sum_i z'(i))
533  in_deriv->MulColsVec(scale);
534  // At this point, *in_deriv contains
535  // scale * (z'(i) - 1/I * \sum_i z'(i))
536 
537  in_deriv->AddMatDiagVec(1.0, out_value, kNoTrans,
538  var_deriv_mod, 1.0);
539 
540  // At this point, *in_deriv contains what we described in the comment
541  // starting BATCHNORM_MATH as:
542  // x'(i) = scale * (z'(i) - 1/I * \sum_i z'(i)) + z(i) var_deriv_mod
543  } else {
544  KALDI_ASSERT(offset_.Dim() == block_dim_);
545  // the next call does no work if they point to the same memory.
546  in_deriv->CopyFromMat(out_deriv);
547  in_deriv->MulColsVec(scale_);
548  }
549 }
kaldi::int32 int32
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define NVTX_RANGE(name)
Definition: cu-common.h:143
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments &#39;to_update&#39; and &#39;in_deriv&#39; are non-NULL...

◆ Check()

◆ ComputeDerived()

void ComputeDerived ( )
private

Definition at line 209 of file nnet-normalize-component.cc.

References NormalizeComponent::block_dim_, KALDI_WARN, and NormalizeComponent::target_rms_.

Referenced by BatchNormComponent::Add(), BatchNormComponent::BatchNormComponent(), BatchNormComponent::InitFromConfig(), and BatchNormComponent::Read().

209  {
210  if (!test_mode_) {
211  offset_.Resize(0);
212  scale_.Resize(0);
213  return;
214  }
215 
216  if (count_ == 0.0) {
217  KALDI_WARN << "Test-mode is set but there is no data count. "
218  "Creating random counts. This is NOT A PROBLEM if the message "
219  "appears in unit-tests or in compute_prob_*.0.log. If you see this "
220  "elsewhere, something is very wrong.";
221  count_ = 1.0;
225  }
226 
227  offset_.Resize(block_dim_);
228  scale_.Resize(block_dim_);
229  offset_.CopyFromVec(stats_sum_);
230  offset_.Scale(-1.0 / count_);
231  // now offset_ is -mean.
232  scale_.CopyFromVec(stats_sumsq_);
233  scale_.Scale(1.0 / count_);
234  scale_.AddVecVec(-1.0, offset_, offset_, 1.0);
235  // now scale_ is variance.
236  // Mathematically the ApplyFloor statement should be a no-op; this is in case
237  // of numerical roundoff.
238  scale_.ApplyFloor(0.0);
239  scale_.Add(epsilon_);
240  BaseFloat power = -0.5;
241  scale_.ApplyPow(power);
242  // now scale_ = min(variance, epsilon)^power
243  // next, multiply by the target RMS (normally 1.0).
244  scale_.Scale(target_rms_);
245  offset_.MulElements(scale_);
246  // now offset_ is -(scale*mean).
247 }
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_WARN
Definition: kaldi-error.h:150
void AddVecVec(Real alpha, const CuVectorBase< Real > &v, const CuVectorBase< Real > &r, Real beta)
Definition: cu-vector.cc:560

◆ ComputeOffsetAndScale()

static void ComputeOffsetAndScale ( double  count,
BaseFloat  epsilon,
const Vector< double > &  stats_sum,
const Vector< double > &  stats_sumsq,
Vector< BaseFloat > *  offset,
Vector< BaseFloat > *  scale 
)
staticprivate

◆ Copy()

virtual Component* Copy ( ) const
inlinevirtual

Copies component (deep copy).

Implements Component.

Definition at line 209 of file nnet-normalize-component.h.

References Component::Add(), Component::Scale(), and Component::ZeroStats().

◆ DeleteMemo()

virtual void DeleteMemo ( void *  memo) const
inlinevirtual

This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function.

It's called by NnetComputer in cases where Propagate returns a memo but there will be no backprop to consume it.

Reimplemented from Component.

Definition at line 216 of file nnet-normalize-component.h.

References Component::StoreStats().

216 { delete static_cast<Memo*>(memo); }

◆ Info()

std::string Info ( ) const
virtual

Returns some text-form information about this component, for diagnostics.

Starts with the type of the component. E.g. "SigmoidComponent dim=900", although most components will have much more info.

Reimplemented from Component.

Definition at line 269 of file nnet-normalize-component.cc.

References VectorBase< Real >::AddVecVec(), VectorBase< Real >::ApplyFloor(), VectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, VectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, kaldi::nnet3::SummarizeVector(), BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, and BatchNormComponent::Type().

269  {
270  std::ostringstream stream;
271  stream << Type() << ", dim=" << dim_ << ", block-dim=" << block_dim_
272  << ", epsilon=" << epsilon_ << ", target-rms=" << target_rms_
273  << ", count=" << count_
274  << ", test-mode=" << (test_mode_ ? "true" : "false");
275  if (count_ > 0) {
276  Vector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
277  mean.Scale(1.0 / count_);
278  var.Scale(1.0 / count_);
279  // subtract mean^2 from var.
280  var.AddVecVec(-1.0, mean, mean, 1.0);
281  var.ApplyFloor(0.0);
282  var.ApplyPow(0.5); // make it the stddev.
283  stream << ", data-mean=" << SummarizeVector(mean)
284  << ", data-stddev=" << SummarizeVector(var);
285  }
286  return stream.str();
287 }
std::string SummarizeVector(const VectorBase< float > &vec)
Returns a string that summarizes a vector fairly succintly, for printing stats in info lines...
Definition: nnet-parse.cc:111
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.

◆ InitFromConfig()

void InitFromConfig ( ConfigLine cfl)
virtual

Initialize, from a ConfigLine object.

Parameters
[in]cflA ConfigLine containing any parameters that are needed for initialization. For example: "dim=100 param-stddev=0.1"

Implements Component.

Definition at line 289 of file nnet-normalize-component.cc.

References BatchNormComponent::block_dim_, BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, ConfigLine::GetValue(), ConfigLine::HasUnusedValues(), KALDI_ERR, CuVector< Real >::Resize(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, and ConfigLine::UnusedValues().

289  {
290  dim_ = -1;
291  block_dim_ = -1;
292  epsilon_ = 1.0e-03;
293  target_rms_ = 1.0;
294  test_mode_ = false;
295  bool ok = cfl->GetValue("dim", &dim_);
296  cfl->GetValue("block-dim", &block_dim_);
297  cfl->GetValue("epsilon", &epsilon_);
298  cfl->GetValue("target-rms", &target_rms_);
299  cfl->GetValue("test-mode", &test_mode_);
300  if (!ok || dim_ <= 0) {
301  KALDI_ERR << "BatchNormComponent must have 'dim' specified, and > 0";
302  }
303  if (block_dim_ == -1)
304  block_dim_ = dim_;
305  if (!(block_dim_ > 0 && dim_ % block_dim_ == 0 &&
306  epsilon_ > 0 && target_rms_ > 0))
307  KALDI_ERR << "Invalid configuration in BatchNormComponent.";
308  if (cfl->HasUnusedValues())
309  KALDI_ERR << "Could not process these elements in initializer: "
310  << cfl->UnusedValues();
311  count_ = 0;
314  if (test_mode_) {
315  ComputeDerived();
316  }
317 }
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:993
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ InputDim()

virtual int32 InputDim ( ) const
inlinevirtual

Returns input-dimension of this component.

Implements Component.

Definition at line 176 of file nnet-normalize-component.h.

◆ Offset()

const CuVector<BaseFloat>& Offset ( ) const
inline

Definition at line 224 of file nnet-normalize-component.h.

Referenced by ModelCollapser::CollapseComponentsBatchnorm().

224 { return offset_; }

◆ OutputDim()

virtual int32 OutputDim ( ) const
inlinevirtual

Returns output-dimension of this component.

Implements Component.

Definition at line 177 of file nnet-normalize-component.h.

References NormalizeComponent::Info(), and NormalizeComponent::InitFromConfig().

◆ Propagate()

void * Propagate ( const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
) const
virtual

Propagate function.

Parameters
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]inThe input to this component. Num-columns == InputDim().
[out]outThe output of this component. Num-columns == OutputDim(). Note: output of this component will be added to the initial value of "out" if Properties()&kPropagateAdds != 0; otherwise the output will be set and the initial value ignored. Each Component chooses whether it is more convenient implementation-wise to add or set, and the calling code has to deal with it.
Returns
Normally returns NULL, but may return a non-NULL value for components which have the flag kUsesMemo set. This value will be passed into the corresponding Backprop routine.

Implements Component.

Definition at line 401 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::Add(), CuMatrixBase< Real >::AddVecToRows(), CuVectorBase< Real >::AddVecVec(), CuVectorBase< Real >::ApplyFloor(), CuVectorBase< Real >::ApplyPow(), BatchNormComponent::block_dim_, CuMatrixBase< Real >::CopyFromMat(), CuVectorBase< Real >::CopyFromVec(), BatchNormComponent::count_, CuMatrixBase< Real >::Data(), BatchNormComponent::dim_, BatchNormComponent::epsilon_, KALDI_ASSERT, KALDI_ERR, kaldi::kTrans, BatchNormComponent::Memo::mean_uvar_scale, CuMatrixBase< Real >::MulColsVec(), BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), BatchNormComponent::offset_, kaldi::SameDim(), BatchNormComponent::scale_, CuMatrixBase< Real >::Stride(), BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

403  {
404  KALDI_ASSERT(SameDim(in, *out) &&
405  (in.NumCols() == dim_ || in.NumCols() == block_dim_));
406  if (in.NumCols() != block_dim_) {
407  // if block_dim_ != dim_, we recurse; this helps keep the main code
408  // simple.
409  KALDI_ASSERT(in.Stride() == in.NumCols() && out->Stride() == out->NumCols());
410  int32 ratio = dim_ / block_dim_, orig_rows = in.NumRows(),
411  orig_cols = in.NumCols(), new_rows = orig_rows * ratio,
412  new_cols = orig_cols / ratio;
413  CuSubMatrix<BaseFloat> in_reshaped(in.Data(), new_rows, new_cols, new_cols),
414  out_reshaped(out->Data(), new_rows, new_cols, new_cols);
415  return Propagate(indexes, in_reshaped, &out_reshaped);
416  }
417 
418  // From this point, we can assume that the num-cols of 'in' and 'out'
419  // equals block_dim_.
420 
421  if (!test_mode_) {
422  // search in the comment above for FORWARD PASS to see what is being
423  // implemented here.
424  // if this takes too much time due to multiple different CUDA calls,
425  // we'll consider making a single kernel for some of it.
426  Memo *memo = new Memo;
427  int32 num_frames = in.NumRows(), dim = block_dim_;
428  memo->num_frames = num_frames;
429  memo->mean_uvar_scale.Resize(5, dim);
430  CuSubVector<BaseFloat> mean(memo->mean_uvar_scale, 0),
431  uvar(memo->mean_uvar_scale, 1),
432  scale(memo->mean_uvar_scale, 2);
433  mean.AddRowSumMat(1.0 / num_frames, in, 0.0);
434  uvar.AddDiagMat2(1.0 / num_frames, in, kTrans, 0.0);
435  scale.CopyFromVec(uvar);
436 
437  // by applying this scale at this point, we save a multiply later on.
438  BaseFloat var_scale = 1.0 / (target_rms_ * target_rms_);
439  scale.AddVecVec(-var_scale, mean, mean, var_scale);
440  // at this point, 'scale' contains just the variance (times target-rms^{-2}).
441  scale.ApplyFloor(0.0);
442  scale.Add(var_scale * epsilon_);
443  // Now 'scale' contains the variance floored to zero and then with epsilon
444  // added [both times 1/target-rms^2].
445  scale.ApplyPow(-0.5);
446  // now 'scale' is the actual scale we'll use.
447 
448  // the next command will do no work if out == in, for in-place propagation.
449  out->CopyFromMat(in);
450  out->AddVecToRows(-1.0, mean, 1.0);
451  out->MulColsVec(scale);
452  return static_cast<void*>(memo);
453  } else {
454  if (offset_.Dim() != block_dim_) {
455  if (count_ == 0)
456  KALDI_ERR << "Test mode set in BatchNormComponent, but no stats.";
457  else // why was ComputeDerived() not called?
458  KALDI_ERR << "Code error in BatchNormComponent";
459  }
460  out->CopyFromMat(in);
461  out->MulColsVec(scale_);
462  out->AddVecToRows(1.0, offset_, 1.0);
463  return NULL;
464  }
465 }
kaldi::int32 int32
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
float BaseFloat
Definition: kaldi-types.h:29
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
Propagate function.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Properties()

virtual int32 Properties ( ) const
inlinevirtual

Return bitmask of the component's properties.

These properties depend only on the component's type. See enum ComponentProperties.

Implements Component.

Definition at line 182 of file nnet-normalize-component.h.

References NormalizeComponent::Backprop(), NormalizeComponent::block_dim_, kaldi::nnet3::kBackpropInPlace, kaldi::nnet3::kBackpropNeedsOutput, kaldi::nnet3::kInputContiguous, kaldi::nnet3::kOutputContiguous, kaldi::nnet3::kPropagateInPlace, kaldi::nnet3::kSimpleComponent, kaldi::nnet3::kStoresStats, kaldi::nnet3::kUsesMemo, NormalizeComponent::Propagate(), NormalizeComponent::Read(), and NormalizeComponent::Write().

182  {
183  // If the block-dim is less than the dim, we need the input and output
184  // matrices to be contiguous (stride==num-cols), as we'll be reshaping
185  // internally. This is not much of a cost, because this will be used
186  // in convnets where we have to do this anyway.
191  }

◆ Read()

void Read ( std::istream &  is,
bool  binary 
)
virtual

Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed.

Implements Component.

Definition at line 591 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVecVec(), BatchNormComponent::block_dim_, BatchNormComponent::Check(), BatchNormComponent::ComputeDerived(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, kaldi::ExpectOneOrTwoTokens(), kaldi::nnet3::ExpectToken(), CuVector< Real >::Read(), kaldi::ReadBasicType(), CuVectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, and BatchNormComponent::test_mode_.

591  {
592  ExpectOneOrTwoTokens(is, binary, "<BatchNormComponent>", "<Dim>");
593  ReadBasicType(is, binary, &dim_);
594  ExpectToken(is, binary, "<BlockDim>");
595  ReadBasicType(is, binary, &block_dim_);
596  ExpectToken(is, binary, "<Epsilon>");
597  ReadBasicType(is, binary, &epsilon_);
598  ExpectToken(is, binary, "<TargetRms>");
599  ReadBasicType(is, binary, &target_rms_);
600  ExpectToken(is, binary, "<TestMode>");
601  ReadBasicType(is, binary, &test_mode_);
602  ExpectToken(is, binary, "<Count>");
603  ReadBasicType(is, binary, &count_);
604  ExpectToken(is, binary, "<StatsMean>");
605  stats_sum_.Read(is, binary);
606  ExpectToken(is, binary, "<StatsVar>");
607  stats_sumsq_.Read(is, binary);
611  ExpectToken(is, binary, "</BatchNormComponent>");
612  ComputeDerived();
613  Check();
614 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)
This function is like ExpectToken but for two tokens, and it will either accept token1 and then token...
Definition: text-utils.cc:536
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void Read(std::istream &is, bool binary)
I/O.
Definition: cu-vector.cc:963
void AddVecVec(Real alpha, const CuVectorBase< Real > &v, const CuVectorBase< Real > &r, Real beta)
Definition: cu-vector.cc:560
void Scale(Real value)
Definition: cu-vector.cc:1216

◆ Scale() [1/2]

void Scale ( BaseFloat  scale)
virtual

This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent.

– a Nonlinear component (or another component that stores stats, like BatchNormComponent)– it relates to scaling activation stats, not parameters. Otherwise it will normally do nothing.

Reimplemented from Component.

Definition at line 644 of file nnet-normalize-component.cc.

References BatchNormComponent::count_, CuVectorBase< Real >::Scale(), CuVectorBase< Real >::SetZero(), BatchNormComponent::stats_sum_, and BatchNormComponent::stats_sumsq_.

Referenced by ModelCollapser::CollapseComponentsBatchnorm(), and kaldi::nnet3::ScaleBatchnormStats().

644  {
645  if (scale == 0) {
646  count_ = 0.0;
649  } else {
650  count_ *= scale;
651  stats_sum_.Scale(scale);
652  stats_sumsq_.Scale(scale);
653  }
654 }
void SetZero()
Math operations.
Definition: cu-vector.cc:1098
void Scale(Real value)
Definition: cu-vector.cc:1216

◆ Scale() [2/2]

const CuVector<BaseFloat>& Scale ( ) const
inline

Definition at line 225 of file nnet-normalize-component.h.

225 { return scale_; }

◆ SetTestMode()

void SetTestMode ( bool  test_mode)

◆ StoreStats()

void StoreStats ( const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
void *  memo 
)
virtual

This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity.

It only does something for those components that have nonzero Properties()&kStoresStats.

Parameters
[in]in_valueThe input to the Propagate() function. Note: if the component sets the flag kPropagateInPlace, this should not be used; the empty matrix will be provided here if in-place propagation was used.
[in]out_valueThe output of the Propagate() function.
[in]memoThe 'memo' returned by the Propagate() function; this will usually be NULL.

Reimplemented from Component.

Definition at line 551 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVec(), BatchNormComponent::block_dim_, BatchNormComponent::count_, CuMatrixBase< Real >::Data(), CuVectorBase< Real >::Dim(), BatchNormComponent::dim_, KALDI_ASSERT, BatchNormComponent::Memo::mean_uvar_scale, BatchNormComponent::Memo::num_frames, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), CuVector< Real >::Resize(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, CuMatrixBase< Real >::Stride(), and BatchNormComponent::test_mode_.

554  {
555  // in test mode this component does not store stats, it doesn't provide the
556  // kStoresStats flag.
558  KALDI_ASSERT(out_value.NumCols() == dim_ || out_value.NumCols() == block_dim_);
559  if (out_value.NumCols() != block_dim_) {
560  // if block_dim_ != dim_, we recurse; this helps keep the main code
561  // simple.
562  KALDI_ASSERT(out_value.Stride() == out_value.NumCols());
563  int32 ratio = dim_ / block_dim_,
564  orig_rows = out_value.NumRows(),
565  orig_cols = out_value.NumCols(),
566  new_rows = orig_rows * ratio, new_cols = orig_cols / ratio;
567  CuSubMatrix<BaseFloat> out_value_reshaped(out_value.Data(), new_rows,
568  new_cols, new_cols);
569  // we'll never use in_value, so just pass it in unchanged.
570  StoreStats(in_value, out_value_reshaped, memo_in);
571  return;
572  }
573 
574  Memo *memo = static_cast<Memo*>(memo_in);
575  KALDI_ASSERT(out_value.NumRows() == memo->num_frames);
576 
577  CuSubVector<BaseFloat> mean(memo->mean_uvar_scale, 0),
578  uvar(memo->mean_uvar_scale, 1);
579  KALDI_ASSERT(mean.Dim() == block_dim_ && memo->num_frames > 0);
580  BaseFloat num_frames = memo->num_frames;
581  if (stats_sum_.Dim() != block_dim_) {
584  KALDI_ASSERT(count_ == 0);
585  }
586  count_ += num_frames;
587  stats_sum_.AddVec(num_frames, mean, 1.0);
588  stats_sumsq_.AddVec(num_frames, uvar, 1.0);
589 }
kaldi::int32 int32
virtual void StoreStats(const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
This function may store stats on average activation values, and for some component types...
float BaseFloat
Definition: kaldi-types.h:29
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:993
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1237
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:69

◆ Type()

virtual std::string Type ( ) const
inlinevirtual

Returns a string such as "SigmoidComponent", describing the type of the object.

Implements Component.

Definition at line 181 of file nnet-normalize-component.h.

Referenced by BatchNormComponent::Info().

181 { return "BatchNormComponent"; }

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write component to stream.

Implements Component.

Definition at line 616 of file nnet-normalize-component.cc.

References CuVectorBase< Real >::AddVecVec(), BatchNormComponent::block_dim_, BatchNormComponent::Check(), BatchNormComponent::count_, BatchNormComponent::dim_, BatchNormComponent::epsilon_, CuVectorBase< Real >::Scale(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, BatchNormComponent::target_rms_, BatchNormComponent::test_mode_, CuVector< Real >::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().

616  {
617  Check();
618  WriteToken(os, binary, "<BatchNormComponent>");
619  WriteToken(os, binary, "<Dim>");
620  WriteBasicType(os, binary, dim_);
621  WriteToken(os, binary, "<BlockDim>");
622  WriteBasicType(os, binary, block_dim_);
623  WriteToken(os, binary, "<Epsilon>");
624  WriteBasicType(os, binary, epsilon_);
625  WriteToken(os, binary, "<TargetRms>");
626  WriteBasicType(os, binary, target_rms_);
627  WriteToken(os, binary, "<TestMode>");
628  WriteBasicType(os, binary, test_mode_);
629  WriteToken(os, binary, "<Count>");
630  WriteBasicType(os, binary, count_);
631  CuVector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
632  if (count_ != 0) {
633  mean.Scale(1.0 / count_);
634  var.Scale(1.0 / count_);
635  var.AddVecVec(-1.0, mean, mean, 1.0);
636  }
637  WriteToken(os, binary, "<StatsMean>");
638  mean.Write(os, binary);
639  WriteToken(os, binary, "<StatsVar>");
640  var.Write(os, binary);
641  WriteToken(os, binary, "</BatchNormComponent>");
642 }
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

◆ ZeroStats()

void ZeroStats ( )
virtual

Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero.

Other components that store other types of statistics (e.g. regarding gradient clipping) should implement ZeroStats() also.

Reimplemented from Component.

Definition at line 668 of file nnet-normalize-component.cc.

References BatchNormComponent::count_, CuVectorBase< Real >::SetZero(), BatchNormComponent::stats_sum_, BatchNormComponent::stats_sumsq_, and BatchNormComponent::test_mode_.

668  {
669  // We only zero the stats if we're not in test mode. In test mode, this would
670  // be dangerous as the stats are the source for the transform, and zeroing
671  // them and then calling ComputeDerived() again would remove the transform
672  // parameters (offset_ and scale_).
673  if (!test_mode_) {
674  count_ = 0.0;
677  }
678 }
void SetZero()
Math operations.
Definition: cu-vector.cc:1098

Member Data Documentation

◆ block_dim_

◆ count_

◆ dim_

◆ epsilon_

◆ offset_

CuVector<BaseFloat> offset_
private

◆ scale_

◆ stats_sum_

◆ stats_sumsq_

◆ target_rms_

◆ test_mode_


The documentation for this class was generated from the following files: