RestrictedAttentionComponent implements an attention model with restricted temporal context. More...
#include <nnet-attention-component.h>
Classes | |
struct | Memo |
class | PrecomputedIndexes |
Public Member Functions | |
RestrictedAttentionComponent () | |
RestrictedAttentionComponent (const RestrictedAttentionComponent &other) | |
virtual int32 | InputDim () const |
Returns input-dimension of this component. More... | |
virtual int32 | OutputDim () const |
Returns output-dimension of this component. More... | |
virtual std::string | Info () const |
Returns some text-form information about this component, for diagnostics. More... | |
virtual void | InitFromConfig (ConfigLine *cfl) |
Initialize, from a ConfigLine object. More... | |
virtual std::string | Type () const |
Returns a string such as "SigmoidComponent", describing the type of the object. More... | |
virtual int32 | Properties () const |
Return bitmask of the component's properties. More... | |
virtual void * | Propagate (const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const |
Propagate function. More... | |
virtual void | StoreStats (const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo) |
This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity. More... | |
virtual void | Scale (BaseFloat scale) |
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent. More... | |
virtual void | Add (BaseFloat alpha, const Component &other) |
This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters. More... | |
virtual void | ZeroStats () |
Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero. More... | |
virtual void | Backprop (const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const |
Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update. More... | |
virtual void | Read (std::istream &is, bool binary) |
Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed. More... | |
virtual void | Write (std::ostream &os, bool binary) const |
Write component to stream. More... | |
virtual Component * | Copy () const |
Copies component (deep copy). More... | |
virtual void | DeleteMemo (void *memo) const |
This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function. More... | |
virtual void | ReorderIndexes (std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const |
This function only does something interesting for non-simple Components. More... | |
virtual void | GetInputIndexes (const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const |
This function only does something interesting for non-simple Components. More... | |
virtual bool | IsComputable (const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const |
This function only does something interesting for non-simple Components, and it exists to make it possible to manage optionally-required inputs. More... | |
virtual ComponentPrecomputedIndexes * | PrecomputeIndexes (const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const |
This function must return NULL for simple Components. More... | |
Public Member Functions inherited from Component | |
virtual void | ConsolidateMemory () |
This virtual function relates to memory management, and avoiding fragmentation. More... | |
Component () | |
virtual | ~Component () |
Private Member Functions | |
void | PropagateOneHead (const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *out) const |
void | BackpropOneHead (const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &out_deriv, CuMatrixBase< BaseFloat > *in_deriv) const |
void | GetComputationStructure (const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo *io) const |
void | GetIndexes (const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo &io, std::vector< Index > *new_input_indexes, std::vector< Index > *new_output_indexes) const |
void | Check () const |
Static Private Member Functions | |
static void | CreateIndexesVector (const std::vector< std::pair< int32, int32 > > &n_x_pairs, int32 t_start, int32 t_step, int32 num_t_values, const std::unordered_set< Index, IndexHasher > &index_set, std::vector< Index > *output_indexes) |
Utility function used in GetIndexes(). More... | |
Private Attributes | |
int32 | num_heads_ |
int32 | key_dim_ |
int32 | value_dim_ |
int32 | num_left_inputs_ |
int32 | num_right_inputs_ |
int32 | time_stride_ |
int32 | context_dim_ |
int32 | num_left_inputs_required_ |
int32 | num_right_inputs_required_ |
bool | output_context_ |
BaseFloat | key_scale_ |
double | stats_count_ |
Vector< double > | entropy_stats_ |
CuMatrix< double > | posterior_stats_ |
Additional Inherited Members | |
Static Public Member Functions inherited from Component | |
static Component * | ReadNew (std::istream &is, bool binary) |
Read component from stream (works out its type). Dies on error. More... | |
static Component * | NewComponentOfType (const std::string &type) |
Returns a new Component of the given type e.g. More... | |
RestrictedAttentionComponent implements an attention model with restricted temporal context.
What is implemented here is a case of self-attention, meaning that the set of indexes on the input is the same set as the indexes on the output (like an N->N mapping, ignoring edge effects, as opposed to an N->M mapping that you might see in a translation model). "Restricted" means that the source indexes are constrained to be close to the destination indexes, i.e. when outputting something for time 't' we attend to a narrow range of source time indexes close to 't'.
This component is just a fixed nonlinearity (albeit of a type that "knows about" time, i.e. the output at time 't' depends on inputs at a range of time values). This component is not updatable; all the parameters are expected to live in the previous component which is most likely going to be of type NaturalGradientAffineComponent. For a more in-depth explanation, please see comments in the source of the file attention.h. Also, look at the comments for InputDim() and OutputDim() which help to clarify the input and output formats.
The following are the parameters accepted on the config line, with examples of their values.
num-heads E.g. num-heads=10. Defaults to 1. Having multiple heads just means the same nonlinearity is repeated many times. InputDim() and OutputDim() are multiples of num-heads. key-dim E.g. key-dim=60. Must be specified. Dimension of input keys. value-dim E.g. value-dim=60. Must be specified. Dimension of input values (these are the things over which the component forms a weighted sum, although if output-context=true we append to the output the weights of the weighted sum, as they might also carry useful information. time-stride Stride for 't' value, e.g. 1 or 3. For example, if time-stride=3, to compute the output for t=10 we'd use the input for time values like ... t=7, t=10, t=13, ... (the ends of this range depend on num-left-inputs and num-right-inputs). num-left-inputs Number of frames to the left of the current frame, that we use as input, e.g. 5. (The 't' values used will be separated by 'time-stride'). num-left-inputs must be >= 0. Must be specified. num-right-inputs Number of frames to the right of the current frame, that we use as input, e.g. 2. Must be >= 0 and must be specified. You are not allowed to set both num-left-inputs and num-right-inputs to zero. num-left-inputs-required The number of frames to the left, that are required in order to produce an output. Defaults to the same as num-left-inputs, but you can set it to a smaller value if you want. We'll use zero-padding for non-required inputs that are not present in the input. Be careful with this because it interacts with decoding settings; for non-online decoding and for dumping of egs it would be advisable to increase the extra-left-context parameter by the sum of the difference between num-left-inputs-required and num-left-inputs, although you could leave extra-left-context-initial at zero. num-right-inputs-required See num-left-inputs-required for explanation; it's the mirror image. Defaults to num-right-inputs. However, be even more careful with the right-hand version; if you set this, online (looped) decoding will not work correctly. It might be wiser just to reduce num-right-inputs if you care about real-time decoding. key-scale Scale on the keys (but not the added context). Defaults to 1.0 / sqrt(key-dim), like the 1/sqrt(d_k) value in the "Attention is all you need" paper. This helps prevent saturation of the softmax. output-context (Default: true). If true, output the softmax that encodes which positions we chose, in addition to the input values.
Definition at line 106 of file nnet-attention-component.h.
|
inline |
Definition at line 110 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Copy().
RestrictedAttentionComponent | ( | const RestrictedAttentionComponent & | other | ) |
Definition at line 61 of file nnet-attention-component.cc.
This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters.
– a NonlinearComponent (or another component that stores stats, like BatchNormComponent)– it relates to adding stats. Otherwise it will normally do nothing.
Reimplemented from Component.
Definition at line 261 of file nnet-attention-component.cc.
References CuMatrixBase< Real >::AddMat(), VectorBase< Real >::AddVec(), VectorBase< Real >::Dim(), RestrictedAttentionComponent::entropy_stats_, KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), RestrictedAttentionComponent::posterior_stats_, Vector< Real >::Resize(), CuMatrix< Real >::Resize(), and RestrictedAttentionComponent::stats_count_.
Referenced by RestrictedAttentionComponent::Properties().
|
virtual |
Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update.
[in] | debug_info | The component name, to be printed out in any warning messages. |
[in] | indexes | A pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing). |
[in] | in_value | The matrix that was given as input to the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsInput == 0. |
[in] | out_value | The matrix that was output from the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsOutput == 0 |
[in] | out_deriv | The derivative at the output of this component. |
[in] | memo | This will normally be NULL, but for component types that set the flag kUsesMemo, this will be the return value of the Propagate() function that corresponds to this Backprop() function. Ownership of any pointers is not transferred to the Backprop function; DeleteMemo() will be called to delete it. |
[out] | to_update | If model update is desired, the Component to be updated, else NULL. Does not have to be identical to this. If supplied, you can assume that to_update->Properties() & kUpdatableComponent is nonzero. |
[out] | in_deriv | The derivative at the input of this component, if needed (else NULL). If Properties()&kBackpropInPlace, may be the same matrix as out_deriv. If Properties()&kBackpropAdds, this is added to by the Backprop routine, else it is set. The component code chooses which mode to work in, based on convenience. |
Implements Component.
Definition at line 292 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::BackpropOneHead(), RestrictedAttentionComponent::Memo::c, RestrictedAttentionComponent::context_dim_, RestrictedAttentionComponent::PrecomputedIndexes::io, KALDI_ASSERT, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::num_heads_, ConvolutionComputationIo::num_images, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, CuMatrixBase< Real >::NumRows(), NVTX_RANGE, RestrictedAttentionComponent::output_context_, kaldi::SameDim(), and RestrictedAttentionComponent::value_dim_.
Referenced by RestrictedAttentionComponent::Properties().
|
private |
Definition at line 335 of file nnet-attention-component.cc.
References kaldi::nnet3::attention::AttentionBackward(), RestrictedAttentionComponent::context_dim_, KALDI_ASSERT, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, ConvolutionComputationIo::num_images, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), RestrictedAttentionComponent::output_context_, kaldi::SameDim(), ConvolutionComputationIo::start_t_in, ConvolutionComputationIo::start_t_out, ConvolutionComputationIo::t_step_in, ConvolutionComputationIo::t_step_out, and RestrictedAttentionComponent::value_dim_.
Referenced by RestrictedAttentionComponent::Backprop().
|
private |
Definition at line 277 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::context_dim_, KALDI_ASSERT, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, RestrictedAttentionComponent::stats_count_, RestrictedAttentionComponent::time_stride_, and RestrictedAttentionComponent::value_dim_.
Referenced by RestrictedAttentionComponent::InitFromConfig().
|
inlinevirtual |
Copies component (deep copy).
Implements Component.
Definition at line 155 of file nnet-attention-component.h.
References RestrictedAttentionComponent::RestrictedAttentionComponent().
|
staticprivate |
Utility function used in GetIndexes().
Creates a grid of Indexes, where 't' has the larger stride, and within each block of Indexes for a given 't', we have the given list of (n, x) pairs. For Indexes that we create where the 't' value was not present in 'index_set', we set the 't' value to kNoTime (indicating that it's only for padding, not a real input or an output that's ever used).
Definition at line 575 of file nnet-attention-component.cc.
References KALDI_ASSERT, and kaldi::nnet3::kNoTime.
Referenced by RestrictedAttentionComponent::GetIndexes().
|
inlinevirtual |
This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function.
It's called by NnetComputer in cases where Propagate returns a memo but there will be no backprop to consume it.
Reimplemented from Component.
Definition at line 158 of file nnet-attention-component.h.
References RestrictedAttentionComponent::GetInputIndexes(), RestrictedAttentionComponent::IsComputable(), RestrictedAttentionComponent::PrecomputeIndexes(), and RestrictedAttentionComponent::ReorderIndexes().
|
private |
Definition at line 389 of file nnet-attention-component.cc.
References kaldi::Gcd(), kaldi::nnet3::time_height_convolution::GetComputationIo(), KALDI_ASSERT, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, ConvolutionComputationIo::start_t_in, ConvolutionComputationIo::start_t_out, ConvolutionComputationIo::t_step_in, ConvolutionComputationIo::t_step_out, and RestrictedAttentionComponent::time_stride_.
Referenced by RestrictedAttentionComponent::PrecomputeIndexes(), and RestrictedAttentionComponent::ReorderIndexes().
|
private |
Definition at line 597 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::CreateIndexesVector(), kaldi::nnet3::GetNxList(), KALDI_ASSERT, ConvolutionComputationIo::num_images, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, ConvolutionComputationIo::start_t_in, ConvolutionComputationIo::start_t_out, ConvolutionComputationIo::t_step_in, and ConvolutionComputationIo::t_step_out.
Referenced by RestrictedAttentionComponent::PrecomputeIndexes(), and RestrictedAttentionComponent::ReorderIndexes().
|
virtual |
This function only does something interesting for non-simple Components.
For a given index at the output of the component, tells us what indexes are required at its input (note: "required" encompasses also optionally-required things; it will enumerate all things that we'd like to have). See also IsComputable().
[in] | misc_info | This argument is supplied to handle things that the framework can't very easily supply: information like which time indexes are needed for AggregateComponent, which time-indexes are available at the input of a recurrent network, and so on. We will add members to misc_info as needed. |
[in] | output_index | The Index at the output of the component, for which we are requesting the list of indexes at the component's input. |
[out] | desired_indexes | A list of indexes that are desired at the input. are to be written to here. By "desired" we mean required or optionally-required. |
The default implementation of this function is suitable for any SimpleComponent; it just copies the output_index to a single identical element in input_indexes.
Reimplemented from Component.
Definition at line 507 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::context_dim_, rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kNoTime, Index::n, rnnlm::n, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_right_inputs_, Index::t, RestrictedAttentionComponent::time_stride_, and Index::x.
Referenced by RestrictedAttentionComponent::DeleteMemo().
|
virtual |
Returns some text-form information about this component, for diagnostics.
Starts with the type of the component. E.g. "SigmoidComponent dim=900", although most components will have much more info.
Reimplemented from Component.
Definition at line 32 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::context_dim_, VectorBase< Real >::Dim(), RestrictedAttentionComponent::entropy_stats_, rnnlm::i, RestrictedAttentionComponent::InputDim(), rnnlm::j, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, CuMatrixBase< Real >::NumCols(), RestrictedAttentionComponent::output_context_, RestrictedAttentionComponent::OutputDim(), RestrictedAttentionComponent::posterior_stats_, RestrictedAttentionComponent::stats_count_, RestrictedAttentionComponent::time_stride_, RestrictedAttentionComponent::Type(), and RestrictedAttentionComponent::value_dim_.
Referenced by RestrictedAttentionComponent::OutputDim().
|
virtual |
Initialize, from a ConfigLine object.
[in] | cfl | A ConfigLine containing any parameters that are needed for initialization. For example: "dim=100 param-stddev=0.1" |
Implements Component.
Definition at line 80 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::context_dim_, ConfigLine::GetValue(), KALDI_ERR, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, RestrictedAttentionComponent::output_context_, RestrictedAttentionComponent::stats_count_, RestrictedAttentionComponent::time_stride_, RestrictedAttentionComponent::value_dim_, and ConfigLine::WholeLine().
Referenced by RestrictedAttentionComponent::OutputDim().
|
inlinevirtual |
Returns input-dimension of this component.
Implements Component.
Definition at line 115 of file nnet-attention-component.h.
References RestrictedAttentionComponent::context_dim_, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::num_heads_, and RestrictedAttentionComponent::value_dim_.
Referenced by RestrictedAttentionComponent::Info().
|
virtual |
This function only does something interesting for non-simple Components, and it exists to make it possible to manage optionally-required inputs.
It tells the user whether a given output index is computable from a given set of input indexes, and if so, says which input indexes will be used in the computation.
Implementations of this function are required to have the property that adding an element to "input_index_set" can only ever change IsComputable from false to true, never vice versa.
[in] | misc_info | Some information specific to the computation, such as minimum and maximum times for certain components to do adaptation on; it's a place to put things that don't easily fit in the framework. |
[in] | output_index | The index that is to be computed at the output of this Component. |
[in] | input_index_set | The set of indexes that is available at the input of this Component. |
[out] | used_inputs | If this is non-NULL and the output is computable this will be set to the list of input indexes that will actually be used in the computation. |
The default implementation of this function is suitable for any SimpleComponent: it just returns true if output_index is in input_index_set, and if so sets used_inputs to vector containing that one Index.
Reimplemented from Component.
Definition at line 527 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::context_dim_, KALDI_ASSERT, kaldi::nnet3::kNoTime, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, Index::t, and RestrictedAttentionComponent::time_stride_.
Referenced by RestrictedAttentionComponent::DeleteMemo().
|
inlinevirtual |
Returns output-dimension of this component.
Implements Component.
Definition at line 121 of file nnet-attention-component.h.
References RestrictedAttentionComponent::context_dim_, RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::output_context_, and RestrictedAttentionComponent::value_dim_.
Referenced by RestrictedAttentionComponent::Info().
|
virtual |
This function must return NULL for simple Components.
Returns a pointer to a class that may contain some precomputed component-specific and computation-specific indexes to be in used in the Propagate and Backprop functions.
[in] | misc_info | This argument is supplied to handle things that the framework can't very easily supply: information like which time indexes are needed for AggregateComponent, which time-indexes are available at the input of a recurrent network, and so on. misc_info may not even ever be used here. We will add members to misc_info as needed. |
[in] | input_indexes | A vector of indexes that explains what time-indexes (and other indexes) each row of the in/in_value/in_deriv matrices given to Propagate and Backprop will mean. |
[in] | output_indexes | A vector of indexes that explains what time-indexes (and other indexes) each row of the out/out_value/out_deriv matrices given to Propagate and Backprop will mean. |
[in] | need_backprop | True if we might need to do backprop with this component, so that if any different indexes are needed for backprop then those should be computed too. |
Reimplemented from Component.
Definition at line 622 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::GetComputationStructure(), RestrictedAttentionComponent::GetIndexes(), kaldi::GetVerboseLevel(), RestrictedAttentionComponent::PrecomputedIndexes::io, and KALDI_ASSERT.
Referenced by RestrictedAttentionComponent::DeleteMemo().
|
virtual |
Propagate function.
[in] | indexes | A pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing). |
[in] | in | The input to this component. Num-columns == InputDim(). |
[out] | out | The output of this component. Num-columns == OutputDim(). Note: output of this component will be added to the initial value of "out" if Properties()&kPropagateAdds != 0; otherwise the output will be set and the initial value ignored. Each Component chooses whether it is more convenient implementation-wise to add or set, and the calling code has to deal with it. |
Implements Component.
Definition at line 132 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::context_dim_, RestrictedAttentionComponent::PrecomputedIndexes::io, KALDI_ASSERT, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::num_heads_, ConvolutionComputationIo::num_images, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, CuMatrixBase< Real >::NumRows(), RestrictedAttentionComponent::output_context_, RestrictedAttentionComponent::PropagateOneHead(), and RestrictedAttentionComponent::value_dim_.
Referenced by RestrictedAttentionComponent::Properties().
|
private |
Definition at line 160 of file nnet-attention-component.cc.
References kaldi::nnet3::attention::AttentionForward(), RestrictedAttentionComponent::context_dim_, KALDI_ASSERT, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, ConvolutionComputationIo::num_images, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), RestrictedAttentionComponent::output_context_, ConvolutionComputationIo::start_t_in, ConvolutionComputationIo::start_t_out, ConvolutionComputationIo::t_step_in, ConvolutionComputationIo::t_step_out, and RestrictedAttentionComponent::value_dim_.
Referenced by RestrictedAttentionComponent::Propagate().
|
inlinevirtual |
Return bitmask of the component's properties.
These properties depend only on the component's type. See enum ComponentProperties.
Implements Component.
Definition at line 131 of file nnet-attention-component.h.
References RestrictedAttentionComponent::Add(), RestrictedAttentionComponent::Backprop(), kaldi::nnet3::kBackpropAdds, kaldi::nnet3::kBackpropNeedsInput, kaldi::nnet3::kPropagateAdds, kaldi::nnet3::kReordersIndexes, kaldi::nnet3::kStoresStats, kaldi::nnet3::kUsesMemo, RestrictedAttentionComponent::Propagate(), RestrictedAttentionComponent::Read(), RestrictedAttentionComponent::Scale(), RestrictedAttentionComponent::StoreStats(), RestrictedAttentionComponent::Write(), and RestrictedAttentionComponent::ZeroStats().
|
virtual |
Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed.
Implements Component.
Definition at line 473 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::context_dim_, RestrictedAttentionComponent::entropy_stats_, kaldi::ExpectOneOrTwoTokens(), kaldi::nnet3::ExpectToken(), RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, RestrictedAttentionComponent::output_context_, RestrictedAttentionComponent::posterior_stats_, Vector< Real >::Read(), CuMatrix< Real >::Read(), kaldi::ReadBasicType(), RestrictedAttentionComponent::stats_count_, RestrictedAttentionComponent::time_stride_, and RestrictedAttentionComponent::value_dim_.
Referenced by RestrictedAttentionComponent::Properties().
|
virtual |
This function only does something interesting for non-simple Components.
It provides an opportunity for a Component to reorder the or pad the indexes at its input and output. This might be useful, for instance, if a component requires a particular ordering of the indexes that doesn't correspond to their natural ordering. Components that might modify the indexes are required to return the kReordersIndexes flag in their Properties(). The ReorderIndexes() function is now allowed to insert blanks into the indexes. The 'blanks' must be of the form (n,kNoTime,x), where the marker kNoTime (a very negative number) is there where the 't' indexes normally live. The reason we don't just have, say, (-1,-1,-1), relates to the need to preserve a regular pattern over the 'n' indexes so that 'shortcut compilation' (c.f. ExpandComputation()) can work correctly
[in,out] | Indexes | at the input of the Component. |
[in,out] | Indexes | at the output of the Component |
Reimplemented from Component.
Definition at line 376 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::GetComputationStructure(), and RestrictedAttentionComponent::GetIndexes().
Referenced by RestrictedAttentionComponent::DeleteMemo().
|
virtual |
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent.
– a Nonlinear component (or another component that stores stats, like BatchNormComponent)– it relates to scaling activation stats, not parameters. Otherwise it will normally do nothing.
Reimplemented from Component.
Definition at line 255 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::entropy_stats_, RestrictedAttentionComponent::posterior_stats_, VectorBase< Real >::Scale(), CuMatrixBase< Real >::Scale(), and RestrictedAttentionComponent::stats_count_.
Referenced by RestrictedAttentionComponent::Properties().
|
virtual |
This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity.
It only does something for those components that have nonzero Properties()&kStoresStats.
[in] | in_value | The input to the Propagate() function. Note: if the component sets the flag kPropagateInPlace, this should not be used; the empty matrix will be provided here if in-place propagation was used. |
[in] | out_value | The output of the Propagate() function. |
[in] | memo | The 'memo' returned by the Propagate() function; this will usually be NULL. |
Reimplemented from Component.
Definition at line 200 of file nnet-attention-component.cc.
References CuVectorBase< Real >::AddDiagMatMat(), CuMatrixBase< Real >::AddMat(), CuVectorBase< Real >::AddRowSumMat(), VectorBase< Real >::AddVec(), CuMatrixBase< Real >::ApplyFloor(), CuMatrixBase< Real >::ApplyLog(), RestrictedAttentionComponent::Memo::c, RestrictedAttentionComponent::context_dim_, CuVectorBase< Real >::Data(), VectorBase< Real >::Dim(), RestrictedAttentionComponent::entropy_stats_, KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, RestrictedAttentionComponent::num_heads_, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), RestrictedAttentionComponent::posterior_stats_, kaldi::RandInt(), Vector< Real >::Resize(), CuMatrix< Real >::Resize(), and RestrictedAttentionComponent::stats_count_.
Referenced by RestrictedAttentionComponent::Properties().
|
inlinevirtual |
Returns a string such as "SigmoidComponent", describing the type of the object.
Implements Component.
Definition at line 130 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Info().
|
virtual |
Write component to stream.
Implements Component.
Definition at line 442 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::entropy_stats_, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, RestrictedAttentionComponent::output_context_, RestrictedAttentionComponent::posterior_stats_, RestrictedAttentionComponent::stats_count_, RestrictedAttentionComponent::time_stride_, RestrictedAttentionComponent::value_dim_, VectorBase< Real >::Write(), CuMatrixBase< Real >::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().
Referenced by RestrictedAttentionComponent::Properties().
|
virtual |
Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero.
Other components that store other types of statistics (e.g. regarding gradient clipping) should implement ZeroStats() also.
Reimplemented from Component.
Definition at line 249 of file nnet-attention-component.cc.
References RestrictedAttentionComponent::entropy_stats_, RestrictedAttentionComponent::posterior_stats_, VectorBase< Real >::SetZero(), CuMatrixBase< Real >::SetZero(), and RestrictedAttentionComponent::stats_count_.
Referenced by RestrictedAttentionComponent::Properties().
|
private |
Definition at line 287 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Backprop(), RestrictedAttentionComponent::BackpropOneHead(), RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::GetInputIndexes(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::InputDim(), RestrictedAttentionComponent::IsComputable(), RestrictedAttentionComponent::OutputDim(), RestrictedAttentionComponent::Propagate(), RestrictedAttentionComponent::PropagateOneHead(), RestrictedAttentionComponent::Read(), and RestrictedAttentionComponent::StoreStats().
|
private |
Definition at line 295 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Add(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::Read(), RestrictedAttentionComponent::Scale(), RestrictedAttentionComponent::StoreStats(), RestrictedAttentionComponent::Write(), and RestrictedAttentionComponent::ZeroStats().
|
private |
Definition at line 282 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Backprop(), RestrictedAttentionComponent::BackpropOneHead(), RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::InputDim(), RestrictedAttentionComponent::Propagate(), RestrictedAttentionComponent::PropagateOneHead(), RestrictedAttentionComponent::Read(), and RestrictedAttentionComponent::Write().
|
private |
Definition at line 292 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::BackpropOneHead(), RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::PropagateOneHead(), RestrictedAttentionComponent::Read(), and RestrictedAttentionComponent::Write().
|
private |
Definition at line 281 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Backprop(), RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::InputDim(), RestrictedAttentionComponent::OutputDim(), RestrictedAttentionComponent::Propagate(), RestrictedAttentionComponent::Read(), RestrictedAttentionComponent::StoreStats(), and RestrictedAttentionComponent::Write().
|
private |
Definition at line 284 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::GetComputationStructure(), RestrictedAttentionComponent::GetInputIndexes(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::IsComputable(), RestrictedAttentionComponent::Read(), and RestrictedAttentionComponent::Write().
|
private |
Definition at line 289 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::GetComputationStructure(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::IsComputable(), RestrictedAttentionComponent::Read(), and RestrictedAttentionComponent::Write().
|
private |
Definition at line 285 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::GetComputationStructure(), RestrictedAttentionComponent::GetInputIndexes(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::IsComputable(), RestrictedAttentionComponent::Read(), and RestrictedAttentionComponent::Write().
|
private |
Definition at line 290 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::GetComputationStructure(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::IsComputable(), RestrictedAttentionComponent::Read(), and RestrictedAttentionComponent::Write().
|
private |
Definition at line 291 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Backprop(), RestrictedAttentionComponent::BackpropOneHead(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::OutputDim(), RestrictedAttentionComponent::Propagate(), RestrictedAttentionComponent::PropagateOneHead(), RestrictedAttentionComponent::Read(), and RestrictedAttentionComponent::Write().
|
private |
Definition at line 298 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Add(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::Read(), RestrictedAttentionComponent::Scale(), RestrictedAttentionComponent::StoreStats(), RestrictedAttentionComponent::Write(), and RestrictedAttentionComponent::ZeroStats().
|
private |
Definition at line 294 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Add(), RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::Read(), RestrictedAttentionComponent::Scale(), RestrictedAttentionComponent::StoreStats(), RestrictedAttentionComponent::Write(), and RestrictedAttentionComponent::ZeroStats().
|
private |
Definition at line 286 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::GetComputationStructure(), RestrictedAttentionComponent::GetInputIndexes(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::IsComputable(), RestrictedAttentionComponent::Read(), and RestrictedAttentionComponent::Write().
|
private |
Definition at line 283 of file nnet-attention-component.h.
Referenced by RestrictedAttentionComponent::Backprop(), RestrictedAttentionComponent::BackpropOneHead(), RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::InputDim(), RestrictedAttentionComponent::OutputDim(), RestrictedAttentionComponent::Propagate(), RestrictedAttentionComponent::PropagateOneHead(), RestrictedAttentionComponent::Read(), and RestrictedAttentionComponent::Write().