SimpleForwardingDescriptor Class Reference

SimpleForwardingDescriptor is the base-case of ForwardingDescriptor, consisting of a source node in the graph with a given scalar weight (which will in the normal case be 1.0). More...

#include <nnet-descriptor.h>

Inheritance diagram for SimpleForwardingDescriptor:
Collaboration diagram for SimpleForwardingDescriptor:

Public Member Functions

virtual Cindex MapToInput (const Index &index) const
 
virtual int32 Dim (const Nnet &nnet) const
 
virtual ForwardingDescriptorCopy () const
 
virtual void GetNodeDependencies (std::vector< int32 > *node_indexes) const
 This function appends to "node_indexes" all the node indexes. More...
 
virtual BaseFloat GetScaleForNode (int32 node_index) const
 This function returns the scale on the node-index 'node_index' when it appears in expressions inside this descriptor, or +infinity if it does not appear. More...
 
virtual void WriteConfig (std::ostream &os, const std::vector< std::string > &node_names) const
 
 SimpleForwardingDescriptor (int32 src_node, BaseFloat scale=1.0)
 
virtual ~SimpleForwardingDescriptor ()
 
- Public Member Functions inherited from ForwardingDescriptor
virtual int32 Modulus () const
 This function is for use in things like clockwork RNNs, where shifting the time of the inputs and outputs of the network by some multiple integer n would leave things the same, but shifting by non-multiples would change the network structure. More...
 
virtual ~ForwardingDescriptor ()
 
 ForwardingDescriptor ()
 

Private Attributes

int32 src_node_
 
BaseFloat scale_
 

Detailed Description

SimpleForwardingDescriptor is the base-case of ForwardingDescriptor, consisting of a source node in the graph with a given scalar weight (which will in the normal case be 1.0).

The string representation in the normal (scale=1.0) case is just the node-name, like `tdnn2`; if the weight is not 1.0 it's something like `Scale(2.0, tdnn2)`

Definition at line 144 of file nnet-descriptor.h.

Constructor & Destructor Documentation

◆ SimpleForwardingDescriptor()

SimpleForwardingDescriptor ( int32  src_node,
BaseFloat  scale = 1.0 
)
inline

Definition at line 158 of file nnet-descriptor.h.

References KALDI_ASSERT.

159  :
160  src_node_(src_node), scale_(scale) {
161  KALDI_ASSERT(src_node >= 0);
162  }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ~SimpleForwardingDescriptor()

virtual ~SimpleForwardingDescriptor ( )
inlinevirtual

Definition at line 163 of file nnet-descriptor.h.

163 { }

Member Function Documentation

◆ Copy()

ForwardingDescriptor * Copy ( ) const
virtual

Implements ForwardingDescriptor.

Definition at line 92 of file nnet-descriptor.cc.

◆ Dim()

int32 Dim ( const Nnet nnet) const
virtual

Implements ForwardingDescriptor.

Definition at line 79 of file nnet-descriptor.cc.

References NetworkNode::Dim(), and Nnet::GetNode().

79  {
80  return nnet.GetNode(src_node_).Dim(nnet);
81 }

◆ GetNodeDependencies()

void GetNodeDependencies ( std::vector< int32 > *  node_indexes) const
virtual

This function appends to "node_indexes" all the node indexes.

Implements ForwardingDescriptor.

Definition at line 96 of file nnet-descriptor.cc.

Referenced by ModelCollapser::SumDescriptorIsCollapsible().

97  {
98  node_indexes->push_back(src_node_);
99 }

◆ GetScaleForNode()

BaseFloat GetScaleForNode ( int32  node_index) const
virtual

This function returns the scale on the node-index 'node_index' when it appears in expressions inside this descriptor, or +infinity if it does not appear.

E.g. if the descriptor is just `Scale(tdnn2, 2.0)` and the node index for `tdnn2` is 4, then GetScaleForNode(4) would return 2.0. If a particular node_index > 0 appears in different sub-expressions of the descriptor with different scales it is an error (it's not supported) and this function would crash.

Implements ForwardingDescriptor.

Definition at line 83 of file nnet-descriptor.cc.

83  {
84  if (node_index == src_node_) return scale_;
85  else return std::numeric_limits<BaseFloat>::infinity();
86 }

◆ MapToInput()

Cindex MapToInput ( const Index index) const
virtual

Implements ForwardingDescriptor.

Definition at line 88 of file nnet-descriptor.cc.

88  {
89  return Cindex(src_node_, index);
90 }
std::pair< int32, Index > Cindex
Definition: nnet-common.h:115

◆ WriteConfig()

void WriteConfig ( std::ostream &  os,
const std::vector< std::string > &  node_names 
) const
virtual

Implements ForwardingDescriptor.

Definition at line 101 of file nnet-descriptor.cc.

References KALDI_ASSERT.

103  {
104  KALDI_ASSERT(static_cast<size_t>(src_node_) < node_names.size());
105  if (scale_ == 1.0) {
106  os << node_names[src_node_];
107  } else {
108  os << "Scale(" << scale_ << ", " << node_names[src_node_] << ")";
109  }
110 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

Member Data Documentation

◆ scale_

BaseFloat scale_
private

Definition at line 166 of file nnet-descriptor.h.

◆ src_node_

int32 src_node_
private

Definition at line 165 of file nnet-descriptor.h.


The documentation for this class was generated from the following files: