RbmBase Class Referenceabstract

#include <nnet-rbm.h>

Inheritance diagram for RbmBase:
Collaboration diagram for RbmBase:

Public Types

enum  RbmNodeType { Bernoulli, Gaussian }
 
- Public Types inherited from Component
enum  ComponentType {
  kUnknown = 0x0, kUpdatableComponent = 0x0100, kAffineTransform, kLinearTransform,
  kConvolutionalComponent, kLstmProjected, kBlstmProjected, kRecurrentComponent,
  kActivationFunction = 0x0200, kSoftmax, kHiddenSoftmax, kBlockSoftmax,
  kSigmoid, kTanh, kParametricRelu, kDropout,
  kLengthNormComponent, kTranform = 0x0400, kRbm, kSplice,
  kCopy, kTranspose, kBlockLinearity, kAddShift,
  kRescale, kKlHmm = 0x0800, kSentenceAveragingComponent, kSimpleSentenceAveragingComponent,
  kAveragePoolingComponent, kMaxPoolingComponent, kFramePoolingComponent, kParallelComponent,
  kMultiBasisComponent
}
 Component type identification mechanism,. More...
 

Public Member Functions

 RbmBase (int32 dim_in, int32 dim_out)
 
virtual void Reconstruct (const CuMatrixBase< BaseFloat > &hid_state, CuMatrix< BaseFloat > *vis_probs)=0
 
virtual void RbmUpdate (const CuMatrixBase< BaseFloat > &pos_vis, const CuMatrixBase< BaseFloat > &pos_hid, const CuMatrixBase< BaseFloat > &neg_vis, const CuMatrixBase< BaseFloat > &neg_hid)=0
 
virtual RbmNodeType VisType () const =0
 
virtual RbmNodeType HidType () const =0
 
virtual void WriteAsNnet (std::ostream &os, bool binary) const =0
 
void SetRbmTrainOptions (const RbmTrainOptions &opts)
 Set training hyper-parameters to the network and its UpdatableComponent(s) More...
 
const RbmTrainOptionsGetRbmTrainOptions () const
 Get training hyper-parameters from the network. More...
 
- Public Member Functions inherited from Component
 Component (int32 input_dim, int32 output_dim)
 Generic interface of a component,. More...
 
virtual ~Component ()
 
virtual ComponentCopy () const =0
 Copy component (deep copy),. More...
 
virtual ComponentType GetType () const =0
 Get Type Identification of the component,. More...
 
virtual bool IsUpdatable () const
 Check if componeny has 'Updatable' interface (trainable components),. More...
 
virtual bool IsMultistream () const
 Check if component has 'Recurrent' interface (trainable and recurrent),. More...
 
int32 InputDim () const
 Get the dimension of the input,. More...
 
int32 OutputDim () const
 Get the dimension of the output,. More...
 
void Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
 Perform forward-pass propagation 'in' -> 'out',. More...
 
void Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
 Perform backward-pass propagation 'out_diff' -> 'in_diff'. More...
 
void Write (std::ostream &os, bool binary) const
 Write the component to a stream,. More...
 
virtual std::string Info () const
 Print some additional info (after <ComponentName> and the dims),. More...
 
virtual std::string InfoGradient () const
 Print some additional info about gradient (after <...> and dims),. More...
 

Protected Attributes

RbmTrainOptions rbm_opts_
 
- Protected Attributes inherited from Component
int32 input_dim_
 Data members,. More...
 
int32 output_dim_
 Dimension of the output of the Component,. More...
 

Private Member Functions

void Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
 
void BackpropagateFnc (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)
 Backward pass transformation (to be implemented by descending class...) More...
 

Additional Inherited Members

- Static Public Member Functions inherited from Component
static const char * TypeToMarker (ComponentType t)
 Converts component type to marker,. More...
 
static ComponentType MarkerToType (const std::string &s)
 Converts marker to component type (case insensitive),. More...
 
static ComponentInit (const std::string &conf_line)
 Initialize component from a line in config file,. More...
 
static ComponentRead (std::istream &is, bool binary)
 Read the component from a stream (static method),. More...
 
- Static Public Attributes inherited from Component
static const struct key_value kMarkerMap []
 The table with pairs of Component types and markers (defined in nnet-component.cc),. More...
 
