Abstract class, building block of the network. More...
#include <nnet-component.h>
Classes | |
struct | key_value |
A pair of type and marker,. More... | |
Public Types | |
enum | ComponentType { kUnknown = 0x0, kUpdatableComponent = 0x0100, kAffineTransform, kLinearTransform, kConvolutionalComponent, kLstmProjected, kBlstmProjected, kRecurrentComponent, kActivationFunction = 0x0200, kSoftmax, kHiddenSoftmax, kBlockSoftmax, kSigmoid, kTanh, kParametricRelu, kDropout, kLengthNormComponent, kTranform = 0x0400, kRbm, kSplice, kCopy, kTranspose, kBlockLinearity, kAddShift, kRescale, kKlHmm = 0x0800, kSentenceAveragingComponent, kSimpleSentenceAveragingComponent, kAveragePoolingComponent, kMaxPoolingComponent, kFramePoolingComponent, kParallelComponent, kMultiBasisComponent } |
Component type identification mechanism,. More... | |
Public Member Functions | |
Component (int32 input_dim, int32 output_dim) | |
Generic interface of a component,. More... | |
virtual | ~Component () |
virtual Component * | Copy () const =0 |
Copy component (deep copy),. More... | |
virtual ComponentType | GetType () const =0 |
Get Type Identification of the component,. More... | |
virtual bool | IsUpdatable () const |
Check if componeny has 'Updatable' interface (trainable components),. More... | |
virtual bool | IsMultistream () const |
Check if component has 'Recurrent' interface (trainable and recurrent),. More... | |
int32 | InputDim () const |
Get the dimension of the input,. More... | |
int32 | OutputDim () const |
Get the dimension of the output,. More... | |
void | Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out) |
Perform forward-pass propagation 'in' -> 'out',. More... | |
void | Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff) |
Perform backward-pass propagation 'out_diff' -> 'in_diff'. More... | |
void | Write (std::ostream &os, bool binary) const |
Write the component to a stream,. More... | |
virtual std::string | Info () const |
Print some additional info (after <ComponentName> and the dims),. More... | |
virtual std::string | InfoGradient () const |
Print some additional info about gradient (after <...> and dims),. More... | |
Static Public Member Functions | |
static const char * | TypeToMarker (ComponentType t) |
Converts component type to marker,. More... | |
static ComponentType | MarkerToType (const std::string &s) |
Converts marker to component type (case insensitive),. More... | |
static Component * | Init (const std::string &conf_line) |
Initialize component from a line in config file,. More... | |
static Component * | Read (std::istream &is, bool binary) |
Read the component from a stream (static method),. More... | |
Static Public Attributes | |
static const struct key_value | kMarkerMap [] |
The table with pairs of Component types and markers (defined in nnet-component.cc),. More... | |
Protected Member Functions | |
virtual void | PropagateFnc (const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)=0 |
Abstract interface for propagation/backpropagation. More... | |
virtual void | BackpropagateFnc (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)=0 |
Backward pass transformation (to be implemented by descending class...) More... | |
virtual void | InitData (std::istream &is) |
Virtual interface for initialization and I/O,. More... | |
virtual void | ReadData (std::istream &is, bool binary) |
Reads the component content. More... | |
virtual void | WriteData (std::ostream &os, bool binary) const |
Writes the component content. More... | |
Protected Attributes | |
int32 | input_dim_ |
Data members,. More... | |
int32 | output_dim_ |
Dimension of the output of the Component,. More... | |
Static Private Member Functions | |
static Component * | NewComponentOfType (ComponentType t, int32 input_dim, int32 output_dim) |
Private members (descending classes cannot call this),. More... | |
Abstract class, building block of the network.
It is able to propagate (PropagateFnc: compute the output based on its input) and backpropagate (BackpropagateFnc: i.e. transform loss derivative w.r.t. output to derivative w.r.t. the input) the formulas are implemented in descendant classes (AffineTransform,Sigmoid,Softmax,...).
Definition at line 51 of file nnet-component.cc.
enum ComponentType |
Component type identification mechanism,.
Types of Components,
Definition at line 47 of file nnet-component.h.
Generic interface of a component,.
Definition at line 105 of file nnet-component.h.
|
inlinevirtual |
Definition at line 110 of file nnet-component.h.
References Component::Copy(), and Component::GetType().
|
inline |
Perform backward-pass propagation 'out_diff' -> 'in_diff'.
Note: 'in' and 'out' will be used only sometimes...
Definition at line 325 of file nnet-component.h.
References Component::BackpropagateFnc(), Component::InputDim(), KALDI_ASSERT, KALDI_ERR, kaldi::kSetZero, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), Component::OutputDim(), and CuMatrix< Real >::Resize().
Referenced by Component::OutputDim(), kaldi::nnet1::UnitTestConvolutionalComponent3x3(), kaldi::nnet1::UnitTestConvolutionalComponentUnity(), kaldi::nnet1::UnitTestDropoutComponent(), kaldi::nnet1::UnitTestMaxPoolingComponent(), and kaldi::nnet1::UnitTestSimpleSentenceAveragingComponent().
|
protectedpure virtual |
Backward pass transformation (to be implemented by descending class...)
Implemented in BlstmProjected, Rescale, LstmProjected, AddShift, ConvolutionalComponent, Dropout, MultiBasisComponent, LengthNormComponent, SentenceAveragingComponent, ParallelComponent, Tanh, RecurrentComponent, FramePoolingComponent, Sigmoid, CopyComponent, BlockSoftmax, AffineTransform, LinearTransform, ParametricRelu, SimpleSentenceAveragingComponent, AveragePoolingComponent, MaxPoolingComponent, Splice, KlHmm, HiddenSoftmax, RbmBase, and Softmax.
Referenced by Component::Backpropagate(), and Component::InfoGradient().
|
pure virtual |
Copy component (deep copy),.
Implemented in Rescale, AddShift, Dropout, LengthNormComponent, Tanh, Sigmoid, SentenceAveragingComponent, CopyComponent, BlockSoftmax, Rbm, ConvolutionalComponent, HiddenSoftmax, BlstmProjected, LstmProjected, FramePoolingComponent, SimpleSentenceAveragingComponent, AveragePoolingComponent, MaxPoolingComponent, RecurrentComponent, Splice, ParametricRelu, KlHmm, Softmax, MultiBasisComponent, ParallelComponent, AffineTransform, and LinearTransform.
Referenced by Nnet::AppendComponent(), Nnet::Nnet(), Nnet::operator=(), Nnet::ReplaceComponent(), and Component::~Component().
|
pure virtual |
Get Type Identification of the component,.
Implemented in Rescale, AddShift, Dropout, LengthNormComponent, Tanh, Sigmoid, SentenceAveragingComponent, CopyComponent, BlockSoftmax, Rbm, ConvolutionalComponent, HiddenSoftmax, BlstmProjected, LstmProjected, SimpleSentenceAveragingComponent, FramePoolingComponent, AveragePoolingComponent, MaxPoolingComponent, RecurrentComponent, Splice, ParametricRelu, KlHmm, Softmax, MultiBasisComponent, ParallelComponent, AffineTransform, and LinearTransform.
Referenced by kaldi::ConvertComponent(), main(), Component::Propagate(), Nnet::SetDropoutRate(), Component::Write(), and Component::~Component().
|
inlinevirtual |
Print some additional info (after <ComponentName> and the dims),.
Reimplemented in Rescale, BlstmProjected, AddShift, Dropout, LstmProjected, ConvolutionalComponent, MultiBasisComponent, SentenceAveragingComponent, CopyComponent, BlockSoftmax, ParallelComponent, FramePoolingComponent, RecurrentComponent, AffineTransform, LinearTransform, ParametricRelu, SimpleSentenceAveragingComponent, and Splice.
Definition at line 159 of file nnet-component.h.
|
inlinevirtual |
Print some additional info about gradient (after <...> and dims),.
Reimplemented in Rescale, BlstmProjected, AddShift, LstmProjected, MultiBasisComponent, ConvolutionalComponent, SentenceAveragingComponent, ParallelComponent, FramePoolingComponent, RecurrentComponent, AffineTransform, LinearTransform, ParametricRelu, and SimpleSentenceAveragingComponent.
Definition at line 162 of file nnet-component.h.
References Component::BackpropagateFnc(), and Component::PropagateFnc().
|
static |
Initialize component from a line in config file,.
Definition at line 203 of file nnet-component.cc.
References kaldi::ExpectToken(), Component::InitData(), Component::MarkerToType(), Component::NewComponentOfType(), kaldi::ReadBasicType(), and kaldi::ReadToken().
Referenced by Nnet::Init(), Component::OutputDim(), and kaldi::nnet1::UnitTestMaxPoolingComponent().
|
inlineprotectedvirtual |
Virtual interface for initialization and I/O,.
Initialize internal data of a component
Reimplemented in Rescale, AddShift, Dropout, UpdatableComponent, SentenceAveragingComponent, CopyComponent, BlockSoftmax, Rbm, ConvolutionalComponent, BlstmProjected, LstmProjected, SimpleSentenceAveragingComponent, FramePoolingComponent, AveragePoolingComponent, MaxPoolingComponent, RecurrentComponent, Splice, ParametricRelu, ParallelComponent, MultiBasisComponent, AffineTransform, and LinearTransform.
Definition at line 180 of file nnet-component.h.
Referenced by Component::Init(), and UpdatableComponent::SetBiasLearnRateCoef().
|
inline |
Get the dimension of the input,.
Definition at line 130 of file nnet-component.h.
References Component::input_dim_.
Referenced by Component::Backpropagate(), kaldi::ConvertSigmoidComponent(), kaldi::ConvertSoftmaxComponent(), kaldi::ConvertSpliceComponent(), LinearTransform::InitData(), AffineTransform::InitData(), MultiBasisComponent::InitData(), ParallelComponent::InitData(), Splice::InitData(), FramePoolingComponent::InitData(), Rbm::InitData(), CopyComponent::InitData(), SentenceAveragingComponent::InitData(), AddShift::InitData(), Rescale::InitData(), SimpleSentenceAveragingComponent::PropagateFnc(), Nnet::Read(), Splice::ReadData(), ParallelComponent::ReadData(), FramePoolingComponent::ReadData(), MultiBasisComponent::ReadData(), SentenceAveragingComponent::ReadData(), AddShift::Update(), Rescale::Update(), Component::Write(), and Rbm::WriteAsNnet().
|
inlinevirtual |
Check if component has 'Recurrent' interface (trainable and recurrent),.
Reimplemented in MultistreamComponent.
Definition at line 125 of file nnet-component.h.
Referenced by Nnet::ResetStreams(), and Nnet::SetSeqLengths().
|
inlinevirtual |
Check if componeny has 'Updatable' interface (trainable components),.
Reimplemented in UpdatableComponent.
Definition at line 120 of file nnet-component.h.
Referenced by main(), and Nnet::SetTrainOptions().
|
static |
Converts marker to component type (case insensitive),.
Definition at line 94 of file nnet-component.cc.
References rnnlm::i, KALDI_ERR, Component::kMarkerMap, Component::kUnknown, and Component::key_value::value.
Referenced by Component::Init(), and Component::Read().
|
staticprivate |
Private members (descending classes cannot call this),.
Create a new intance of component,
Definition at line 110 of file nnet-component.cc.
References Component::kAddShift, Component::kAffineTransform, KALDI_ERR, Component::kAveragePoolingComponent, Component::kBlockSoftmax, Component::kBlstmProjected, Component::kConvolutionalComponent, Component::kCopy, Component::kDropout, Component::kFramePoolingComponent, Component::kHiddenSoftmax, Component::kKlHmm, Component::kLengthNormComponent, Component::kLinearTransform, Component::kLstmProjected, Component::kMaxPoolingComponent, Component::kMultiBasisComponent, Component::kParallelComponent, Component::kParametricRelu, Component::kRbm, Component::kRecurrentComponent, Component::kRescale, Component::kSentenceAveragingComponent, Component::kSigmoid, Component::kSimpleSentenceAveragingComponent, Component::kSoftmax, Component::kSplice, Component::kTanh, Component::kUnknown, kaldi::cu::Splice(), and Component::TypeToMarker().
Referenced by Component::Init(), and Component::Read().
|
inline |
Get the dimension of the output,.
Definition at line 135 of file nnet-component.h.
References Component::Backpropagate(), Component::Init(), Component::output_dim_, Component::Propagate(), Component::Read(), and Component::Write().
Referenced by Component::Backpropagate(), SimpleSentenceAveragingComponent::BackpropagateFnc(), MultiBasisComponent::BackpropagateFnc(), AffineTransform::InitData(), LinearTransform::InitData(), MultiBasisComponent::InitData(), ParallelComponent::InitData(), Splice::InitData(), Rbm::InitData(), BlockSoftmax::InitData(), CopyComponent::InitData(), SentenceAveragingComponent::InitData(), Splice::ReadData(), ParallelComponent::ReadData(), MultiBasisComponent::ReadData(), BlockSoftmax::ReadData(), CopyComponent::ReadData(), SentenceAveragingComponent::ReadData(), AffineTransform::Update(), RecurrentComponent::Update(), Component::Write(), and Rbm::WriteAsNnet().
|
inline |
Perform forward-pass propagation 'in' -> 'out',.
Definition at line 311 of file nnet-component.h.
References Component::GetType(), Component::input_dim_, KALDI_ERR, kaldi::kSetZero, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), Component::output_dim_, Component::PropagateFnc(), CuMatrix< Real >::Resize(), and Component::TypeToMarker().
Referenced by Component::OutputDim(), kaldi::nnet1::UnitTestConvolutionalComponent3x3(), kaldi::nnet1::UnitTestConvolutionalComponentUnity(), kaldi::nnet1::UnitTestDropoutComponent(), kaldi::nnet1::UnitTestLengthNorm(), kaldi::nnet1::UnitTestMaxPoolingComponent(), and kaldi::nnet1::UnitTestSimpleSentenceAveragingComponent().
|
protectedpure virtual |
Abstract interface for propagation/backpropagation.
Forward pass transformation (to be implemented by descending class...)
Implemented in BlstmProjected, Rescale, AddShift, LstmProjected, Dropout, MultiBasisComponent, ConvolutionalComponent, LengthNormComponent, Tanh, SentenceAveragingComponent, ParallelComponent, Rbm, Sigmoid, FramePoolingComponent, CopyComponent, RecurrentComponent, AffineTransform, BlockSoftmax, LinearTransform, ParametricRelu, SimpleSentenceAveragingComponent, Splice, AveragePoolingComponent, MaxPoolingComponent, HiddenSoftmax, KlHmm, and Softmax.
Referenced by Component::InfoGradient(), and Component::Propagate().
Read the component from a stream (static method),.
Definition at line 224 of file nnet-component.cc.
References kaldi::ExpectToken(), Component::MarkerToType(), Component::NewComponentOfType(), kaldi::Peek(), kaldi::PeekToken(), kaldi::ReadBasicType(), Component::ReadData(), and kaldi::ReadToken().
Referenced by Component::OutputDim(), Nnet::Read(), and kaldi::nnet1::ReadComponentFromString().
|
inlineprotectedvirtual |
Reads the component content.
Reimplemented in Rescale, AddShift, Dropout, SentenceAveragingComponent, CopyComponent, Rbm, BlockSoftmax, MultiBasisComponent, BlstmProjected, ConvolutionalComponent, FramePoolingComponent, KlHmm, ParallelComponent, LstmProjected, LinearTransform, Splice, RecurrentComponent, AffineTransform, SimpleSentenceAveragingComponent, ParametricRelu, AveragePoolingComponent, and MaxPoolingComponent.
Definition at line 183 of file nnet-component.h.
Referenced by Component::Read().
|
static |
Converts component type to marker,.
Definition at line 84 of file nnet-component.cc.
References rnnlm::i, KALDI_ERR, Component::key_value::key, and Component::kMarkerMap.
Referenced by CopyComponent::BackpropagateFnc(), Nnet::Check(), kaldi::ConvertComponent(), Nnet::Info(), Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), main(), Component::NewComponentOfType(), Component::Propagate(), Component::Write(), and Rbm::WriteAsNnet().
void Write | ( | std::ostream & | os, |
bool | binary | ||
) | const |
Write the component to a stream,.
Definition at line 260 of file nnet-component.cc.
References Component::GetType(), Component::InputDim(), Component::OutputDim(), Component::TypeToMarker(), kaldi::WriteBasicType(), Component::WriteData(), and kaldi::WriteToken().
Referenced by Component::OutputDim().
|
inlineprotectedvirtual |
Writes the component content.
Reimplemented in Rescale, AddShift, Dropout, SentenceAveragingComponent, Rbm, CopyComponent, BlstmProjected, ConvolutionalComponent, MultiBasisComponent, BlockSoftmax, FramePoolingComponent, LstmProjected, ParallelComponent, KlHmm, LinearTransform, RecurrentComponent, AffineTransform, SimpleSentenceAveragingComponent, ParametricRelu, Splice, AveragePoolingComponent, and MaxPoolingComponent.
Definition at line 186 of file nnet-component.h.
Referenced by Component::Write().
|
protected |
Data members,.
Dimension of the input of the Component,
Definition at line 190 of file nnet-component.h.
Referenced by MaxPoolingComponent::BackpropagateFnc(), AveragePoolingComponent::BackpropagateFnc(), LstmProjected::BackpropagateFnc(), BlstmProjected::BackpropagateFnc(), RecurrentComponent::InitData(), LstmProjected::InitData(), BlstmProjected::InitData(), ConvolutionalComponent::InitData(), Component::InputDim(), Component::Propagate(), AveragePoolingComponent::PropagateFnc(), MaxPoolingComponent::PropagateFnc(), ConvolutionalComponent::PropagateFnc(), Rbm::RbmUpdate(), AveragePoolingComponent::ReadData(), MaxPoolingComponent::ReadData(), AffineTransform::ReadData(), LinearTransform::ReadData(), KlHmm::ReadData(), FramePoolingComponent::ReadData(), ConvolutionalComponent::ReadData(), Rbm::ReadData(), Rbm::Reconstruct(), ConvolutionalComponent::ReverseIndexes(), and KlHmm::SetStats().
|
static |
The table with pairs of Component types and markers (defined in nnet-component.cc),.
Definition at line 95 of file nnet-component.h.
Referenced by Component::MarkerToType(), and Component::TypeToMarker().
|
protected |
Dimension of the output of the Component,.
Definition at line 191 of file nnet-component.h.
Referenced by RecurrentComponent::InitData(), ConvolutionalComponent::InitData(), Component::OutputDim(), Component::Propagate(), Rbm::RbmUpdate(), MaxPoolingComponent::ReadData(), AveragePoolingComponent::ReadData(), ParametricRelu::ReadData(), AffineTransform::ReadData(), LinearTransform::ReadData(), KlHmm::ReadData(), FramePoolingComponent::ReadData(), ConvolutionalComponent::ReadData(), Rbm::ReadData(), Rbm::Reconstruct(), and KlHmm::SetStats().