CompositeComponent Class Reference

CompositeComponent is a component representing a sequence of [simple] components. More...

#include <nnet-simple-component.h>

Inheritance diagram for CompositeComponent:
Collaboration diagram for CompositeComponent:

Public Member Functions

virtual int32 InputDim () const
 Returns input-dimension of this component. More...
 
virtual int32 OutputDim () const
 Returns output-dimension of this component. More...
 
virtual std::string Info () const
 Returns some text-form information about this component, for diagnostics. More...
 
virtual void InitFromConfig (ConfigLine *cfl)
 Initialize, from a ConfigLine object. More...
 
virtual ComponentCopy () const
 Copies component (deep copy). More...
 
 CompositeComponent ()
 
void Init (const std::vector< Component *> &components, int32 max_rows_process)
 
virtual std::string Type () const
 Returns a string such as "SigmoidComponent", describing the type of the object. More...
 
virtual int32 Properties () const
 Return bitmask of the component's properties. More...
 
virtual void * Propagate (const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
 Propagate function. More...
 
virtual void Backprop (const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const
 Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update. More...
 
virtual void ZeroStats ()
 Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero. More...
 
virtual void Read (std::istream &is, bool binary)
 Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed. More...
 
virtual void Write (std::ostream &os, bool binary) const
 Write component to stream. More...
 
virtual void SetUnderlyingLearningRate (BaseFloat lrate)
 Sets the learning rate of gradient descent- gets multiplied by learning_rate_factor_. More...
 
virtual void SetActualLearningRate (BaseFloat lrate)
 Sets the learning rate directly, bypassing learning_rate_factor_. More...
 
virtual void SetAsGradient ()
 Sets is_gradient_ to true and sets learning_rate_ to 1, ignoring learning_rate_factor_. More...
 
virtual void Scale (BaseFloat scale)
 This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent. More...
 
virtual void Add (BaseFloat alpha, const Component &other)
 This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters. More...
 
virtual void PerturbParams (BaseFloat stddev)
 This function is to be used in testing. More...
 
virtual BaseFloat DotProduct (const UpdatableComponent &other) const
 Computes dot-product between parameters of two instances of a Component. More...
 
virtual int32 NumParameters () const
 The following new virtual function returns the total dimension of the parameters in this class. More...
 
virtual void Vectorize (VectorBase< BaseFloat > *params) const
 Turns the parameters into vector form. More...
 
virtual void UnVectorize (const VectorBase< BaseFloat > &params)
 Converts the parameters from vector form. More...
 
virtual void FreezeNaturalGradient (bool freeze)
 virtual More...
 
int32 NumComponents () const
 
const ComponentGetComponent (int32 i) const
 Gets the ith component in this component. More...
 
void SetComponent (int32 i, Component *component)
 Sets the ith component. More...
 
virtual ~CompositeComponent ()
 
- Public Member Functions inherited from UpdatableComponent
 UpdatableComponent (const UpdatableComponent &other)
 
 UpdatableComponent ()
 
virtual ~UpdatableComponent ()
 
virtual BaseFloat LearningRateFactor ()
 
virtual void SetLearningRateFactor (BaseFloat lrate_factor)
 
void SetUpdatableConfigs (const UpdatableComponent &other)
 
BaseFloat LearningRate () const
 Gets the learning rate to be used in gradient descent. More...
 
BaseFloat MaxChange () const
 Returns the per-component max-change value, which is interpreted as the maximum change (in l2 norm) in parameters that is allowed per minibatch for this component. More...
 
void SetMaxChange (BaseFloat max_change)
 
BaseFloat L2Regularization () const
 Returns the l2 regularization constant, which may be set in any updatable component (usually from the config file). More...
 
void SetL2Regularization (BaseFloat a)
 
- Public Member Functions inherited from Component
virtual void StoreStats (const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
 This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity. More...
 
virtual void GetInputIndexes (const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual bool IsComputable (const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
 This function only does something interesting for non-simple Components, and it exists to make it possible to manage optionally-required inputs. More...
 
virtual void ReorderIndexes (std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual ComponentPrecomputedIndexesPrecomputeIndexes (const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
 This function must return NULL for simple Components. More...
 
virtual void DeleteMemo (void *memo) const
 This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function. More...
 
virtual void ConsolidateMemory ()
 This virtual function relates to memory management, and avoiding fragmentation. More...
 
 Component ()
 
virtual ~Component ()
 

Private Member Functions

MatrixStrideType GetStrideType (int32 i) const
 
bool IsUpdatable () const
 

Private Attributes

int32 max_rows_process_
 
std::vector< Component * > components_
 

Additional Inherited Members

- Static Public Member Functions inherited from Component
static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream (works out its type). Dies on error. More...
 
static ComponentNewComponentOfType (const std::string &type)
 Returns a new Component of the given type e.g. More...
 
- Protected Member Functions inherited from UpdatableComponent
void InitLearningRatesFromConfig (ConfigLine *cfl)
 
std::string ReadUpdatableCommon (std::istream &is, bool binary)
 
void WriteUpdatableCommon (std::ostream &is, bool binary) const
 
- Protected Attributes inherited from UpdatableComponent
BaseFloat learning_rate_
 learning rate (typically 0.0..0.01) More...
 
BaseFloat learning_rate_factor_
 learning rate factor (normally 1.0, but can be set to another < value so that when < you call SetLearningRate(), that value will be scaled by this factor. More...
 
BaseFloat l2_regularize_
 L2 regularization constant. More...
 
bool is_gradient_
 True if this component is to be treated as a gradient rather than as parameters. More...
 
BaseFloat max_change_
 configuration value for imposing max-change More...
 

Detailed Description

CompositeComponent is a component representing a sequence of [simple] components.

The config line would be something like the following (imagine this is all on one line):

component name=composite1 type=CompositeComponent max-rows-process=2048 num-components=3 \ component1='type=BlockAffineComponent input-dim=1000 output-dim=10000 num-blocks=100' \ component2='type=RectifiedLinearComponent dim=10000' \ component3='type=BlockAffineComponent input-dim=10000 output-dim=1000 num-blocks=100'

The reason you might want to use this component, instead of directly using the same sequence of components in the config file, is to save GPU memory (at the expense of more compute)– because doing it like this means we have to re-do parts of the forward pass in the backprop phase, but we avoid using much memory for very long (and you can make the memory usage very small by making max-rows-process small). We inherit from UpdatableComponent just in case one or more of the components in the sequence are updatable.

It is an error to nest a CompositeComponent inside a CompositeComponent. The same effect can be accomplished by specifying a smaller max-rows-process in a single CompositeComponent.

Definition at line 1971 of file nnet-simple-component.h.

Constructor & Destructor Documentation

◆ CompositeComponent()

CompositeComponent ( )
inline

Definition at line 1982 of file nnet-simple-component.h.

References PnormComponent::Init().

1982 { } // use Init() or InitFromConfig() to really initialize.

◆ ~CompositeComponent()

virtual ~CompositeComponent ( )
inlinevirtual

Definition at line 2050 of file nnet-simple-component.h.

References kaldi::DeletePointers().

void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
std::vector< Component * > components_

Member Function Documentation

◆ Add()

void Add ( BaseFloat  alpha,
const Component other 
)
virtual

This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters.

– a NonlinearComponent (or another component that stores stats, like BatchNormComponent)– it relates to adding stats. Otherwise it will normally do nothing.

Reimplemented from Component.

Definition at line 4423 of file nnet-simple-component.cc.

References PerElementScaleComponent::Add(), CompositeComponent::components_, rnnlm::i, and KALDI_ASSERT.

4423  {
4424  const CompositeComponent *other = dynamic_cast<const CompositeComponent*>(
4425  &other_in);
4426  KALDI_ASSERT(other != NULL && other->components_.size() ==
4427  components_.size() && "Mismatching nnet topologies");
4428  for (size_t i = 0; i < components_.size(); i++)
4429  components_[i]->Add(alpha, *(other->components_[i]));
4430 }
virtual void Add(BaseFloat alpha, const Component &other)
This virtual function when called by – an UpdatableComponent adds the parameters of another updatabl...
std::vector< Component * > components_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Backprop()

void Backprop ( const std::string &  debug_info,
const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
const CuMatrixBase< BaseFloat > &  out_deriv,
void *  memo,
Component to_update,
CuMatrixBase< BaseFloat > *  in_deriv 
) const
virtual

Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update.

Parameters
[in]debug_infoThe component name, to be printed out in any warning messages.
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]in_valueThe matrix that was given as input to the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsInput == 0.
[in]out_valueThe matrix that was output from the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsOutput == 0
[in]out_derivThe derivative at the output of this component.
[in]memoThis will normally be NULL, but for component types that set the flag kUsesMemo, this will be the return value of the Propagate() function that corresponds to this Backprop() function. Ownership of any pointers is not transferred to the Backprop function; DeleteMemo() will be called to delete it.
[out]to_updateIf model update is desired, the Component to be updated, else NULL. Does not have to be identical to this. If supplied, you can assume that to_update->Properties() & kUpdatableComponent is nonzero.
[out]in_derivThe derivative at the input of this component, if needed (else NULL). If Properties()&kBackpropInPlace, may be the same matrix as out_deriv. If Properties()&kBackpropAdds, this is added to by the Backprop routine, else it is set. The component code chooses which mode to work in, based on convenience.

Implements Component.

Definition at line 4277 of file nnet-simple-component.cc.

References PerElementScaleComponent::Backprop(), rnnlm::i, PerElementScaleComponent::InputDim(), KALDI_ASSERT, kaldi::nnet3::kBackpropAdds, kaldi::nnet3::kBackpropNeedsInput, kaldi::nnet3::kBackpropNeedsOutput, kaldi::nnet3::kPropagateAdds, kaldi::kSetZero, kaldi::nnet3::kStoresStats, kaldi::kUndefined, kaldi::nnet3::kUpdatableComponent, kaldi::nnet3::kUsesMemo, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), NVTX_RANGE, PerElementScaleComponent::OutputDim(), PerElementScaleComponent::Properties(), and Component::StoreStats().

4284  {
4285  NVTX_RANGE("CompositeComponent::Backprop");
4286  KALDI_ASSERT(in_value.NumRows() == out_deriv.NumRows() &&
4287  in_value.NumCols() == InputDim() &&
4288  out_deriv.NumCols() == OutputDim());
4289  int32 num_rows = in_value.NumRows(),
4290  num_components = components_.size();
4291  if (max_rows_process_ > 0 && num_rows > max_rows_process_) {
4293  // recurse and process smaller parts of the data, to save memory.
4294  for (int32 row_offset = 0; row_offset < num_rows;
4295  row_offset += max_rows_process_) {
4296  bool have_output_value = (out_value.NumRows() != 0);
4297  int32 this_num_rows = std::min<int32>(max_rows_process_,
4298  num_rows - row_offset);
4299  // out_value_part will only be used if out_value is nonempty; otherwise we
4300  // make it a submatrix of 'out_deriv' to avoid errors in the constructor.
4301  const CuSubMatrix<BaseFloat> out_value_part(have_output_value ? out_value : out_deriv,
4302  row_offset, this_num_rows,
4303  0, out_deriv.NumCols());
4304  // in_deriv_value_part will only be used if in_deriv != NULL; otherwise we
4305  // make it a submatrix of 'in_value' to avoid errors in the constructor.
4306  CuSubMatrix<BaseFloat> in_deriv_part(in_deriv != NULL ? *in_deriv : in_value,
4307  row_offset, this_num_rows,
4308  0, in_value.NumCols());
4309  CuSubMatrix<BaseFloat> in_value_part(in_value, row_offset, this_num_rows,
4310  0, in_value.NumCols());
4311  const CuSubMatrix<BaseFloat> out_deriv_part(out_deriv,
4312  row_offset, this_num_rows,
4313  0, out_deriv.NumCols());
4314  CuMatrix<BaseFloat> empty_mat;
4315  this->Backprop(debug_info, NULL, in_value_part,
4316  (have_output_value ? static_cast<const CuMatrixBase<BaseFloat>&>(out_value_part) :
4317  static_cast<const CuMatrixBase<BaseFloat>&>(empty_mat)),
4318  out_deriv_part, NULL, to_update,
4319  in_deriv != NULL ? &in_deriv_part : NULL);
4320  }
4321  return;
4322  }
4323  // For now, assume all intermediate values and derivatives need to be
4324  // computed. in_value and out_deriv will always be supplied.
4325 
4326  // intermediate_outputs[i] contains the output of component i.
4327  std::vector<CuMatrix<BaseFloat> > intermediate_outputs(num_components);
4328  // intermediate_derivs[i] contains the deriative at the output of component i.
4329  std::vector<CuMatrix<BaseFloat> > intermediate_derivs(num_components - 1);
4330 
4331  KALDI_ASSERT(memo == NULL);
4332  // note: only a very few components use memos, but we need to support them.
4333  std::vector<void*> memos(num_components, NULL);
4334 
4335  int32 num_components_to_propagate = num_components;
4336  if (!(components_[num_components - 1]->Properties() & kUsesMemo)) {
4337  // we only need to propagate the very last component if it uses a memo.
4338  num_components_to_propagate--;
4339  if (num_components > 1) {
4340  // skip the last-but-one component's propagate if the last component's
4341  // backprop doesn't need the input and the last-but-one component's
4342  // backprop doesn't need the output. This is the lowest hanging fruit for
4343  // optimization; other propagates might also be skippable.
4344  int32 properties = components_[num_components - 2]->Properties(),
4345  next_properties = components_[num_components - 1]->Properties();
4346  if (!(properties & (kBackpropNeedsOutput | kUsesMemo)) &&
4347  !(next_properties & kBackpropNeedsInput)) {
4348  num_components_to_propagate--;
4349  }
4350  }
4351  }
4352 
4353 
4354  // Do the propagation again.
4355  for (int32 i = 0; i < num_components_to_propagate; i++) {
4356  MatrixResizeType resize_type =
4357  ((components_[i]->Properties() & kPropagateAdds) ?
4358  kSetZero : kUndefined);
4359  intermediate_outputs[i].Resize(num_rows, components_[i]->OutputDim(),
4360  resize_type, GetStrideType(i));
4361  memos[i] =
4362  components_[i]->Propagate(NULL,
4363  (i == 0 ? in_value : intermediate_outputs[i-1]),
4364  &(intermediate_outputs[i]));
4365  }
4366 
4367  for (int32 i = num_components - 1; i >= 0; i--) {
4368  const CuMatrixBase<BaseFloat> &this_in_value =
4369  (i == 0 ? in_value : intermediate_outputs[i-1]),
4370  &this_out_value =
4371  (i == num_components - 1 ? out_value : intermediate_outputs[i]);
4372 
4373  Component *component_to_update =
4374  (to_update == NULL ? NULL :
4375  dynamic_cast<CompositeComponent*>(to_update)->components_[i]);
4376 
4377  if (component_to_update != NULL &&
4379  component_to_update->StoreStats(this_in_value, this_out_value, memos[i]);
4380 
4381  if (i > 0) {
4382  MatrixResizeType resize_type =
4383  ((components_[i]->Properties() & kBackpropAdds) ?
4384  kSetZero : kUndefined);
4385  intermediate_derivs[i-1].Resize(num_rows, components_[i]->InputDim(),
4386  resize_type, GetStrideType(i - 1));
4387  }
4388  // skip the first component's backprop if it's not updatable and in_deriv is
4389  // not requested. Again, this is the lowest-hanging fruit to optimize.
4390  if (!(i == 0 && !(components_[0]->Properties() & kUpdatableComponent) &&
4391  in_deriv == NULL)) {
4392  components_[i]->Backprop(debug_info, NULL,
4393  this_in_value, this_out_value,
4394  (i + 1 == num_components ? out_deriv : intermediate_derivs[i]),
4395  memos[i], component_to_update,
4396  (i == 0 ? in_deriv : &(intermediate_derivs[i-1])));
4397  }
4398  if (memos[i] != NULL)
4399  components_[i]->DeleteMemo(memos[i]);
4400  }
4401 }
virtual int32 OutputDim() const
Returns output-dimension of this component.
MatrixResizeType
Definition: matrix-common.h:37
kaldi::int32 int32
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments &#39;to_update&#39; and &#39;in_deriv&#39; are non-NULL...
MatrixStrideType GetStrideType(int32 i) const
std::vector< Component * > components_
virtual int32 InputDim() const
Returns input-dimension of this component.
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define NVTX_RANGE(name)
Definition: cu-common.h:143

◆ Copy()

Component * Copy ( ) const
virtual

Copies component (deep copy).

Implements Component.

Definition at line 4567 of file nnet-simple-component.cc.

References NaturalGradientPerElementScaleComponent::Copy(), rnnlm::i, and CompositeComponent::Init().

4567  {
4568  std::vector<Component*> components(components_.size());
4569  for (size_t i = 0; i < components_.size(); i++)
4570  components[i] = components_[i]->Copy();
4572  ans->Init(components, max_rows_process_);
4573  return ans;
4574 }
std::vector< Component * > components_
virtual Component * Copy() const
Copies component (deep copy).

◆ DotProduct()

BaseFloat DotProduct ( const UpdatableComponent other) const
virtual

Computes dot-product between parameters of two instances of a Component.

Can be used for computing parameter-norm of an UpdatableComponent.

Implements UpdatableComponent.

Definition at line 4534 of file nnet-simple-component.cc.

References CompositeComponent::components_, UpdatableComponent::DotProduct(), rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kUpdatableComponent, and PerElementScaleComponent::Properties().

4535  {
4536  const CompositeComponent *other = dynamic_cast<const CompositeComponent*>(
4537  &other_in);
4538  KALDI_ASSERT(other != NULL && other->components_.size() ==
4539  components_.size() && "Mismatching nnet topologies");
4540  BaseFloat ans = 0.0;
4541  for (size_t i = 0.0; i < components_.size(); i++) {
4543  UpdatableComponent *uc =
4544  dynamic_cast<UpdatableComponent*>(components_[i]);
4545  const UpdatableComponent *uc_other =
4546  dynamic_cast<UpdatableComponent*>(other->components_[i]);
4547  KALDI_ASSERT(uc != NULL && uc_other != NULL);
4548  ans += uc->DotProduct(*uc_other);
4549  }
4550  }
4551  return ans;
4552 }
float BaseFloat
Definition: kaldi-types.h:29
std::vector< Component * > components_
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ FreezeNaturalGradient()

void FreezeNaturalGradient ( bool  freeze)
virtual

virtual

Reimplemented from UpdatableComponent.

Definition at line 4555 of file nnet-simple-component.cc.

References UpdatableComponent::FreezeNaturalGradient(), rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kUpdatableComponent, and PerElementScaleComponent::Properties().

4555  {
4556  for (size_t i = 0; i < components_.size(); i++) {
4558  UpdatableComponent *uc =
4559  dynamic_cast<UpdatableComponent*>(components_[i]);
4560  KALDI_ASSERT(uc != NULL);
4561  uc->FreezeNaturalGradient(freeze);
4562  }
4563  }
4564 }
std::vector< Component * > components_
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetComponent()

const Component * GetComponent ( int32  i) const

Gets the ith component in this component.

The ordering is the same as in the config line. The caller does not own the received component.

Definition at line 4633 of file nnet-simple-component.cc.

References rnnlm::i, and KALDI_ASSERT.

Referenced by kaldi::nnet3::ConvertRepeatedToBlockAffine(), and kaldi::nnet3::UnitTestConvertRepeatedToBlockAffineComposite().

4633  {
4634  KALDI_ASSERT(static_cast<size_t>(i) < components_.size());
4635  return components_[i];
4636 }
std::vector< Component * > components_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetStrideType()

MatrixStrideType GetStrideType ( int32  i) const
inlineprivate

Definition at line 4125 of file nnet-simple-component.cc.

References kaldi::kDefaultStride, kaldi::nnet3::kInputContiguous, kaldi::nnet3::kOutputContiguous, kaldi::kStrideEqualNumCols, and PerElementScaleComponent::Properties().

4125  {
4126  int32 num_components = components_.size();
4127  if ((components_[i]->Properties() & kOutputContiguous) ||
4128  (i + 1 < num_components &&
4130  return kStrideEqualNumCols;
4131  else
4132  return kDefaultStride;
4133 }
kaldi::int32 int32
std::vector< Component * > components_
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.

◆ Info()

std::string Info ( ) const
virtual

Returns some text-form information about this component, for diagnostics.

Starts with the type of the component. E.g. "SigmoidComponent dim=900", although most components will have much more info.

Reimplemented from UpdatableComponent.

Definition at line 4405 of file nnet-simple-component.cc.

References rnnlm::i, NaturalGradientPerElementScaleComponent::Info(), and NaturalGradientPerElementScaleComponent::Type().

4405  {
4406  std::ostringstream stream;
4407  stream << Type() << " ";
4408  for (size_t i = 0; i < components_.size(); i++) {
4409  if (i > 0) stream << ", ";
4410  stream << "sub-component" << (i+1) << " = { "
4411  << components_[i]->Info() << " }";
4412  }
4413  return stream.str();
4414 }
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
std::vector< Component * > components_

◆ Init()

void Init ( const std::vector< Component *> &  components,
int32  max_rows_process 
)

Definition at line 4184 of file nnet-simple-component.cc.

References kaldi::DeletePointers(), rnnlm::i, PerElementScaleComponent::InputDim(), KALDI_ASSERT, kaldi::nnet3::kSimpleComponent, PerElementScaleComponent::OutputDim(), and PerElementScaleComponent::Properties().

Referenced by CompositeComponent::Copy().

4185  {
4186  DeletePointers(&components_); // clean up.
4187  components_ = components;
4188  KALDI_ASSERT(!components.empty());
4189  max_rows_process_ = max_rows_process;
4190 
4191  for (size_t i = 0; i < components_.size(); i++) {
4192  // make sure all constituent components are simple.
4194  if (i > 0) {
4195  // make sure all the internal dimensions match up.
4197  components_[i-1]->OutputDim());
4198  }
4199  }
4200 }
virtual int32 OutputDim() const
Returns output-dimension of this component.
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
std::vector< Component * > components_
virtual int32 InputDim() const
Returns input-dimension of this component.
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ InitFromConfig()

void InitFromConfig ( ConfigLine cfl)
virtual

Initialize, from a ConfigLine object.

Parameters
[in]cflA ConfigLine containing any parameters that are needed for initialization. For example: "dim=100 param-stddev=0.1"

Implements Component.

Definition at line 4578 of file nnet-simple-component.cc.

References kaldi::DeletePointers(), ConfigLine::FirstToken(), ConfigLine::GetValue(), ConfigLine::HasUnusedValues(), rnnlm::i, NaturalGradientPerElementScaleComponent::Init(), Component::InitFromConfig(), KALDI_ERR, kaldi::nnet3::kRandomComponent, kaldi::nnet3::kSimpleComponent, Component::NewComponentOfType(), ConfigLine::ParseLine(), Component::Properties(), Component::Type(), ConfigLine::UnusedValues(), and ConfigLine::WholeLine().

4578  {
4579  int32 max_rows_process = 4096, num_components = -1;
4580  cfl->GetValue("max-rows-process", &max_rows_process);
4581  if (!cfl->GetValue("num-components", &num_components) ||
4582  num_components < 1)
4583  KALDI_ERR << "Expected num-components to be defined in "
4584  << "CompositeComponent config line '" << cfl->WholeLine() << "'";
4585  std::vector<Component*> components;
4586  for (int32 i = 1; i <= num_components; i++) {
4587  std::ostringstream name_stream;
4588  name_stream << "component" << i;
4589  std::string component_config;
4590  if (!cfl->GetValue(name_stream.str(), &component_config)) {
4591  DeletePointers(&components);
4592  KALDI_ERR << "Expected '" << name_stream.str() << "' to be defined in "
4593  << "CompositeComponent config line '" << cfl->WholeLine() << "'";
4594  }
4595  ConfigLine nested_line;
4596  // note: the nested line may not contain comments.
4597  std::string component_type;
4598  Component *this_component = NULL;
4599  if (!nested_line.ParseLine(component_config) ||
4600  !nested_line.GetValue("type", &component_type) ||
4601  !(this_component = NewComponentOfType(component_type)) ||
4602  nested_line.FirstToken() != "") {
4603  DeletePointers(&components);
4604  KALDI_ERR << "Could not parse config line for '" << name_stream.str()
4605  << "(or undefined or bad component type [type=xxx]), in "
4606  << "CompositeComponent config line '" << cfl->WholeLine() << "'";
4607  }
4608  if(this_component->Type() == "CompositeComponent") {
4609  DeletePointers(&components);
4610  delete this_component;
4611  // This is not allowed. If memory is too much with just one
4612  // CompositeComponent, try decreasing max-rows-process instead.
4613  KALDI_ERR << "Found CompositeComponent nested within CompositeComponent."
4614  << "Nested line: '" << nested_line.WholeLine() << "'\n"
4615  << "Toplevel CompositeComponent line '" << cfl->WholeLine()
4616  << "'";
4617  }
4618  this_component->InitFromConfig(&nested_line);
4619  int32 props = this_component->Properties();
4620  if ((props & kRandomComponent) != 0 ||
4621  (props & kSimpleComponent) == 0) {
4622  KALDI_ERR << "CompositeComponent contains disallowed component type: "
4623  << nested_line.WholeLine();
4624  }
4625  components.push_back(this_component);
4626  }
4627  if (cfl->HasUnusedValues())
4628  KALDI_ERR << "Could not process these elements in initializer: "
4629  << cfl->UnusedValues();
4630  this->Init(components, max_rows_process);
4631 }
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147
void Init(const std::vector< Component *> &components, int32 max_rows_process)
static Component * NewComponentOfType(const std::string &type)
Returns a new Component of the given type e.g.

◆ InputDim()

int32 InputDim ( ) const
virtual

Returns input-dimension of this component.

Implements Component.

Definition at line 4091 of file nnet-simple-component.cc.

References KALDI_ASSERT.

4091  {
4092  KALDI_ASSERT(!components_.empty());
4093  return components_.front()->InputDim();
4094 }
std::vector< Component * > components_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ IsUpdatable()

bool IsUpdatable ( ) const
private

Definition at line 4082 of file nnet-simple-component.cc.

References kaldi::nnet3::kUpdatableComponent.

4082  {
4083  for (std::vector<Component*>::const_iterator iter = components_.begin(),
4084  end = components_.end(); iter != end; ++iter)
4085  if (((*iter)->Properties() & kUpdatableComponent) != 0)
4086  return true;
4087  return false;
4088 }
std::vector< Component * > components_

◆ NumComponents()

int32 NumComponents ( ) const
inline

◆ NumParameters()

int32 NumParameters ( ) const
virtual

The following new virtual function returns the total dimension of the parameters in this class.

Reimplemented from UpdatableComponent.

Definition at line 4486 of file nnet-simple-component.cc.

References rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kUpdatableComponent, UpdatableComponent::NumParameters(), and PerElementScaleComponent::Properties().

4486  {
4487  KALDI_ASSERT(this->IsUpdatable()); // or should not be called.
4488  int32 ans = 0;
4489  for (size_t i = 0; i < components_.size(); i++) {
4491  UpdatableComponent *uc =
4492  dynamic_cast<UpdatableComponent*>(components_[i]);
4493  ans += uc->NumParameters();
4494  }
4495  }
4496  return ans;
4497 }
kaldi::int32 int32
std::vector< Component * > components_
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ OutputDim()

int32 OutputDim ( ) const
virtual

Returns output-dimension of this component.

Implements Component.

Definition at line 4097 of file nnet-simple-component.cc.

References KALDI_ASSERT.

4097  {
4098  KALDI_ASSERT(!components_.empty());
4099  return components_.back()->OutputDim();
4100 }
std::vector< Component * > components_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ PerturbParams()

void PerturbParams ( BaseFloat  stddev)
virtual

This function is to be used in testing.

It adds unit noise times "stddev" to the parameters of the component.

Implements UpdatableComponent.

Definition at line 4433 of file nnet-simple-component.cc.

References rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kUpdatableComponent, UpdatableComponent::PerturbParams(), and PerElementScaleComponent::Properties().

4433  {
4434  KALDI_ASSERT(this->IsUpdatable()); // or should not be called.
4435  for (size_t i = 0; i < components_.size(); i++) {
4437  UpdatableComponent *uc =
4438  dynamic_cast<UpdatableComponent*>(components_[i]);
4439  uc->PerturbParams(stddev);
4440  }
4441  }
4442 }
std::vector< Component * > components_
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Propagate()

void * Propagate ( const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
) const
virtual

Propagate function.

Parameters
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]inThe input to this component. Num-columns == InputDim().
[out]outThe output of this component. Num-columns == OutputDim(). Note: output of this component will be added to the initial value of "out" if Properties()&kPropagateAdds != 0; otherwise the output will be set and the initial value ignored. Each Component chooses whether it is more convenient implementation-wise to add or set, and the calling code has to deal with it.
Returns
Normally returns NULL, but may return a non-NULL value for components which have the flag kUsesMemo set. This value will be passed into the corresponding Backprop routine.

Implements Component.

Definition at line 4137 of file nnet-simple-component.cc.

References rnnlm::i, PerElementScaleComponent::InputDim(), KALDI_ASSERT, kaldi::nnet3::kPropagateAdds, kaldi::kSetZero, kaldi::kUndefined, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), PerElementScaleComponent::OutputDim(), and PerElementScaleComponent::Propagate().

4140  {
4141  KALDI_ASSERT(in.NumRows() == out->NumRows() && in.NumCols() == InputDim() &&
4142  out->NumCols() == OutputDim());
4143  int32 num_rows = in.NumRows(),
4144  num_components = components_.size();
4145  if (max_rows_process_ > 0 && num_rows > max_rows_process_) {
4146  // recurse and process smaller parts of the data, to save memory.
4147  for (int32 row_offset = 0; row_offset < num_rows;
4148  row_offset += max_rows_process_) {
4149  int32 this_num_rows = std::min<int32>(max_rows_process_,
4150  num_rows - row_offset);
4151  const CuSubMatrix<BaseFloat> in_part(in, row_offset, this_num_rows,
4152  0, in.NumCols());
4153  CuSubMatrix<BaseFloat> out_part(*out, row_offset, this_num_rows,
4154  0, out->NumCols());
4155  this->Propagate(NULL, in_part, &out_part);
4156  }
4157  return NULL;
4158  }
4159  std::vector<CuMatrix<BaseFloat> > intermediate_outputs(num_components - 1);
4160  for (int32 i = 0; i < num_components; i++) {
4161  if (i + 1 < num_components) {
4162  MatrixResizeType resize_type =
4163  ((components_[i]->Properties() & kPropagateAdds) ?
4164  kSetZero : kUndefined);
4165  intermediate_outputs[i].Resize(num_rows, components_[i]->OutputDim(),
4166  resize_type, GetStrideType(i));
4167  }
4168  const CuMatrixBase<BaseFloat> &this_in = (i == 0 ? in :
4169  intermediate_outputs[i-1]);
4170  CuMatrixBase<BaseFloat> *this_out = (i + 1 == num_components ?
4171  out : &(intermediate_outputs[i]));
4172  void *memo = components_[i]->Propagate(NULL, this_in, this_out);
4173  // we'll re-do the forward propagation in the backprop, and we can
4174  // regenerate any memos there, so no need to keep them.
4175  if (memo != NULL)
4176  components_[i]->DeleteMemo(memo);
4177  if (i > 0)
4178  intermediate_outputs[i-1].Resize(0, 0);
4179  }
4180  return NULL;
4181 }
virtual int32 OutputDim() const
Returns output-dimension of this component.
MatrixResizeType
Definition: matrix-common.h:37
kaldi::int32 int32
MatrixStrideType GetStrideType(int32 i) const
std::vector< Component * > components_
virtual int32 InputDim() const
Returns input-dimension of this component.
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
Propagate function.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Properties()

int32 Properties ( ) const
virtual

Return bitmask of the component's properties.

These properties depend only on the component's type. See enum ComponentProperties.

Implements Component.

Definition at line 4103 of file nnet-simple-component.cc.

References KALDI_ASSERT, kaldi::nnet3::kBackpropAdds, kaldi::nnet3::kBackpropNeedsInput, kaldi::nnet3::kBackpropNeedsOutput, kaldi::nnet3::kInputContiguous, kaldi::nnet3::kOutputContiguous, kaldi::nnet3::kPropagateAdds, kaldi::nnet3::kSimpleComponent, kaldi::nnet3::kStoresStats, and kaldi::nnet3::kUpdatableComponent.

4103  {
4104  KALDI_ASSERT(!components_.empty());
4105  int32 last_component_properties = components_.back()->Properties(),
4106  first_component_properties = components_.front()->Properties();
4107  // We always assume backprop needs the input, as this would be necessary to
4108  // get the activations at intermediate layers, if these were not needed in
4109  // backprop, there would be no reason to use a CompositeComponent.
4111  (last_component_properties &
4113  (first_component_properties &
4115  (IsUpdatable() ? kUpdatableComponent : 0);
4116  // note, we don't return the kStoresStats property because that function is
4117  // not implemented; instead, for efficiency, we call StoreStats() on any
4118  // sub-components as part of the backprop phase.
4119  if (last_component_properties & kStoresStats)
4120  ans |= kBackpropNeedsOutput;
4121  return ans;
4122 }
kaldi::int32 int32
std::vector< Component * > components_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Read()

void Read ( std::istream &  is,
bool  binary 
)
virtual

Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed.

Implements Component.

Definition at line 4203 of file nnet-simple-component.cc.

References kaldi::nnet3::ExpectToken(), rnnlm::i, NaturalGradientPerElementScaleComponent::Init(), UpdatableComponent::is_gradient_, KALDI_ERR, UpdatableComponent::learning_rate_, UpdatableComponent::learning_rate_factor_, kaldi::ReadBasicType(), Component::ReadNew(), kaldi::ReadToken(), and UpdatableComponent::ReadUpdatableCommon().

4203  {
4204  // Because we didn't previously write out the learning rate,
4205  // we need some temporary code.
4206  int32 max_rows_process;
4207  if (false) {
4208  ReadUpdatableCommon(is, binary);
4209  ExpectToken(is, binary, "<MaxRowsProcess>");
4210  ReadBasicType(is, binary, &max_rows_process);
4211  } else { // temporary code.
4212  std::string token;
4213  ReadToken(is, binary, &token);
4214  if (token == "<CompositeComponent>") {
4215  // if the first token is the opening tag, then
4216  // ignore it and get the next tag.
4217  ReadToken(is, binary, &token);
4218  }
4219  if (token == "<LearningRateFactor>") {
4220  ReadBasicType(is, binary, &learning_rate_factor_);
4221  ReadToken(is, binary, &token);
4222  } else {
4223  learning_rate_factor_ = 1.0;
4224  }
4225  if (token == "<IsGradient>") {
4226  ReadBasicType(is, binary, &is_gradient_);
4227  ReadToken(is, binary, &token);
4228  } else {
4229  is_gradient_ = false;
4230  }
4231  if (token == "<LearningRate>") {
4232  ReadBasicType(is, binary, &learning_rate_);
4233  ReadToken(is, binary, &token);
4234  }
4235  if (token != "<MaxRowsProcess>") {
4236  KALDI_ERR << "Expected token <MaxRowsProcess>, got "
4237  << token;
4238  }
4239  ReadBasicType(is, binary, &max_rows_process);
4240  }
4241  ExpectToken(is, binary, "<NumComponents>");
4242  int32 num_components;
4243  ReadBasicType(is, binary, &num_components); // Read dimension.
4244  if (num_components < 0 || num_components > 100000)
4245  KALDI_ERR << "Bad num-components";
4246  std::vector<Component*> components(num_components);
4247  for (int32 i = 0; i < num_components; i++)
4248  components[i] = ReadNew(is, binary);
4249  Init(components, max_rows_process);
4250  ExpectToken(is, binary, "</CompositeComponent>");
4251 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
BaseFloat learning_rate_
learning rate (typically 0.0..0.01)
std::string ReadUpdatableCommon(std::istream &is, bool binary)
#define KALDI_ERR
Definition: kaldi-error.h:147
static Component * ReadNew(std::istream &is, bool binary)
Read component from stream (works out its type). Dies on error.
void Init(const std::vector< Component *> &components, int32 max_rows_process)
BaseFloat learning_rate_factor_
learning rate factor (normally 1.0, but can be set to another < value so that when < you call SetLear...
bool is_gradient_
True if this component is to be treated as a gradient rather than as parameters.

◆ Scale()

void Scale ( BaseFloat  scale)
virtual

This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent.

– a Nonlinear component (or another component that stores stats, like BatchNormComponent)– it relates to scaling activation stats, not parameters. Otherwise it will normally do nothing.

Reimplemented from Component.

Definition at line 4417 of file nnet-simple-component.cc.

References rnnlm::i, and PerElementScaleComponent::Scale().

4417  {
4418  for (size_t i = 0; i < components_.size(); i++)
4419  components_[i]->Scale(scale);
4420 }
std::vector< Component * > components_
virtual void Scale(BaseFloat scale)
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when c...

◆ SetActualLearningRate()

void SetActualLearningRate ( BaseFloat  lrate)
virtual

Sets the learning rate directly, bypassing learning_rate_factor_.

Reimplemented from UpdatableComponent.

Definition at line 4460 of file nnet-simple-component.cc.

References rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kUpdatableComponent, PerElementScaleComponent::Properties(), and UpdatableComponent::SetActualLearningRate().

4460  {
4461  KALDI_ASSERT(this->IsUpdatable()); // or should not be called.
4463  for (size_t i = 0; i < components_.size(); i++) {
4465  UpdatableComponent *uc =
4466  dynamic_cast<UpdatableComponent*>(components_[i]);
4467  uc->SetActualLearningRate(lrate);
4468  }
4469  }
4470 }
std::vector< Component * > components_
virtual void SetActualLearningRate(BaseFloat lrate)
Sets the learning rate directly, bypassing learning_rate_factor_.
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetAsGradient()

void SetAsGradient ( )
virtual

Sets is_gradient_ to true and sets learning_rate_ to 1, ignoring learning_rate_factor_.

Reimplemented from UpdatableComponent.

Definition at line 4473 of file nnet-simple-component.cc.

References rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kUpdatableComponent, PerElementScaleComponent::Properties(), and UpdatableComponent::SetAsGradient().

4473  {
4474  KALDI_ASSERT(this->IsUpdatable()); // or should not be called.
4476  for (size_t i = 0; i < components_.size(); i++) {
4478  UpdatableComponent *uc =
4479  dynamic_cast<UpdatableComponent*>(components_[i]);
4480  uc->SetAsGradient();
4481  }
4482  }
4483 }
std::vector< Component * > components_
virtual void SetAsGradient()
Sets is_gradient_ to true and sets learning_rate_ to 1, ignoring learning_rate_factor_.
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetComponent()

void SetComponent ( int32  i,
Component component 
)

Sets the ith component.

After this call, CompositeComponent owns the reference to the argument component. Frees the previous ith component.

Definition at line 4638 of file nnet-simple-component.cc.

References rnnlm::i, and KALDI_ASSERT.

Referenced by kaldi::nnet3::ConvertRepeatedToBlockAffine().

4638  {
4639  KALDI_ASSERT(static_cast<size_t>(i) < components_.size());
4640  delete components_[i];
4641  components_[i] = component;
4642 }
std::vector< Component * > components_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetUnderlyingLearningRate()

void SetUnderlyingLearningRate ( BaseFloat  lrate)
virtual

Sets the learning rate of gradient descent- gets multiplied by learning_rate_factor_.

Reimplemented from UpdatableComponent.

Definition at line 4444 of file nnet-simple-component.cc.

References rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kUpdatableComponent, UpdatableComponent::LearningRate(), PerElementScaleComponent::Properties(), and UpdatableComponent::SetUnderlyingLearningRate().

4444  {
4445  KALDI_ASSERT(this->IsUpdatable()); // or should not be called.
4447 
4448  // apply any learning-rate-factor that's set at this level (ill-advised, but
4449  // we'll do it.)
4450  BaseFloat effective_lrate = LearningRate();
4451  for (size_t i = 0; i < components_.size(); i++) {
4453  UpdatableComponent *uc =
4454  dynamic_cast<UpdatableComponent*>(components_[i]);
4455  uc->SetUnderlyingLearningRate(effective_lrate);
4456  }
4457  }
4458 }
float BaseFloat
Definition: kaldi-types.h:29
std::vector< Component * > components_
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual void SetUnderlyingLearningRate(BaseFloat lrate)
Sets the learning rate of gradient descent- gets multiplied by learning_rate_factor_.
BaseFloat LearningRate() const
Gets the learning rate to be used in gradient descent.

◆ Type()

virtual std::string Type ( ) const
inlinevirtual

◆ UnVectorize()

void UnVectorize ( const VectorBase< BaseFloat > &  params)
virtual

Converts the parameters from vector form.

Reimplemented from UpdatableComponent.

Definition at line 4517 of file nnet-simple-component.cc.

References VectorBase< Real >::Dim(), rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kUpdatableComponent, UpdatableComponent::NumParameters(), PerElementScaleComponent::Properties(), and UpdatableComponent::UnVectorize().

4517  {
4518  int32 cur_offset = 0;
4519  KALDI_ASSERT(this->IsUpdatable()); // or should not be called.
4520  for (size_t i = 0; i < components_.size(); i++) {
4522  UpdatableComponent *uc =
4523  dynamic_cast<UpdatableComponent*>(components_[i]);
4524  int32 this_size = uc->NumParameters();
4525  SubVector<BaseFloat> params_range(params, cur_offset, this_size);
4526  uc->UnVectorize(params_range);
4527  cur_offset += this_size;
4528  }
4529  }
4530  KALDI_ASSERT(cur_offset == params.Dim());
4531 }
kaldi::int32 int32
std::vector< Component * > components_
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Vectorize()

void Vectorize ( VectorBase< BaseFloat > *  params) const
virtual

Turns the parameters into vector form.

We put the vector form on the CPU, because in the kinds of situations where we do this, we'll tend to use too much memory for the GPU.

Reimplemented from UpdatableComponent.

Definition at line 4500 of file nnet-simple-component.cc.

References VectorBase< Real >::Dim(), rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kUpdatableComponent, UpdatableComponent::NumParameters(), PerElementScaleComponent::Properties(), and UpdatableComponent::Vectorize().

4500  {
4501  int32 cur_offset = 0;
4502  KALDI_ASSERT(this->IsUpdatable()); // or should not be called.
4503  for (size_t i = 0; i < components_.size(); i++) {
4505  UpdatableComponent *uc =
4506  dynamic_cast<UpdatableComponent*>(components_[i]);
4507  int32 this_size = uc->NumParameters();
4508  SubVector<BaseFloat> params_range(*params, cur_offset, this_size);
4509  uc->Vectorize(&params_range);
4510  cur_offset += this_size;
4511  }
4512  }
4513  KALDI_ASSERT(cur_offset == params->Dim());
4514 }
kaldi::int32 int32
std::vector< Component * > components_
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write component to stream.

Implements Component.

Definition at line 4263 of file nnet-simple-component.cc.

References rnnlm::i, NaturalGradientPerElementScaleComponent::Write(), kaldi::WriteBasicType(), kaldi::WriteToken(), and UpdatableComponent::WriteUpdatableCommon().

4263  {
4264  WriteUpdatableCommon(os, binary); // Write opening tag and learning rate.
4265  WriteToken(os, binary, "<MaxRowsProcess>");
4266  WriteBasicType(os, binary, max_rows_process_);
4267  WriteToken(os, binary, "<NumComponents>");
4268  int32 num_components = components_.size();
4269  WriteBasicType(os, binary, num_components);
4270  for (int32 i = 0; i < num_components; i++)
4271  components_[i]->Write(os, binary);
4272  WriteToken(os, binary, "</CompositeComponent>");
4273 }
kaldi::int32 int32
std::vector< Component * > components_
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteUpdatableCommon(std::ostream &is, bool binary) const
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
virtual void Write(std::ostream &os, bool binary) const
Write component to stream.

◆ ZeroStats()

void ZeroStats ( )
virtual

Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero.

Other components that store other types of statistics (e.g. regarding gradient clipping) should implement ZeroStats() also.

Reimplemented from Component.

Definition at line 4254 of file nnet-simple-component.cc.

References rnnlm::i, and Component::ZeroStats().

4254  {
4255  // we call ZeroStats() on all components without checking their flags; this
4256  // will do nothing if the component doesn't store stats. (components like
4257  // ReLU and sigmoid and tanh store stats on activations).
4258  for (size_t i = 0; i < components_.size(); i++)
4259  components_[i]->ZeroStats();
4260 }
virtual void ZeroStats()
Components that provide an implementation of StoreStats should also provide an implementation of Zero...
std::vector< Component * > components_

Member Data Documentation

◆ components_

std::vector<Component*> components_
private

◆ max_rows_process_

int32 max_rows_process_
private

Definition at line 2061 of file nnet-simple-component.h.


The documentation for this class was generated from the following files: