All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Component Class Referenceabstract

Abstract class, basic element of the network, it is a box with defined inputs, outputs, and tranformation functions interface. More...

#include <nnet-component.h>

Inheritance diagram for Component:
Collaboration diagram for Component:

Public Member Functions

 Component ()
 
virtual std::string Type () const =0
 
virtual int32 Index () const
 Returns the index in the sequence of layers in the neural net; intended only to be used in debugging information. More...
 
virtual void SetIndex (int32 index)
 
virtual void InitFromString (std::string args)=0
 Initialize, typically from a line of a config file. More...
 
virtual int32 InputDim () const =0
 Get size of input vectors. More...
 
virtual int32 OutputDim () const =0
 Get size of output vectors. More...
 
virtual std::vector< int32 > Context () const
 Return a vector describing the temporal context this component requires for each frame of output, as a sorted list. More...
 
virtual void Propagate (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const =0
 Perform forward pass propagation Input->Output. More...
 
void Propagate (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out) const
 A non-virtual propagate function that first resizes output if necessary. More...
 
virtual void Backprop (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, Component *to_update, CuMatrix< BaseFloat > *in_deriv) const =0
 Perform backward pass propagation of the derivative, and also either update the model (if to_update == this) or update another model or compute the model derivative (otherwise). More...
 
virtual bool BackpropNeedsInput () const
 
virtual bool BackpropNeedsOutput () const
 
virtual ComponentCopy () const =0
 Copy component (deep copy). More...
 
virtual void Read (std::istream &is, bool binary)=0
 
virtual void Write (std::ostream &os, bool binary) const =0
 Write component to stream. More...
 
virtual std::string Info () const
 
virtual ~Component ()
 

Static Public Member Functions

static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream. More...
 
static ComponentNewFromString (const std::string &initializer_line)
 Initialize the Component from one line that will contain first the type, e.g. More...
 
static ComponentNewComponentOfType (const std::string &type)
 Return a new Component of the given type e.g. More...
 

Private Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (Component)
 

Private Attributes

int32 index_
 

Detailed Description

Abstract class, basic element of the network, it is a box with defined inputs, outputs, and tranformation functions interface.

It is able to propagate and backpropagate exact implementation is to be implemented in descendants.

Definition at line 157 of file nnet-component.h.

Constructor & Destructor Documentation

Component ( )
inline

Definition at line 159 of file nnet-component.h.

159 : index_(-1) { }
virtual ~Component ( )
inlinevirtual

Definition at line 264 of file nnet-component.h.

264 { }

Member Function Documentation

virtual void Backprop ( const ChunkInfo in_info,
const ChunkInfo out_info,
const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
const CuMatrixBase< BaseFloat > &  out_deriv,
Component to_update,
CuMatrix< BaseFloat > *  in_deriv 
) const
pure virtual

Perform backward pass propagation of the derivative, and also either update the model (if to_update == this) or update another model or compute the model derivative (otherwise).

Note: in_value and out_value are the values of the input and output of the component, and these may be dummy variables if respectively BackpropNeedsInput() or BackpropNeedsOutput() return false for that component (not all components need these).

num_chunks lets us treat the input matrix as contiguous-in-time chunks of equal size; it only matters if splicing is involved.

Implemented in Convolutional1dComponent, AdditiveNoiseComponent, DropoutComponent, FixedBiasComponent, FixedScaleComponent, FixedAffineComponent, FixedLinearComponent, DctComponent, PermuteComponent, SumGroupComponent, BlockAffineComponent, SpliceMaxComponent, SpliceComponent, AffineComponent, LogSoftmaxComponent, SoftmaxComponent, ScaleComponent, SoftHingeComponent, RectifiedLinearComponent, PowerComponent, TanhComponent, SigmoidComponent, NormalizeComponent, PnormComponent, MaxpoolingComponent, and MaxoutComponent.

Referenced by NnetComputer::Backprop(), NnetDiscriminativeUpdater::Backprop(), NnetUpdater::Backprop(), NnetRescaler::RescaleComponent(), and kaldi::nnet2::UnitTestGenericComponentInternal().

virtual std::vector<int32> Context ( ) const
inlinevirtual

Return a vector describing the temporal context this component requires for each frame of output, as a sorted list.

The default implementation returns a vector ( 0 ), but a splicing layer might return e.g. (-2, -1, 0, 1, 2), but it doesn't have to be contiguous. Note : The context needed by the entire network is a function of the contexts needed by all the components. It is required that Context().front() <= 0 and Context().back() >= 0.

Reimplemented in SpliceMaxComponent, and SpliceComponent.

Definition at line 188 of file nnet-component.h.

Referenced by Nnet::ComputeChunkInfo(), and NnetOnlineComputer::Propagate().

188 { return std::vector<int32>(1, 0); }
virtual int32 Index ( ) const
inlinevirtual

Returns the index in the sequence of layers in the neural net; intended only to be used in debugging information.

Definition at line 166 of file nnet-component.h.

References Component::index_.

Referenced by AffineComponentPreconditioned::GetScalingFactor(), and AffineComponentPreconditionedOnline::GetScalingFactor().

166 { return index_; }
KALDI_DISALLOW_COPY_AND_ASSIGN ( Component  )
private
Component * NewComponentOfType ( const std::string &  type)
static

Return a new Component of the given type e.g.

"SoftmaxComponent", or NULL if no such type exists.

Definition at line 51 of file nnet-component.cc.

Referenced by Component::NewFromString(), and Component::ReadNew().

51  {
52  Component *ans = NULL;
53  if (component_type == "SigmoidComponent") {
54  ans = new SigmoidComponent();
55  } else if (component_type == "TanhComponent") {
56  ans = new TanhComponent();
57  } else if (component_type == "PowerComponent") {
58  ans = new PowerComponent();
59  } else if (component_type == "SoftmaxComponent") {
60  ans = new SoftmaxComponent();
61  } else if (component_type == "LogSoftmaxComponent") {
62  ans = new LogSoftmaxComponent();
63  } else if (component_type == "RectifiedLinearComponent") {
64  ans = new RectifiedLinearComponent();
65  } else if (component_type == "NormalizeComponent") {
66  ans = new NormalizeComponent();
67  } else if (component_type == "SoftHingeComponent") {
68  ans = new SoftHingeComponent();
69  } else if (component_type == "PnormComponent") {
70  ans = new PnormComponent();
71  } else if (component_type == "MaxoutComponent") {
72  ans = new MaxoutComponent();
73  } else if (component_type == "ScaleComponent") {
74  ans = new ScaleComponent();
75  } else if (component_type == "AffineComponent") {
76  ans = new AffineComponent();
77  } else if (component_type == "AffineComponentPreconditioned") {
78  ans = new AffineComponentPreconditioned();
79  } else if (component_type == "AffineComponentPreconditionedOnline") {
80  ans = new AffineComponentPreconditionedOnline();
81  } else if (component_type == "SumGroupComponent") {
82  ans = new SumGroupComponent();
83  } else if (component_type == "BlockAffineComponent") {
84  ans = new BlockAffineComponent();
85  } else if (component_type == "BlockAffineComponentPreconditioned") {
86  ans = new BlockAffineComponentPreconditioned();
87  } else if (component_type == "PermuteComponent") {
88  ans = new PermuteComponent();
89  } else if (component_type == "DctComponent") {
90  ans = new DctComponent();
91  } else if (component_type == "FixedLinearComponent") {
92  ans = new FixedLinearComponent();
93  } else if (component_type == "FixedAffineComponent") {
94  ans = new FixedAffineComponent();
95  } else if (component_type == "FixedScaleComponent") {
96  ans = new FixedScaleComponent();
97  } else if (component_type == "FixedBiasComponent") {
98  ans = new FixedBiasComponent();
99  } else if (component_type == "SpliceComponent") {
100  ans = new SpliceComponent();
101  } else if (component_type == "SpliceMaxComponent") {
102  ans = new SpliceMaxComponent();
103  } else if (component_type == "DropoutComponent") {
104  ans = new DropoutComponent();
105  } else if (component_type == "AdditiveNoiseComponent") {
106  ans = new AdditiveNoiseComponent();
107  } else if (component_type == "Convolutional1dComponent") {
108  ans = new Convolutional1dComponent();
109  } else if (component_type == "MaxpoolingComponent") {
110  ans = new MaxpoolingComponent();
111  }
112  return ans;
113 }
Component * NewFromString ( const std::string &  initializer_line)
static

Initialize the Component from one line that will contain first the type, e.g.

SigmoidComponent, and then a number of tokens (typically integers or floats) that will be used to initialize the component.

Definition at line 116 of file nnet-component.cc.

References Component::InitFromString(), KALDI_ERR, and Component::NewComponentOfType().

Referenced by Nnet::Init().

116  {
117  std::istringstream istr(initializer_line);
118  std::string component_type; // e.g. "SigmoidComponent".
119  istr >> component_type >> std::ws;
120  std::string rest_of_line;
121  getline(istr, rest_of_line);
122  Component *ans = NewComponentOfType(component_type);
123  if (ans == NULL)
124  KALDI_ERR << "Bad initializer line (no such type of Component): "
125  << initializer_line;
126  ans->InitFromString(rest_of_line);
127  return ans;
128 }
#define KALDI_ERR
Definition: kaldi-error.h:127
static Component * NewComponentOfType(const std::string &type)
Return a new Component of the given type e.g.
virtual void Propagate ( const ChunkInfo in_info,
const ChunkInfo out_info,
const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
) const
pure virtual

Perform forward pass propagation Input->Output.

Each row is one frame or training example. Interpreted as "num_chunks" equally sized chunks of frames; this only matters for layers that do things like context splicing. Typically this variable will either be 1 (when we're processing a single contiguous chunk of data) or will be the same as in.NumFrames(), but other values are possible if some layers do splicing.

Implemented in Convolutional1dComponent, AdditiveNoiseComponent, DropoutComponent, FixedBiasComponent, FixedScaleComponent, FixedAffineComponent, FixedLinearComponent, DctComponent, PermuteComponent, SumGroupComponent, BlockAffineComponent, SpliceMaxComponent, SpliceComponent, AffineComponent, LogSoftmaxComponent, SoftmaxComponent, ScaleComponent, SoftHingeComponent, RectifiedLinearComponent, PowerComponent, TanhComponent, SigmoidComponent, NormalizeComponent, PnormComponent, MaxpoolingComponent, and MaxoutComponent.

Referenced by NnetComputer::Propagate(), NnetDiscriminativeUpdater::Propagate(), NnetOnlineComputer::Propagate(), NnetUpdater::Propagate(), Component::Propagate(), NnetRescaler::Rescale(), NnetRescaler::RescaleComponent(), and kaldi::nnet2::UnitTestGenericComponentInternal().

void Propagate ( const ChunkInfo in_info,
const ChunkInfo out_info,
const CuMatrixBase< BaseFloat > &  in,
CuMatrix< BaseFloat > *  out 
) const
inline

A non-virtual propagate function that first resizes output if necessary.

Definition at line 203 of file nnet-component.h.

References ChunkInfo::NumCols(), CuMatrixBase< Real >::NumCols(), ChunkInfo::NumRows(), CuMatrixBase< Real >::NumRows(), Component::Propagate(), and CuMatrix< Real >::Resize().

206  {
207  if (out->NumRows() != out_info.NumRows() ||
208  out->NumCols() != out_info.NumCols()) {
209  out->Resize(out_info.NumRows(), out_info.NumCols());
210  }
211 
212  // Cast to CuMatrixBase to use the virtual version of propagate function.
213  Propagate(in_info, out_info, in,
214  static_cast<CuMatrixBase<BaseFloat>*>(out));
215  }
MatrixIndexT NumCols() const
Definition: cu-matrix.h:196
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:47
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
virtual void Propagate(const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const =0
Perform forward pass propagation Input->Output.
Component * ReadNew ( std::istream &  is,
bool  binary 
)
static

Read component from stream.

Definition at line 37 of file nnet-component.cc.

References KALDI_ERR, Component::NewComponentOfType(), Component::Read(), and kaldi::ReadToken().

Referenced by Nnet::Read(), and kaldi::nnet2::UnitTestGenericComponentInternal().

37  {
38  std::string token;
39  ReadToken(is, binary, &token); // e.g. "<SigmoidComponent>".
40  token.erase(0, 1); // erase "<".
41  token.erase(token.length()-1); // erase ">".
42  Component *ans = NewComponentOfType(token);
43  if (!ans)
44  KALDI_ERR << "Unknown component type " << token;
45  ans->Read(is, binary);
46  return ans;
47 }
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
#define KALDI_ERR
Definition: kaldi-error.h:127
static Component * NewComponentOfType(const std::string &type)
Return a new Component of the given type e.g.
virtual void SetIndex ( int32  index)
inlinevirtual

Definition at line 168 of file nnet-component.h.

References Component::index_.

168 { index_ = index; }

Member Data Documentation

int32 index_
private

Definition at line 267 of file nnet-component.h.

Referenced by Component::Index(), and Component::SetIndex().


The documentation for this class was generated from the following files: