All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
UpdatableComponent Class Referenceabstract

Class UpdatableComponent is a Component which has trainable parameters, it contains SGD training hyper-parameters in NnetTrainOptions. More...

#include <nnet-component.h>

Inheritance diagram for UpdatableComponent:
Collaboration diagram for UpdatableComponent:

Public Member Functions

 UpdatableComponent (int32 input_dim, int32 output_dim)
 
virtual ~UpdatableComponent ()
 
bool IsUpdatable () const
 Check if contains trainable parameters,. More...
 
virtual int32 NumParams () const =0
 Number of trainable parameters,. More...
 
virtual void GetGradient (VectorBase< BaseFloat > *gradient) const =0
 Get gradient reshaped as a vector,. More...
 
virtual void GetParams (VectorBase< BaseFloat > *params) const =0
 Get the trainable parameters reshaped as a vector,. More...
 
virtual void SetParams (const VectorBase< BaseFloat > &params)=0
 Set the trainable parameters from, reshaped as a vector,. More...
 
virtual void Update (const CuMatrixBase< BaseFloat > &input, const CuMatrixBase< BaseFloat > &diff)=0
 Compute gradient and update parameters,. More...
 
virtual void SetTrainOptions (const NnetTrainOptions &opts)
 Set the training options to the component,. More...
 
const NnetTrainOptionsGetTrainOptions () const
 Get the training options from the component,. More...
 
virtual void SetLearnRateCoef (BaseFloat val)
 Set the learn-rate coefficient,. More...
 
virtual void SetBiasLearnRateCoef (BaseFloat val)
 Set the learn-rate coefficient for bias,. More...
 
virtual void InitData (std::istream &is)=0
 Initialize the content of the component by the 'line' from the prototype,. More...
 
- Public Member Functions inherited from Component
 Component (int32 input_dim, int32 output_dim)
 Generic interface of a component,. More...
 
virtual ~Component ()
 
virtual ComponentCopy () const =0
 Copy component (deep copy),. More...
 
virtual ComponentType GetType () const =0
 Get Type Identification of the component,. More...
 
virtual bool IsMultistream () const
 Check if component has 'Recurrent' interface (trainable and recurrent),. More...
 
int32 InputDim () const
 Get the dimension of the input,. More...
 
int32 OutputDim () const
 Get the dimension of the output,. More...
 
void Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
 Perform forward-pass propagation 'in' -> 'out',. More...
 
void Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
 Perform backward-pass propagation 'out_diff' -> 'in_diff'. More...
 
void Write (std::ostream &os, bool binary) const
 Write the component to a stream,. More...
 
virtual std::string Info () const
 Print some additional info (after <ComponentName> and the dims),. More...
 
virtual std::string InfoGradient () const
 Print some additional info about gradient (after <...> and dims),. More...
 

Protected Attributes

NnetTrainOptions opts_
 Option-class with training hyper-parameters,. More...
 
BaseFloat learn_rate_coef_
 Scalar applied to learning rate for weight matrices (to be used in ::Update method),. More...
 
BaseFloat bias_learn_rate_coef_
 Scalar applied to learning rate for bias (to be used in ::Update method),. More...
 
- Protected Attributes inherited from Component
int32 input_dim_
 Data members,. More...
 
int32 output_dim_
 Dimension of the output of the Component,. More...
 

Additional Inherited Members

- Public Types inherited from Component
enum  ComponentType {
  kUnknown = 0x0, kUpdatableComponent = 0x0100, kAffineTransform, kLinearTransform,
  kConvolutionalComponent, kConvolutional2DComponent, kLstmProjected, kBlstmProjected,
  kRecurrentComponent, kActivationFunction = 0x0200, kSoftmax, kHiddenSoftmax,
  kBlockSoftmax, kSigmoid, kTanh, kParametricRelu,
  kDropout, kLengthNormComponent, kTranform = 0x0400, kRbm,
  kSplice, kCopy, kTranspose, kBlockLinearity,
  kAddShift, kRescale, kKlHmm = 0x0800, kSentenceAveragingComponent,
  kSimpleSentenceAveragingComponent, kAveragePoolingComponent, kAveragePooling2DComponent, kMaxPoolingComponent,
  kMaxPooling2DComponent, kFramePoolingComponent, kParallelComponent, kMultiBasisComponent
}
 Component type identification mechanism,. More...
 
- Static Public Member Functions inherited from Component
static const char * TypeToMarker (ComponentType t)
 Converts component type to marker,. More...
 
static ComponentType MarkerToType (const std::string &s)
 Converts marker to component type (case insensitive),. More...
 
static ComponentInit (const std::string &conf_line)
 Initialize component from a line in config file,. More...
 
static ComponentRead (std::istream &is, bool binary)
 Read the component from a stream (static method),. More...
 
- Static Public Attributes inherited from Component
static const struct key_value kMarkerMap []
 The table with pairs of Component types and markers (defined in nnet-component.cc),. More...
 
- Protected Member Functions inherited from Component
virtual void PropagateFnc (const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)=0
 Abstract interface for propagation/backpropagation. More...
 
virtual void BackpropagateFnc (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)=0
 Backward pass transformation (to be implemented by descending class...) More...
 
virtual void ReadData (std::istream &is, bool binary)
 Reads the component content. More...
 
virtual void WriteData (std::ostream &os, bool binary) const
 Writes the component content. More...
 

Detailed Description

Class UpdatableComponent is a Component which has trainable parameters, it contains SGD training hyper-parameters in NnetTrainOptions.

The constants 'learning_rate_coef_' and 'bias_learn_rate_coef_' are separate, and should be stored by ::WriteData(...),

Definition at line 211 of file nnet-component.h.

Constructor & Destructor Documentation

UpdatableComponent ( int32  input_dim,
int32  output_dim 
)
inline

Definition at line 213 of file nnet-component.h.

213  :
214  Component(input_dim, output_dim),
215  learn_rate_coef_(1.0),
217  { }
BaseFloat bias_learn_rate_coef_
Scalar applied to learning rate for bias (to be used in ::Update method),.
BaseFloat learn_rate_coef_
Scalar applied to learning rate for weight matrices (to be used in ::Update method),.
Component(int32 input_dim, int32 output_dim)
Generic interface of a component,.
virtual ~UpdatableComponent ( )
inlinevirtual

Definition at line 219 of file nnet-component.h.

220  { }

Member Function Documentation

const NnetTrainOptions& GetTrainOptions ( ) const
inline

Get the training options from the component,.

Definition at line 249 of file nnet-component.h.

References UpdatableComponent::opts_.

249  {
250  return opts_;
251  }
NnetTrainOptions opts_
Option-class with training hyper-parameters,.
virtual void InitData ( std::istream &  is)
pure virtual
bool IsUpdatable ( ) const
inlinevirtual

Check if contains trainable parameters,.

Reimplemented from Component.

Definition at line 223 of file nnet-component.h.

223  {
224  return true;
225  }
virtual void SetBiasLearnRateCoef ( BaseFloat  val)
inlinevirtual

Set the learn-rate coefficient for bias,.

Reimplemented in MultiBasisComponent, and ParallelComponent.

Definition at line 259 of file nnet-component.h.

References UpdatableComponent::bias_learn_rate_coef_.

Referenced by main(), ParallelComponent::SetBiasLearnRateCoef(), and MultiBasisComponent::SetBiasLearnRateCoef().

259  {
260  bias_learn_rate_coef_ = val;
261  }
BaseFloat bias_learn_rate_coef_
Scalar applied to learning rate for bias (to be used in ::Update method),.
virtual void SetLearnRateCoef ( BaseFloat  val)
inlinevirtual

Set the learn-rate coefficient,.

Reimplemented in Rescale, MultiBasisComponent, AddShift, and ParallelComponent.

Definition at line 254 of file nnet-component.h.

References UpdatableComponent::learn_rate_coef_.

Referenced by main(), ParallelComponent::SetLearnRateCoef(), and MultiBasisComponent::SetLearnRateCoef().

254  {
255  learn_rate_coef_ = val;
256  }
BaseFloat learn_rate_coef_
Scalar applied to learning rate for weight matrices (to be used in ::Update method),.
virtual void SetParams ( const VectorBase< BaseFloat > &  params)
pure virtual
virtual void SetTrainOptions ( const NnetTrainOptions opts)
inlinevirtual

Set the training options to the component,.

Reimplemented in MultiBasisComponent, ParallelComponent, and SentenceAveragingComponent.

Definition at line 244 of file nnet-component.h.

References UpdatableComponent::opts_.

Referenced by SentenceAveragingComponent::SetTrainOptions().

244  {
245  opts_ = opts;
246  }
NnetTrainOptions opts_
Option-class with training hyper-parameters,.

Member Data Documentation

BaseFloat learn_rate_coef_
protected

Scalar applied to learning rate for weight matrices (to be used in ::Update method),.

Definition at line 272 of file nnet-component.h.

Referenced by LinearTransform::Info(), AffineTransform::Info(), ConvolutionalComponent::Info(), Convolutional2DComponent::Info(), LstmProjected::Info(), AddShift::Info(), BlstmProjected::Info(), Rescale::Info(), LinearTransform::InfoGradient(), AffineTransform::InfoGradient(), RecurrentComponent::InfoGradient(), FramePoolingComponent::InfoGradient(), ConvolutionalComponent::InfoGradient(), Convolutional2DComponent::InfoGradient(), LstmProjected::InfoGradient(), AddShift::InfoGradient(), BlstmProjected::InfoGradient(), Rescale::InfoGradient(), AffineTransform::InitData(), LinearTransform::InitData(), RecurrentComponent::InitData(), FramePoolingComponent::InitData(), LstmProjected::InitData(), BlstmProjected::InitData(), ConvolutionalComponent::InitData(), Convolutional2DComponent::InitData(), AddShift::InitData(), Rescale::InitData(), AffineTransform::ReadData(), RecurrentComponent::ReadData(), LinearTransform::ReadData(), LstmProjected::ReadData(), FramePoolingComponent::ReadData(), ConvolutionalComponent::ReadData(), BlstmProjected::ReadData(), Convolutional2DComponent::ReadData(), AddShift::ReadData(), Rescale::ReadData(), UpdatableComponent::SetLearnRateCoef(), AddShift::SetLearnRateCoef(), Rescale::SetLearnRateCoef(), LinearTransform::Update(), AffineTransform::Update(), FramePoolingComponent::Update(), RecurrentComponent::Update(), AddShift::Update(), ConvolutionalComponent::Update(), Convolutional2DComponent::Update(), Rescale::Update(), LstmProjected::Update(), BlstmProjected::Update(), AffineTransform::WriteData(), RecurrentComponent::WriteData(), LinearTransform::WriteData(), FramePoolingComponent::WriteData(), LstmProjected::WriteData(), ConvolutionalComponent::WriteData(), Convolutional2DComponent::WriteData(), BlstmProjected::WriteData(), AddShift::WriteData(), and Rescale::WriteData().


The documentation for this class was generated from the following file: