20 #ifndef KALDI_NNET3_NNET_GENERAL_COMPONENT_H_ 21 #define KALDI_NNET3_NNET_GENERAL_COMPONENT_H_ 59 Init(input_dim, output_dim);
67 virtual std::string
Type()
const {
return "DistributeComponent"; }
72 virtual void Backprop(
const std::string &debug_info,
81 virtual void Read(std::istream &is,
bool binary);
85 virtual void Write(std::ostream &os,
bool binary)
const;
93 const Index &output_index,
94 std::vector<Index> *desired_indexes)
const;
99 const Index &output_index,
101 std::vector<Index> *used_inputs)
const;
105 const std::vector<Index> &input_indexes,
106 const std::vector<Index> &output_indexes,
107 bool need_backprop)
const;
120 int32 num_output_rows,
121 std::vector<const BaseFloat*> *input_pointers)
const;
125 int32 num_output_rows,
127 std::vector<BaseFloat*> *input_pointers)
const;
140 std::vector<std::pair<int32, int32> >
pairs;
150 virtual void Write(std::ostream &ostream,
bool binary)
const;
152 virtual void Read(std::istream &istream,
bool binary);
154 virtual std::string
Type()
const {
return "DistributeComponentPrecomputedIndexes"; }
215 virtual std::string
Type()
const {
return "StatisticsExtractionComponent"; }
223 virtual void Backprop(
const std::string &debug_info,
232 virtual void Read(std::istream &is,
bool binary);
236 virtual void Write(std::ostream &os,
bool binary)
const;
243 const Index &output_index,
244 std::vector<Index> *desired_indexes)
const;
247 const Index &output_index,
249 std::vector<Index> *used_inputs)
const;
254 std::vector<Index> *output_indexes)
const;
258 const std::vector<Index> &input_indexes,
259 const std::vector<Index> &output_indexes,
260 bool need_backprop)
const;
298 virtual void Write(std::ostream &os,
bool binary)
const;
300 virtual void Read(std::istream &is,
bool binary);
302 virtual std::string
Type()
const {
return "StatisticsExtractionComponentPrecomputedIndexes"; }
347 return input_dim_ + num_log_count_features_ - 1;
350 virtual std::string
Type()
const {
return "StatisticsPoolingComponent"; }
353 (output_stddevs_ || num_log_count_features_ > 0 ?
360 virtual void Backprop(
const std::string &debug_info,
369 virtual void Read(std::istream &is,
bool binary);
373 virtual void Write(std::ostream &os,
bool binary)
const;
380 const Index &output_index,
381 std::vector<Index> *desired_indexes)
const;
385 const Index &output_index,
387 std::vector<Index> *used_inputs)
const;
392 std::vector<Index> *output_indexes)
const;
396 const std::vector<Index> &input_indexes,
397 const std::vector<Index> &output_indexes,
398 bool need_backprop)
const;
454 virtual void Write(std::ostream &os,
bool binary)
const;
456 virtual void Read(std::istream &is,
bool binary);
458 virtual std::string
Type()
const {
return "StatisticsPoolingComponentPrecomputedIndexes"; }
472 int32 zeroing_interval,
473 int32 recurrence_interval) {
474 Init(dim, scale, clipping_threshold, zeroing_threshold,
475 zeroing_interval, recurrence_interval);}
478 zeroing_threshold_(-1), zeroing_interval_(0), recurrence_interval_(0),
479 num_clipped_(0), num_zeroed_(0), count_(0), count_zeroing_boundaries_(0) { }
486 int32 recurrence_interval);
488 virtual std::string
Type()
const {
return "BackpropTruncationComponent"; }
501 virtual void Backprop(
const std::string &debug_info,
512 const std::vector<Index> &input_indexes,
513 const std::vector<Index> &output_indexes,
514 bool need_backprop)
const;
518 virtual void Read(std::istream &is,
bool binary);
521 virtual void Write(std::ostream &os,
bool binary)
const;
522 virtual std::string
Info()
const;
593 virtual void Write(std::ostream &ostream,
bool binary)
const;
595 virtual void Read(std::istream &istream,
bool binary);
597 virtual std::string
Type()
const {
598 return "BackpropTruncationComponentPrecomputedIndexes";
636 virtual std::string
Info()
const;
647 virtual std::string
Type()
const {
return "ConstantComponent"; }
655 virtual void Backprop(
const std::string &debug_info,
664 virtual void Read(std::istream &is,
bool binary);
665 virtual void Write(std::ostream &os,
bool binary)
const;
671 const Index &output_index,
672 std::vector<Index> *desired_indexes)
const {
673 desired_indexes->clear();
680 const Index &output_index,
682 std::vector<Index> *used_inputs)
const {
683 if (used_inputs) used_inputs->clear();
730 virtual std::string
Info()
const;
745 virtual std::string
Type()
const {
return "DropoutMaskComponent"; }
753 virtual void Backprop(
const std::string &debug_info,
762 virtual void Read(std::istream &is,
bool binary);
763 virtual void Write(std::ostream &os,
bool binary)
const;
769 const Index &output_index,
770 std::vector<Index> *desired_indexes)
const {
771 desired_indexes->clear();
778 const Index &output_index,
780 std::vector<Index> *used_inputs)
const {
781 if (used_inputs) used_inputs->clear();
881 virtual std::string
Info()
const;
889 virtual std::string
Type()
const {
return "GeneralDropoutComponent"; }
898 virtual void Backprop(
const std::string &debug_info,
913 const std::vector<Index> &input_indexes,
914 const std::vector<Index> &output_indexes,
915 bool need_backprop)
const;
917 virtual void Read(std::istream &is,
bool binary);
918 virtual void Write(std::ostream &os,
bool binary)
const;
986 virtual void Write(std::ostream &os,
bool binary)
const;
988 virtual void Read(std::istream &is,
bool binary);
990 virtual std::string
Type()
const {
991 return "GeneralDropoutComponentPrecomputedIndexes";
1023 virtual std::string
Info()
const;
1031 virtual std::string
Type()
const {
return "SpecAugmentTimeMaskComponent"; }
1039 virtual void Backprop(
const std::string &debug_info,
1054 const std::vector<Index> &input_indexes,
1055 const std::vector<Index> &output_indexes,
1056 bool need_backprop)
const;
1058 virtual void Read(std::istream &is,
bool binary);
1059 virtual void Write(std::ostream &os,
bool binary)
const;
1105 virtual void Write(std::ostream &os,
bool binary)
const;
1107 virtual void Read(std::istream &is,
bool binary);
1110 return "SpecAugmentTimeMaskComponentPrecomputedIndexes";
std::vector< std::vector< int32 > > indexes
virtual bool IsComputable(const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
This function only does something interesting for non-simple Components, and it exists to make it pos...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
virtual int32 OutputDim() const
Returns output-dimension of this component.
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL...
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
BaseFloat zeroing_threshold_
void ComputeInputIndexAndBlock(const Index &output_index, Index *input_index, int32 *block) const
BaseFloat dropout_proportion_
bool use_natural_gradient_
virtual bool IsComputable(const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
This function only does something interesting for non-simple Components, and it exists to make it pos...
virtual Component * Copy() const
Copies component (deep copy).
virtual int32 OutputDim() const
Returns output-dimension of this component.
int32 recurrence_interval_
virtual void GetInputIndexes(const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
This function only does something interesting for non-simple Components.
Abstract base-class for neural-net components.
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &, const CuMatrixBase< BaseFloat > &, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL...
virtual ComponentPrecomputedIndexes * Copy() const
virtual int32 Properties() const
Return bitmask of the component's properties.
An abstract representation of a set of Indexes.
BackpropTruncationComponent()
void SetDropoutProportion(BaseFloat p)
virtual ~DistributeComponentPrecomputedIndexes()
virtual void DeleteMemo(void *memo) const
This virtual function only needs to be overwritten by Components that return a non-NULL memo from the...
virtual int32 InputDim() const
Returns input-dimension of this component.
BaseFloat dropout_proportion_
void SetRequireDirectInput(bool b)
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
virtual std::string Type() const
virtual ~GeneralDropoutComponentPrecomputedIndexes()
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
virtual void Write(std::ostream &os, bool binary) const
Write component to stream.
Keywords for search: natural gradient, naturalgradient, NG-SGD.
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
virtual void InitFromConfig(ConfigLine *cfl)
Initialize, from a ConfigLine object.
virtual ~SpecAugmentTimeMaskComponentPrecomputedIndexes()
virtual std::string Type() const
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
SpecAugmentTimeMaskComponent implements the time part of SpecAugment.
BackpropTruncationComponentPrecomputedIndexes()
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
virtual int32 OutputDim() const
Returns output-dimension of this component.
virtual int32 InputDim() const
Returns input-dimension of this component.
virtual void Scale(BaseFloat scale)
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when c...
virtual int32 OutputDim() const
Returns output-dimension of this component.
int32 NumParameters(const Nnet &src)
Returns the total of the number of parameters in the updatable components of the nnet.
virtual ~StatisticsPoolingComponentPrecomputedIndexes()
virtual void ZeroStats()
Components that provide an implementation of StoreStats should also provide an implementation of Zero...
CuArray< Int32Pair > forward_indexes
BaseFloat zeroed_proportion_
virtual bool IsComputable(const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
This function only does something interesting for non-simple Components, and it exists to make it pos...
int32 time_mask_max_frames_
bool require_direct_input_
CuArray< Int32Pair > backward_indexes
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
int32 num_log_count_features_
virtual void ReorderIndexes(std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
This function only does something interesting for non-simple Components.
virtual std::string Type() const
BaseFloat variance_floor_
virtual void GetInputIndexes(const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
This function only does something interesting for non-simple Components.
virtual int32 Properties() const
Return bitmask of the component's properties.
This Component takes a larger input-dim than output-dim, where the input-dim must be a multiple of th...
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
virtual int32 OutputDim() const
Returns output-dimension of this component.
virtual int32 InputDim() const
Returns input-dimension of this component.
virtual int32 Properties() const
Return bitmask of the component's properties.
BaseFloat DotProduct(const Nnet &nnet1, const Nnet &nnet2)
Returns dot product between two networks of the same structure (calls the DotProduct functions of the...
virtual int32 OutputDim() const
Returns output-dimension of this component.
virtual ~BackpropTruncationComponentPrecomputedIndexes()
virtual int32 InputDim() const
Returns input-dimension of this component.
DistributeComponent(int32 input_dim, int32 output_dim)
double count_zeroing_boundaries_
Class UpdatableComponent is a Component which has trainable parameters; it extends the interface of C...
int32 specaugment_max_regions_
void SetDropoutProportion(BaseFloat p)
virtual int32 Properties() const
Return bitmask of the component's properties.
GeneralDropoutComponent implements dropout, including a continuous variant where the thing we multipl...
virtual std::string Type() const
CuVector< BaseFloat > output_
Matrix for CUDA computing.
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
Propagate function.
ComponentPrecomputedIndexes * Copy() const
virtual ComponentPrecomputedIndexes * PrecomputeIndexes(const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
This function must return NULL for simple Components.
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing='a b c' baz="a b c d='a b' e" and giving you access to the fields, in this case.
BaseFloat clipping_threshold_
ComponentPrecomputedIndexes * Copy() const
virtual void Read(std::istream &is, bool binary)
Read function (used after we know the type of the Component); accepts input that is missing the token...
virtual void GetInputIndexes(const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
This function only does something interesting for non-simple Components.
void Init(int32 input_dim, int32 output_dim)
virtual void ConsolidateMemory()
This virtual function relates to memory management, and avoiding fragmentation.
virtual int32 InputDim() const
Returns input-dimension of this component.
void PerturbParams(BaseFloat stddev, Nnet *nnet)
Calls PerturbParams (with the given stddev) on all updatable components of the nnet.
OnlineNaturalGradient preconditioner_
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
virtual int32 OutputDim() const
Returns output-dimension of this component.
virtual int32 Properties() const
Return bitmask of the component's properties.
CuVector< BaseFloat > zeroing
Provides a vector abstraction class.
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
std::vector< std::pair< int32, int32 > > pairs
BaseFloat specaugment_max_proportion_
virtual void Add(BaseFloat alpha, const Component &other)
This virtual function when called by – an UpdatableComponent adds the parameters of another updatabl...
virtual ComponentPrecomputedIndexes * Copy() const
virtual int32 InputDim() const
Returns input-dimension of this component.
BackpropTruncationComponent(int32 dim, BaseFloat scale, BaseFloat clipping_threshold, BaseFloat zeroing_threshold, int32 zeroing_interval, int32 recurrence_interval)
virtual ~BackpropTruncationComponent()
virtual int32 InputDim() const
Returns input-dimension of this component.
virtual std::string Type() const
void ComputeInputPointers(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, int32 num_output_rows, std::vector< const BaseFloat *> *input_pointers) const
ComponentPrecomputedIndexes * Copy() const
virtual Component * Copy() const
Copies component (deep copy).
virtual int32 Properties() const
Return bitmask of the component's properties.
virtual int32 Properties() const
Return bitmask of the component's properties.
virtual void DeleteMemo(void *memo) const
This virtual function only needs to be overwritten by Components that return a non-NULL memo from the...