20 #ifndef KALDI_NNET3_NNET_DESCRIPTOR_H_ 21 #define KALDI_NNET3_NNET_DESCRIPTOR_H_ 87 class ForwardingDescriptor {
103 virtual int32 Dim(
const Nnet &nnet)
const = 0;
118 const std::vector<std::string> &node_names)
const = 0;
156 const std::vector<std::string> &node_names)
const;
160 src_node_(src_node), scale_(scale) {
183 const std::vector<std::string> &node_names)
const;
192 Index offset): src_(src), offset_(offset) { }
213 virtual int32 Dim(
const Nnet &nnet)
const {
return src_[0]->Dim(nnet); }
217 const std::vector<std::string> &node_names)
const;
232 std::vector<ForwardingDescriptor*>
src_;
249 const std::vector<std::string> &node_names)
const;
261 src_(src), t_modulus_(t_modulus) { }
283 const std::vector<std::string> &node_names)
const;
294 src_(src), variable_name_(variable_name), value_(value) { }
321 virtual void GetDependencies(
const Index &ind,
322 std::vector<Cindex> *dependencies)
const = 0;
343 virtual bool IsComputable(
const Index &ind,
345 std::vector<Cindex> *used_inputs)
const = 0;
375 const std::vector<std::string> &node_names)
const = 0;
385 virtual void GetDependencies(
const Index &ind,
386 std::vector<Cindex> *dependencies)
const;
389 std::vector<Cindex> *used_inputs)
const {
390 return src_->IsComputable(ind, cindex_set, used_inputs) ||
true;
403 const std::vector<std::string> &node_names)
const;
418 virtual void GetDependencies(
const Index &ind,
419 std::vector<Cindex> *dependencies)
const;
420 virtual bool IsComputable(
const Index &ind,
422 std::vector<Cindex> *used_inputs)
const;
434 const std::vector<std::string> &node_names)
const;
458 std::vector<Cindex> *dependencies)
const { }
461 std::vector<Cindex> *used_inputs)
const {
472 const std::vector<std::string> &node_names)
const;
493 virtual void GetDependencies(
const Index &ind,
494 std::vector<Cindex> *dependencies)
const;
495 virtual bool IsComputable(
const Index &ind,
497 std::vector<Cindex> *used_inputs)
const;
509 const std::vector<std::string> &node_names)
const;
512 op_(op), src1_(src1), src2_(src2) {}
536 bool Parse(
const std::vector<std::string> &node_names,
537 const std::string **next_token);
543 const std::vector<std::string> &node_names)
const;
565 void GetDependencies(
const Index &index,
566 std::vector<Cindex> *used_inputs)
const;
571 bool IsComputable(
const Index &ind,
573 std::vector<Cindex> *used_inputs)
const;
611 kRound, kReplaceIndex,
kScale, kConst, kNodeName };
621 const std::string **next_token);
625 descriptor_type_(t), value1_(value1), value2_(value2),
636 void Print(
const std::vector<std::string> &node_names,
669 void ParseAppendOrSumOrSwitch(
const std::vector<std::string> &node_names,
670 const std::string **next_token);
673 void ParseIfDefined(
const std::vector<std::string> &node_names,
674 const std::string **next_token);
676 void ParseOffset(
const std::vector<std::string> &node_names,
677 const std::string **next_token);
678 void ParseSwitch(
const std::vector<std::string> &node_names,
679 const std::string **next_token);
680 void ParseFailover(
const std::vector<std::string> &node_names,
681 const std::string **next_token);
682 void ParseRound(
const std::vector<std::string> &node_names,
683 const std::string **next_token);
684 void ParseScale(
const std::vector<std::string> &node_names,
685 const std::string **next_token);
686 void ParseConst(
const std::vector<std::string> &node_names,
687 const std::string **next_token);
688 void ParseReplaceIndex(
const std::vector<std::string> &node_names,
689 const std::string **next_token);
697 int32 NumAppendTerms()
const;
KALDI_DISALLOW_COPY_AND_ASSIGN(ForwardingDescriptor)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
GeneralDescriptor(DescriptorType t, int32 value1=-1, int32 value2=-1, BaseFloat alpha=0.0)
virtual ~BinarySumDescriptor()
virtual void WriteConfig(std::ostream &os, const std::vector< std::string > &node_names) const =0
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
OffsetForwardingDescriptor(ForwardingDescriptor *src, Index offset)
virtual ~OffsetForwardingDescriptor()
VariableName variable_name_
This is the case of class SumDescriptor, in which we contain just one term, and that term is optional...
virtual ~SimpleSumDescriptor()
virtual int32 Dim(const Nnet &nnet) const =0
virtual ~ForwardingDescriptor()
RoundingForwardingDescriptor(ForwardingDescriptor *src, int32 t_modulus)
virtual void GetNodeDependencies(std::vector< int32 > *node_indexes) const =0
This function appends to "node_indexes" all the node indexes.
This class is only used when parsing Descriptors.
virtual ~ConstantSumDescriptor()
std::vector< GeneralDescriptor * > descriptors_
ForwardingDescriptor * src_
This is an alternative base-case of SumDescriptor (an alternative to SimpleSumDescriptor) which repre...
virtual bool IsComputable(const Index &ind, const CindexSet &cindex_set, std::vector< Cindex > *used_inputs) const
This function exists to enable us to manage optional dependencies, i.e.
SimpleForwardingDescriptor is the base-case of ForwardingDescriptor, consisting of a source node in t...
virtual Cindex MapToInput(const Index &output) const =0
virtual bool IsComputable(const Index &ind, const CindexSet &cindex_set, std::vector< Cindex > *used_inputs) const
This function exists to enable us to manage optional dependencies, i.e.
std::vector< SumDescriptor * > parts_
BinarySumDescriptor can represent either A + B, or (A if defined, else B).
virtual int32 Dim(const Nnet &nnet) const
ReplaceIndexForwardingDescriptor(ForwardingDescriptor *src, VariableName variable_name, int32 value)
virtual void GetDependencies(const Index &ind, std::vector< Cindex > *dependencies) const
Given an Index at the output of this Descriptor, append to "dependencies" a list of Cindexes that des...
For use in clockwork RNNs and the like, this forwarding-descriptor rounds the time-index t down to th...
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
ForwardingDescriptor * src_
virtual int32 Dim(const Nnet &nnet) const
std::pair< int32, Index > Cindex
virtual void GetNodeDependencies(std::vector< int32 > *node_indexes) const
This function appends to "node_indexes" a list (not necessarily sorted or unique) of all the node ind...
DescriptorType descriptor_type_
This is an abstract base-class.
virtual int32 Dim(const Nnet &nnet) const
const ForwardingDescriptor & Src() const
std::vector< ForwardingDescriptor * > src_
Chooses from different inputs based on the the time index modulo (the number of ForwardingDescriptors...
virtual int32 Dim(const Nnet &nnet) const
virtual int32 Modulus() const
virtual int32 Modulus() const
This function is for use in things like clockwork RNNs, where shifting the time of the inputs and out...
virtual int32 Modulus() const
virtual int32 Modulus() const
OptionalSumDescriptor(SumDescriptor *src)
This is the normal base-case of SumDescriptor which just wraps a ForwardingDescriptor.
virtual int32 Modulus() const
This function is for use in things like clockwork RNNs, where shifting the time of the inputs and out...
virtual int32 Modulus() const
This function is for use in things like clockwork RNNs, where shifting the time of the inputs and out...
const ForwardingDescriptor & Src() const
A ForwardingDescriptor describes how we copy data from another NetworkNode, or from multiple other Ne...
virtual ~OptionalSumDescriptor()
virtual ~RoundingForwardingDescriptor()
virtual ForwardingDescriptor * Copy() const =0
Offsets in 't' and 'x' values of other ForwardingDescriptors.
virtual ~SwitchingForwardingDescriptor()
#define KALDI_ASSERT(cond)
This ForwardingDescriptor modifies the indexes (n, t, x) by replacing one of them (normally t) with a...
virtual ~SimpleForwardingDescriptor()
SimpleSumDescriptor(ForwardingDescriptor *src)
ForwardingDescriptor * src_
void Print(const Fst< Arc > &fst, std::string message)
Descriptor(const Descriptor &other)
Copy constructor.
BinarySumDescriptor(Operation op, SumDescriptor *src1, SumDescriptor *src2)
SimpleForwardingDescriptor(int32 src_node, BaseFloat scale=1.0)
int32 NumParts() const
Returns the number of parts that are concatenated over.
virtual ~ReplaceIndexForwardingDescriptor()
Descriptor(const std::vector< SumDescriptor *> &parts)
Takes ownership of pointers in "parts".
virtual int32 Dim(const Nnet &nnet) const
SwitchingForwardingDescriptor(std::vector< ForwardingDescriptor *> &src)
ForwardingDescriptor * src_
virtual BaseFloat GetScaleForNode(int32 node_index) const =0
This function returns the scale on the node-index 'node_index' when it appears in expressions inside ...