BinarySumDescriptor Class Reference

BinarySumDescriptor can represent either A + B, or (A if defined, else B). More...

#include <nnet-descriptor.h>

Inheritance diagram for BinarySumDescriptor:
Collaboration diagram for BinarySumDescriptor:

Public Types

enum  Operation { kSumOperation, kFailoverOperation }
 

Public Member Functions

virtual void GetDependencies (const Index &ind, std::vector< Cindex > *dependencies) const
 Given an Index at the output of this Descriptor, append to "dependencies" a list of Cindexes that describes what inputs we potentially depend on. More...
 
virtual bool IsComputable (const Index &ind, const CindexSet &cindex_set, std::vector< Cindex > *used_inputs) const
 This function exists to enable us to manage optional dependencies, i.e. More...
 
virtual int32 Dim (const Nnet &nnet) const
 
virtual BaseFloat GetScaleForNode (int32 node_index) const
 This function returns the scale on the node-index 'node_index' when it appears in expressions inside this descriptor. More...
 
virtual void GetNodeDependencies (std::vector< int32 > *node_indexes) const
 This function appends to "node_indexes" a list (not necessarily sorted or unique) of all the node indexes that this descriptor may forward data from. More...
 
virtual int32 Modulus () const
 
virtual void WriteConfig (std::ostream &os, const std::vector< std::string > &node_names) const
 Written form is: if op_ == kSum then "Sum(<src1>, <src2>)"; if op_ == kFailover, then "Failover(<src1>, <src2>)" If you need more than binary operations, just use Sum(a, Sum(b, c)). More...
 
virtual SumDescriptorCopy () const
 
 BinarySumDescriptor (Operation op, SumDescriptor *src1, SumDescriptor *src2)
 
virtual ~BinarySumDescriptor ()
 
- Public Member Functions inherited from SumDescriptor
virtual ~SumDescriptor ()
 

Private Attributes

Operation op_
 
SumDescriptorsrc1_
 
SumDescriptorsrc2_
 

Detailed Description

BinarySumDescriptor can represent either A + B, or (A if defined, else B).

Other expressions such as A + (B if defined, else zero), (A if defined, else zero) + (B if defined, else zero), and (A if defined, else B if defined, else zero) can be expressed using combinations of the two provided options for BinarySumDescriptor and the variant

Definition at line 487 of file nnet-descriptor.h.

Member Enumeration Documentation

◆ Operation

enum Operation
Enumerator
kSumOperation 
kFailoverOperation 

Definition at line 489 of file nnet-descriptor.h.

Constructor & Destructor Documentation

◆ BinarySumDescriptor()

BinarySumDescriptor ( Operation  op,
SumDescriptor src1,
SumDescriptor src2 
)
inline

Definition at line 511 of file nnet-descriptor.h.

511  :
512  op_(op), src1_(src1), src2_(src2) {}

◆ ~BinarySumDescriptor()

virtual ~BinarySumDescriptor ( )
inlinevirtual

Definition at line 513 of file nnet-descriptor.h.

513 { delete src1_; delete src2_; }

Member Function Documentation

◆ Copy()

SumDescriptor * Copy ( ) const
virtual

Implements SumDescriptor.

Definition at line 450 of file nnet-descriptor.cc.

450  {
451  return new BinarySumDescriptor(op_, src1_->Copy(), src2_->Copy());
452 }
virtual SumDescriptor * Copy() const =0
BinarySumDescriptor(Operation op, SumDescriptor *src1, SumDescriptor *src2)

◆ Dim()

int32 Dim ( const Nnet nnet) const
virtual

Implements SumDescriptor.

Definition at line 401 of file nnet-descriptor.cc.

References KALDI_ERR.

401  {
402  int32 dim1 = src1_->Dim(nnet), dim2 = src2_->Dim(nnet);
403  if (dim1 != dim2)
404  KALDI_ERR << "Neural net contains " << (op_ == kSumOperation ? "Sum" :
405  "Failover")
406  << " expression with inconsistent dimension: " << dim1
407  << " vs. " << dim2;
408  return dim1;
409 }
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147
virtual int32 Dim(const Nnet &nnet) const =0

◆ GetDependencies()

void GetDependencies ( const Index ind,
std::vector< Cindex > *  dependencies 
) const
virtual

Given an Index at the output of this Descriptor, append to "dependencies" a list of Cindexes that describes what inputs we potentially depend on.

The output list is not necessarily sorted, and this function doesn't make sure that it's unique.

Implements SumDescriptor.

Definition at line 355 of file nnet-descriptor.cc.

356  {
357  src1_->GetDependencies(ind, dependencies);
358  src2_->GetDependencies(ind, dependencies);
359 }
virtual void GetDependencies(const Index &ind, std::vector< Cindex > *dependencies) const =0
Given an Index at the output of this Descriptor, append to "dependencies" a list of Cindexes that des...

◆ GetNodeDependencies()

void GetNodeDependencies ( std::vector< int32 > *  node_indexes) const
virtual

This function appends to "node_indexes" a list (not necessarily sorted or unique) of all the node indexes that this descriptor may forward data from.

Implements SumDescriptor.

Definition at line 440 of file nnet-descriptor.cc.

441  {
442  src1_->GetNodeDependencies(node_indexes);
443  src2_->GetNodeDependencies(node_indexes);
444 }
virtual void GetNodeDependencies(std::vector< int32 > *node_indexes) const =0
This function appends to "node_indexes" a list (not necessarily sorted or unique) of all the node ind...

◆ GetScaleForNode()

BaseFloat GetScaleForNode ( int32  node_index) const
virtual

This function returns the scale on the node-index 'node_index' when it appears in expressions inside this descriptor.

E.g. if the descriptor is just `Scale(tdnn2, 2.0)` and the node index for `tdnn2` is 4, then GetScaleForNode(4) would return 2.0. It will return +infinity if the node is >= 0 and does not appear in this descriptor. If node_index < 0, it returns the constant offset value from this descriptor, which will equal 0.0 if there is no expression like `Const(1.0, 512)` in this node. If a particular node_index > 0 appears in different sub-expressions of the descriptor with different scales it is an error (it's not supported) and this function would crash.

Implements SumDescriptor.

Definition at line 411 of file nnet-descriptor.cc.

References KALDI_ASSERT, and KALDI_ERR.

411  {
412  BaseFloat ans1 = src1_->GetScaleForNode(node_index),
413  ans2 = src2_->GetScaleForNode(node_index);
414  bool ans1_valid = (ans1 - ans1 == 0),
415  ans2_valid = (ans2 - ans2 == 0); // Test for infinity.
416  if (node_index < 0) { // the query is about the constant offset, not for a
417  // specific node.
418  KALDI_ASSERT(ans1_valid && ans2_valid);
419  if (op_ == kSumOperation) {
420  // For a sum operation, if there were more than one Const(..) expression,
421  // they would logically add together (even though it would be redundant to
422  // write such a thing).
423  return ans1 + ans2;
424  } else if (ans1 != ans2) {
425  KALDI_ERR << "Illegal combination of Failover operation with Const() "
426  "expression encountered in Descriptor (this is not supported).";
427  }
428  }
429  if (ans1_valid && ans2_valid && ans1 != ans2) {
430  // this would be a code error so don't print a very informative message.
431  KALDI_ERR << "Inconsistent value for sum descriptor: for node "
432  << node_index << ", it can have scales "
433  << ans1 << " vs. " << ans2 << " (you have used unsupported "
434  "combinations of descriptors).";
435  }
436  if (!ans2_valid) return ans1;
437  else return ans2;
438 }
virtual BaseFloat GetScaleForNode(int32 node_index) const =0
This function returns the scale on the node-index &#39;node_index&#39; when it appears in expressions inside ...
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ IsComputable()

bool IsComputable ( const Index ind,
const CindexSet cindex_set,
std::vector< Cindex > *  used_inputs 
) const
virtual

This function exists to enable us to manage optional dependencies, i.e.

for making sense of expressions like (A + (B is present)) and (A if present; if not, B). Suppose we are trying to compute the index "ind", and the user represents that "cindex_set" is the set of Cindexes are available to the computation; then this function will return true if we can compute the expression given these inputs; and if so, will output to "used_inputs" the list of Cindexes that this expression will be a summation over.

Parameters
[in]indThe index that we want to compute at the output of the Descriptor.
[in]cindex_setThe set of Cindexes that are available at the input of the Descriptor.
[out]used_inputsIf non-NULL, if this function returns true then to this vector will be *appended* the inputs that will actually participate in the computation. Else (if non-NULL) it will be left unchanged.
Returns
Returns true if this output is computable given the provided inputs.

Implements SumDescriptor.

Definition at line 361 of file nnet-descriptor.cc.

References KALDI_ASSERT.

364  {
365  std::vector<Cindex> src1_inputs, src2_inputs;
366  bool r = (used_inputs != NULL);
367  bool src1_computable = src1_->IsComputable(ind, cindex_set,
368  r ? &src1_inputs: NULL),
369  src2_computable = src2_->IsComputable(ind, cindex_set,
370  r ? &src2_inputs : NULL);
371  if (op_ == kSumOperation) {
372  if (src1_computable && src2_computable) {
373  if (r) {
374  used_inputs->insert(used_inputs->end(),
375  src1_inputs.begin(), src1_inputs.end());
376  used_inputs->insert(used_inputs->end(),
377  src2_inputs.begin(), src2_inputs.end());
378  }
379  return true;
380  } else {
381  return false;
382  }
383  } else {
385  if (src1_computable) {
386  if (r)
387  used_inputs->insert(used_inputs->end(),
388  src1_inputs.begin(), src1_inputs.end());
389  return true;
390  } else if (src2_computable) {
391  if (r)
392  used_inputs->insert(used_inputs->end(),
393  src2_inputs.begin(), src2_inputs.end());
394  return true;
395  } else {
396  return false;
397  }
398  }
399 }
virtual bool IsComputable(const Index &ind, const CindexSet &cindex_set, std::vector< Cindex > *used_inputs) const =0
This function exists to enable us to manage optional dependencies, i.e.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Modulus()

int32 Modulus ( ) const
virtual

Implements SumDescriptor.

Definition at line 446 of file nnet-descriptor.cc.

References kaldi::Lcm().

446  {
447  return Lcm(src1_->Modulus(), src2_->Modulus());
448 }
virtual int32 Modulus() const =0
I Lcm(I m, I n)
Returns the least common multiple of two integers.
Definition: kaldi-math.h:318

◆ WriteConfig()

void WriteConfig ( std::ostream &  os,
const std::vector< std::string > &  node_names 
) const
virtual

Written form is: if op_ == kSum then "Sum(<src1>, <src2>)"; if op_ == kFailover, then "Failover(<src1>, <src2>)" If you need more than binary operations, just use Sum(a, Sum(b, c)).

Implements SumDescriptor.

Definition at line 454 of file nnet-descriptor.cc.

References KALDI_ASSERT.

456  {
458  if (op_ == kSumOperation) os << "Sum(";
459  if (op_ == kFailoverOperation) os << "Failover(";
460  src1_->WriteConfig(os, node_names);
461  os << ", ";
462  src2_->WriteConfig(os, node_names);
463  os << ")";
464 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual void WriteConfig(std::ostream &os, const std::vector< std::string > &node_names) const =0
Write in config-file format.

Member Data Documentation

◆ op_

Operation op_
private

Definition at line 515 of file nnet-descriptor.h.

◆ src1_

SumDescriptor* src1_
private

Definition at line 516 of file nnet-descriptor.h.

◆ src2_

SumDescriptor* src2_
private

Definition at line 517 of file nnet-descriptor.h.


The documentation for this class was generated from the following files: