SwitchingForwardingDescriptor Class Reference

Chooses from different inputs based on the the time index modulo (the number of ForwardingDescriptors given as inputs). More...

#include <nnet-descriptor.h>

Inheritance diagram for SwitchingForwardingDescriptor:
Collaboration diagram for SwitchingForwardingDescriptor:

Public Member Functions

virtual Cindex MapToInput (const Index &ind) const
 
virtual int32 Dim (const Nnet &nnet) const
 
virtual ForwardingDescriptorCopy () const
 
virtual void WriteConfig (std::ostream &os, const std::vector< std::string > &node_names) const
 
virtual int32 Modulus () const
 This function is for use in things like clockwork RNNs, where shifting the time of the inputs and outputs of the network by some multiple integer n would leave things the same, but shifting by non-multiples would change the network structure. More...
 
virtual void GetNodeDependencies (std::vector< int32 > *node_indexes) const
 This function appends to "node_indexes" all the node indexes. More...
 
virtual BaseFloat GetScaleForNode (int32 node_index) const
 This function returns the scale on the node-index 'node_index' when it appears in expressions inside this descriptor, or +infinity if it does not appear. More...
 
 SwitchingForwardingDescriptor (std::vector< ForwardingDescriptor *> &src)
 
virtual ~SwitchingForwardingDescriptor ()
 
- Public Member Functions inherited from ForwardingDescriptor
virtual ~ForwardingDescriptor ()
 
 ForwardingDescriptor ()
 

Private Attributes

std::vector< ForwardingDescriptor * > src_
 

Detailed Description

Chooses from different inputs based on the the time index modulo (the number of ForwardingDescriptors given as inputs).

This is rarely if ever used. Written form is: `Switch(<descriptor>, <descriptor> [, <descriptor> ...])` e.g. `Switch(tdnn2a, tdnn2b, tdnn2c)`

Definition at line 210 of file nnet-descriptor.h.

Constructor & Destructor Documentation

◆ SwitchingForwardingDescriptor()

SwitchingForwardingDescriptor ( std::vector< ForwardingDescriptor *> &  src)
inline

Definition at line 227 of file nnet-descriptor.h.

227  :
228  src_(src) { }
std::vector< ForwardingDescriptor * > src_

◆ ~SwitchingForwardingDescriptor()

virtual ~SwitchingForwardingDescriptor ( )
inlinevirtual

Definition at line 229 of file nnet-descriptor.h.

References kaldi::DeletePointers().

229 { DeletePointers(&src_); }
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
std::vector< ForwardingDescriptor * > src_

Member Function Documentation

◆ Copy()

ForwardingDescriptor * Copy ( ) const
virtual

Implements ForwardingDescriptor.

Definition at line 161 of file nnet-descriptor.cc.

References kaldi::cu::Copy(), and rnnlm::i.

161  {
162  std::vector<ForwardingDescriptor*> src_copy(src_.size());
163  for (size_t i = 0; i < src_.size(); i++)
164  src_copy[i] = src_[i]->Copy();
165  return new SwitchingForwardingDescriptor(src_copy);
166 }
std::vector< ForwardingDescriptor * > src_
virtual ForwardingDescriptor * Copy() const
SwitchingForwardingDescriptor(std::vector< ForwardingDescriptor *> &src)

◆ Dim()

virtual int32 Dim ( const Nnet nnet) const
inlinevirtual

◆ GetNodeDependencies()

void GetNodeDependencies ( std::vector< int32 > *  node_indexes) const
virtual

This function appends to "node_indexes" all the node indexes.

Implements ForwardingDescriptor.

Definition at line 146 of file nnet-descriptor.cc.

References Descriptor::GetNodeDependencies(), and rnnlm::i.

147  {
148  for (size_t i = 0; i < src_.size(); i++)
149  src_[i]->GetNodeDependencies(node_indexes);
150 }
virtual void GetNodeDependencies(std::vector< int32 > *node_indexes) const
This function appends to "node_indexes" all the node indexes.
std::vector< ForwardingDescriptor * > src_

◆ GetScaleForNode()

BaseFloat GetScaleForNode ( int32  node_index) const
virtual

This function returns the scale on the node-index 'node_index' when it appears in expressions inside this descriptor, or +infinity if it does not appear.

E.g. if the descriptor is just `Scale(tdnn2, 2.0)` and the node index for `tdnn2` is 4, then GetScaleForNode(4) would return 2.0. If a particular node_index > 0 appears in different sub-expressions of the descriptor with different scales it is an error (it's not supported) and this function would crash.

Implements ForwardingDescriptor.

Definition at line 473 of file nnet-descriptor.cc.

References rnnlm::i, and KALDI_ERR.

474  {
475  BaseFloat inf = std::numeric_limits<BaseFloat>::infinity(),
476  ans = inf;
477  for (size_t i = 0; i < src_.size(); i++) {
478  BaseFloat this_ans = src_[i]->GetScaleForNode(node_index);
479  if (this_ans != inf) {
480  if (ans != inf && ans != this_ans)
481  KALDI_ERR << "Invalid Descriptor encountered: for node-index "
482  << node_index << ", got two different scales "
483  << this_ans << " vs. " << ans;
484  ans = this_ans;
485  }
486  }
487  return ans;
488 }
float BaseFloat
Definition: kaldi-types.h:29
std::vector< ForwardingDescriptor * > src_
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ MapToInput()

Cindex MapToInput ( const Index ind) const
virtual

Implements ForwardingDescriptor.

Definition at line 152 of file nnet-descriptor.cc.

References KALDI_ASSERT, and Index::t.

152  {
153  KALDI_ASSERT(!src_.empty());
154  int32 size = src_.size(), mod = ind.t % size;
155  // next line gets "mathematical" modulus, not broken "C" modulus.
156  if (mod < 0) mod += size;
157  return src_[mod]->MapToInput(ind);
158 }
kaldi::int32 int32
std::vector< ForwardingDescriptor * > src_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Modulus()

int32 Modulus ( ) const
virtual

This function is for use in things like clockwork RNNs, where shifting the time of the inputs and outputs of the network by some multiple integer n would leave things the same, but shifting by non-multiples would change the network structure.

It returns the smallest modulus to which this descriptor is invariant; the lowest common multiple of all descriptors in the network gives you the modulus for the whole network.

Reimplemented from ForwardingDescriptor.

Definition at line 466 of file nnet-descriptor.cc.

References rnnlm::i, kaldi::Lcm(), and ConstantSumDescriptor::Modulus().

466  {
467  int32 ans = src_.size();;
468  for (size_t i = 0; i < src_.size(); i++)
469  ans = Lcm(ans, src_[i]->Modulus());
470  return ans;
471 }
virtual int32 Modulus() const
This function is for use in things like clockwork RNNs, where shifting the time of the inputs and out...
kaldi::int32 int32
I Lcm(I m, I n)
Returns the least common multiple of two integers.
Definition: kaldi-math.h:318
std::vector< ForwardingDescriptor * > src_

◆ WriteConfig()

void WriteConfig ( std::ostream &  os,
const std::vector< std::string > &  node_names 
) const
virtual

Implements ForwardingDescriptor.

Definition at line 169 of file nnet-descriptor.cc.

References rnnlm::i, and KALDI_ASSERT.

171  {
172  KALDI_ASSERT(!src_.empty());
173  os << "Switch(";
174  for (size_t i = 0; i < src_.size(); i++) {
175  src_[i]->WriteConfig(os, node_names);
176  if (i + 1 < src_.size())
177  os << ", ";
178  }
179  os << ")";
180 }
std::vector< ForwardingDescriptor * > src_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

Member Data Documentation

◆ src_

std::vector<ForwardingDescriptor*> src_
private

Definition at line 232 of file nnet-descriptor.h.


The documentation for this class was generated from the following files: