RoundingForwardingDescriptor Class Reference

For use in clockwork RNNs and the like, this forwarding-descriptor rounds the time-index t down to the the closest t' <= t that is an exact multiple of t_modulus_. More...

#include <nnet-descriptor.h>

Inheritance diagram for RoundingForwardingDescriptor:
Collaboration diagram for RoundingForwardingDescriptor:

Public Member Functions

virtual Cindex MapToInput (const Index &ind) const
 
virtual int32 Dim (const Nnet &nnet) const
 
virtual ForwardingDescriptorCopy () const
 
virtual void WriteConfig (std::ostream &os, const std::vector< std::string > &node_names) const
 
virtual int32 Modulus () const
 This function is for use in things like clockwork RNNs, where shifting the time of the inputs and outputs of the network by some multiple integer n would leave things the same, but shifting by non-multiples would change the network structure. More...
 
virtual void GetNodeDependencies (std::vector< int32 > *node_indexes) const
 This function appends to "node_indexes" all the node indexes. More...
 
virtual BaseFloat GetScaleForNode (int32 node_index) const
 This function returns the scale on the node-index 'node_index' when it appears in expressions inside this descriptor, or +infinity if it does not appear. More...
 
 RoundingForwardingDescriptor (ForwardingDescriptor *src, int32 t_modulus)
 
virtual ~RoundingForwardingDescriptor ()
 
- Public Member Functions inherited from ForwardingDescriptor
virtual ~ForwardingDescriptor ()
 
 ForwardingDescriptor ()
 

Private Attributes

ForwardingDescriptorsrc_
 
int32 t_modulus_
 

Detailed Description

For use in clockwork RNNs and the like, this forwarding-descriptor rounds the time-index t down to the the closest t' <= t that is an exact multiple of t_modulus_.

Written form is: `Round(<descriptor>, <t-modulus>)` e.g.: `Round(tdnn2, 3)`

Definition at line 242 of file nnet-descriptor.h.

Constructor & Destructor Documentation

◆ RoundingForwardingDescriptor()

RoundingForwardingDescriptor ( ForwardingDescriptor src,
int32  t_modulus 
)
inline

Definition at line 259 of file nnet-descriptor.h.

260  :
261  src_(src), t_modulus_(t_modulus) { }

◆ ~RoundingForwardingDescriptor()

virtual ~RoundingForwardingDescriptor ( )
inlinevirtual

Definition at line 263 of file nnet-descriptor.h.

263 { delete src_; }

Member Function Documentation

◆ Copy()

ForwardingDescriptor * Copy ( ) const
virtual

Implements ForwardingDescriptor.

Definition at line 203 of file nnet-descriptor.cc.

203  {
205 }
RoundingForwardingDescriptor(ForwardingDescriptor *src, int32 t_modulus)
virtual ForwardingDescriptor * Copy() const =0

◆ Dim()

virtual int32 Dim ( const Nnet nnet) const
inlinevirtual

Implements ForwardingDescriptor.

Definition at line 245 of file nnet-descriptor.h.

References ForwardingDescriptor::Copy(), and ForwardingDescriptor::WriteConfig().

245 { return src_->Dim(nnet); }
virtual int32 Dim(const Nnet &nnet) const =0

◆ GetNodeDependencies()

void GetNodeDependencies ( std::vector< int32 > *  node_indexes) const
virtual

This function appends to "node_indexes" all the node indexes.

Implements ForwardingDescriptor.

Definition at line 182 of file nnet-descriptor.cc.

183  {
184  src_->GetNodeDependencies(node_indexes);
185 }
virtual void GetNodeDependencies(std::vector< int32 > *node_indexes) const =0
This function appends to "node_indexes" all the node indexes.

◆ GetScaleForNode()

BaseFloat GetScaleForNode ( int32  node_index) const
virtual

This function returns the scale on the node-index 'node_index' when it appears in expressions inside this descriptor, or +infinity if it does not appear.

E.g. if the descriptor is just `Scale(tdnn2, 2.0)` and the node index for `tdnn2` is 4, then GetScaleForNode(4) would return 2.0. If a particular node_index > 0 appears in different sub-expressions of the descriptor with different scales it is an error (it's not supported) and this function would crash.

Implements ForwardingDescriptor.

Definition at line 187 of file nnet-descriptor.cc.

188  {
189  return src_->GetScaleForNode(node_index);
190 }
virtual BaseFloat GetScaleForNode(int32 node_index) const =0
This function returns the scale on the node-index &#39;node_index&#39; when it appears in expressions inside ...

◆ MapToInput()

Cindex MapToInput ( const Index ind) const
virtual

Implements ForwardingDescriptor.

Definition at line 192 of file nnet-descriptor.cc.

References KALDI_ASSERT, and Index::t.

192  {
193  KALDI_ASSERT(t_modulus_ >= 1);
194  Index ind_mod(ind);
195  // unfortunately doing "mathematical" modulus is a bit painful in C.
196  int32 mod = ind_mod.t % t_modulus_;
197  if (mod < 0)
198  mod += t_modulus_;
199  ind_mod.t -= mod;
200  return src_->MapToInput(ind_mod);
201 }
virtual Cindex MapToInput(const Index &output) const =0
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Modulus()

virtual int32 Modulus ( ) const
inlinevirtual

This function is for use in things like clockwork RNNs, where shifting the time of the inputs and outputs of the network by some multiple integer n would leave things the same, but shifting by non-multiples would change the network structure.

It returns the smallest modulus to which this descriptor is invariant; the lowest common multiple of all descriptors in the network gives you the modulus for the whole network.

Reimplemented from ForwardingDescriptor.

Definition at line 251 of file nnet-descriptor.h.

References ForwardingDescriptor::GetNodeDependencies(), and ForwardingDescriptor::GetScaleForNode().

◆ WriteConfig()

void WriteConfig ( std::ostream &  os,
const std::vector< std::string > &  node_names 
) const
virtual

Implements ForwardingDescriptor.

Definition at line 207 of file nnet-descriptor.cc.

209  {
210  os << "Round(";
211  src_->WriteConfig(os, node_names);
212  os << ", " << t_modulus_ << ")";
213 }
virtual void WriteConfig(std::ostream &os, const std::vector< std::string > &node_names) const =0

Member Data Documentation

◆ src_

ForwardingDescriptor* src_
private

Definition at line 265 of file nnet-descriptor.h.

◆ t_modulus_

int32 t_modulus_
private

Definition at line 266 of file nnet-descriptor.h.


The documentation for this class was generated from the following files: