20 #ifndef KALDI_NNET3_NNET_ATTENTION_COMPONENT_H_ 21 #define KALDI_NNET3_NNET_ATTENTION_COMPONENT_H_ 128 virtual std::string
Info()
const;
130 virtual std::string
Type()
const {
return "RestrictedAttentionComponent"; }
145 virtual void Backprop(
const std::string &debug_info,
153 virtual void Read(std::istream &is,
bool binary);
154 virtual void Write(std::ostream &os,
bool binary)
const;
166 std::vector<Index> *output_indexes)
const;
169 const Index &output_index,
170 std::vector<Index> *desired_indexes)
const;
175 const Index &output_index,
177 std::vector<Index> *used_inputs)
const;
181 const std::vector<Index> &input_indexes,
182 const std::vector<Index> &output_indexes,
183 bool need_backprop)
const;
191 virtual void Write(std::ostream &os,
bool binary)
const;
192 virtual void Read(std::istream &os,
bool binary);
193 virtual std::string
Type()
const {
194 return "RestrictedAttentionComponentPrecomputedIndexes";
245 const std::vector<Index> &input_indexes,
246 const std::vector<Index> &output_indexes,
260 const std::vector<Index> &input_indexes,
261 const std::vector<Index> &output_indexes,
263 std::vector<Index> *new_input_indexes,
264 std::vector<Index> *new_output_indexes)
const;
273 const std::vector<std::pair<int32, int32> > &n_x_pairs,
275 const std::unordered_set<Index, IndexHasher> &index_set,
276 std::vector<Index> *output_indexes);
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void GetIndexes(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo &io, std::vector< Index > *new_input_indexes, std::vector< Index > *new_output_indexes) const
virtual void Add(BaseFloat alpha, const Component &other)
This virtual function when called by – an UpdatableComponent adds the parameters of another updatabl...
void GetComputationStructure(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo *io) const
virtual int32 InputDim() const
Returns input-dimension of this component.
virtual void Read(std::istream &is, bool binary)
Read function (used after we know the type of the Component); accepts input that is missing the token...
Abstract base-class for neural-net components.
virtual Component * Copy() const
Copies component (deep copy).
An abstract representation of a set of Indexes.
PrecomputedIndexes(const PrecomputedIndexes &other)
int32 num_left_inputs_required_
void BackpropOneHead(const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &out_deriv, CuMatrixBase< BaseFloat > *in_deriv) const
static void CreateIndexesVector(const std::vector< std::pair< int32, int32 > > &n_x_pairs, int32 t_start, int32 t_step, int32 num_t_values, const std::unordered_set< Index, IndexHasher > &index_set, std::vector< Index > *output_indexes)
Utility function used in GetIndexes().
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
virtual bool IsComputable(const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
This function only does something interesting for non-simple Components, and it exists to make it pos...
RestrictedAttentionComponent()
int32 num_right_inputs_required_
CuMatrix< double > posterior_stats_
virtual int32 OutputDim() const
Returns output-dimension of this component.
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL...
virtual PrecomputedIndexes * Copy() const
time_height_convolution::ConvolutionComputationIo io
virtual void DeleteMemo(void *memo) const
This virtual function only needs to be overwritten by Components that return a non-NULL memo from the...
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
virtual void Scale(BaseFloat scale)
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when c...
Vector< double > entropy_stats_
virtual void InitFromConfig(ConfigLine *cfl)
Initialize, from a ConfigLine object.
virtual ComponentPrecomputedIndexes * PrecomputeIndexes(const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
This function must return NULL for simple Components.
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
virtual void ZeroStats()
Components that provide an implementation of StoreStats should also provide an implementation of Zero...
virtual void GetInputIndexes(const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
This function only does something interesting for non-simple Components.
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
virtual std::string Type() const
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
Propagate function.
virtual void Write(std::ostream &os, bool binary) const
Write component to stream.
virtual void Read(std::istream &os, bool binary)
virtual void ReorderIndexes(std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
This function only does something interesting for non-simple Components.
Matrix for CUDA computing.
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing='a b c' baz="a b c d='a b' e" and giving you access to the fields, in this case.
RestrictedAttentionComponent implements an attention model with restricted temporal context...
virtual ~PrecomputedIndexes()
virtual void Write(std::ostream &os, bool binary) const
This file contains the lower-level interface for self-attention.
virtual int32 Properties() const
Return bitmask of the component's properties.
void PropagateOneHead(const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *out) const
virtual void StoreStats(const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
This function may store stats on average activation values, and for some component types...