- Protected Member Functions inherited from Component
virtual void PropagateFnc (const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)=0
 Abstract interface for propagation/backpropagation. More...
 
virtual void InitData (std::istream &is)
 Virtual interface for initialization and I/O,. More...
 
virtual void ReadData (std::istream &is, bool binary)
 Reads the component content. More...
 
virtual void WriteData (std::ostream &os, bool binary) const
 Writes the component content. More...
 

Detailed Description

Definition at line 35 of file nnet-rbm.h.

Member Enumeration Documentation

◆ RbmNodeType

Enumerator
Bernoulli 
Gaussian 

Definition at line 37 of file nnet-rbm.h.

Constructor & Destructor Documentation

◆ RbmBase()

RbmBase ( int32  dim_in,
int32  dim_out 
)
inline

Definition at line 42 of file nnet-rbm.h.

References RbmBase::HidType(), RbmBase::RbmUpdate(), RbmBase::Reconstruct(), RbmBase::VisType(), and RbmBase::WriteAsNnet().

42  :
43  Component(dim_in, dim_out)
44  { }
Component(int32 input_dim, int32 output_dim)
Generic interface of a component,.

Member Function Documentation

◆ Backpropagate()

void Backpropagate ( const CuMatrixBase< BaseFloat > &  in,
const CuMatrixBase< BaseFloat > &  out,
const CuMatrixBase< BaseFloat > &  out_diff,
CuMatrix< BaseFloat > *  in_diff 
)
inlineprivate

Definition at line 81 of file nnet-rbm.h.

85  { }

◆ BackpropagateFnc()

void BackpropagateFnc ( const CuMatrixBase< BaseFloat > &  in,
const CuMatrixBase< BaseFloat > &  out,
const CuMatrixBase< BaseFloat > &  out_diff,
CuMatrixBase< BaseFloat > *  in_diff 
)
inlineprivatevirtual

Backward pass transformation (to be implemented by descending class...)

Implements Component.

Definition at line 86 of file nnet-rbm.h.

90  { }

◆ GetRbmTrainOptions()

const RbmTrainOptions& GetRbmTrainOptions ( ) const
inline

Get training hyper-parameters from the network.

Definition at line 71 of file nnet-rbm.h.

References RbmBase::rbm_opts_.

71  {
72  return rbm_opts_;
73  }
RbmTrainOptions rbm_opts_
Definition: nnet-rbm.h:76

◆ HidType()

virtual RbmNodeType HidType ( ) const
pure virtual

Implemented in Rbm.

Referenced by RbmBase::RbmBase(), and Rbm::WriteAsNnet().

◆ RbmUpdate()

virtual void RbmUpdate ( const CuMatrixBase< BaseFloat > &  pos_vis,
const CuMatrixBase< BaseFloat > &  pos_hid,
const CuMatrixBase< BaseFloat > &  neg_vis,
const CuMatrixBase< BaseFloat > &  neg_hid 
)
pure virtual

Implemented in Rbm.

Referenced by RbmBase::RbmBase().

◆ Reconstruct()

virtual void Reconstruct ( const CuMatrixBase< BaseFloat > &  hid_state,
CuMatrix< BaseFloat > *  vis_probs 
)
pure virtual

Implemented in Rbm.

Referenced by RbmBase::RbmBase().

◆ SetRbmTrainOptions()

void SetRbmTrainOptions ( const RbmTrainOptions opts)
inline

Set training hyper-parameters to the network and its UpdatableComponent(s)

Definition at line 67 of file nnet-rbm.h.

References RbmBase::rbm_opts_.

67  {
68  rbm_opts_ = opts;
69  }
RbmTrainOptions rbm_opts_
Definition: nnet-rbm.h:76

◆ VisType()

virtual RbmNodeType VisType ( ) const
pure virtual

Implemented in Rbm.

Referenced by RbmBase::RbmBase().

◆ WriteAsNnet()

virtual void WriteAsNnet ( std::ostream &  os,
bool  binary 
) const
pure virtual

Implemented in Rbm.

Referenced by RbmBase::RbmBase().

Member Data Documentation

◆ rbm_opts_

RbmTrainOptions rbm_opts_
protected

The documentation for this class was generated from the following file: