RestrictedAttentionComponent Class Reference

RestrictedAttentionComponent implements an attention model with restricted temporal context. More...

#include <nnet-attention-component.h>

Inheritance diagram for RestrictedAttentionComponent:
Collaboration diagram for RestrictedAttentionComponent:

Classes

struct  Memo
 
class  PrecomputedIndexes
 

Public Member Functions

 RestrictedAttentionComponent ()
 
 RestrictedAttentionComponent (const RestrictedAttentionComponent &other)
 
virtual int32 InputDim () const
 Returns input-dimension of this component. More...
 
virtual int32 OutputDim () const
 Returns output-dimension of this component. More...
 
virtual std::string Info () const
 Returns some text-form information about this component, for diagnostics. More...
 
virtual void InitFromConfig (ConfigLine *cfl)
 Initialize, from a ConfigLine object. More...
 
virtual std::string Type () const
 Returns a string such as "SigmoidComponent", describing the type of the object. More...
 
virtual int32 Properties () const
 Return bitmask of the component's properties. More...
 
virtual void * Propagate (const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
 Propagate function. More...
 
virtual void StoreStats (const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
 This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity. More...
 
virtual void Scale (BaseFloat scale)
 This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent. More...
 
virtual void Add (BaseFloat alpha, const Component &other)
 This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters. More...
 
virtual void ZeroStats ()
 Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero. More...
 
virtual void Backprop (const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const
 Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update. More...
 
virtual void Read (std::istream &is, bool binary)
 Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed. More...
 
virtual void Write (std::ostream &os, bool binary) const
 Write component to stream. More...
 
virtual ComponentCopy () const
 Copies component (deep copy). More...
 
virtual void DeleteMemo (void *memo) const
 This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function. More...
 
virtual void ReorderIndexes (std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual void GetInputIndexes (const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
 This function only does something interesting for non-simple Components. More...
 
virtual bool IsComputable (const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
 This function only does something interesting for non-simple Components, and it exists to make it possible to manage optionally-required inputs. More...
 
virtual ComponentPrecomputedIndexesPrecomputeIndexes (const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
 This function must return NULL for simple Components. More...
 
- Public Member Functions inherited from Component
virtual void ConsolidateMemory ()
 This virtual function relates to memory management, and avoiding fragmentation. More...
 
 Component ()
 
virtual ~Component ()
 

Private Member Functions

void PropagateOneHead (const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *out) const
 
void BackpropOneHead (const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &out_deriv, CuMatrixBase< BaseFloat > *in_deriv) const
 
void GetComputationStructure (const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo *io) const
 
void GetIndexes (const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo &io, std::vector< Index > *new_input_indexes, std::vector< Index > *new_output_indexes) const
 
void Check () const
 

Static Private Member Functions

static void CreateIndexesVector (const std::vector< std::pair< int32, int32 > > &n_x_pairs, int32 t_start, int32 t_step, int32 num_t_values, const std::unordered_set< Index, IndexHasher > &index_set, std::vector< Index > *output_indexes)
 Utility function used in GetIndexes(). More...
 

Private Attributes

int32 num_heads_
 
int32 key_dim_
 
int32 value_dim_
 
int32 num_left_inputs_
 
int32 num_right_inputs_
 
int32 time_stride_
 
int32 context_dim_
 
int32 num_left_inputs_required_
 
int32 num_right_inputs_required_
 
bool output_context_
 
BaseFloat key_scale_
 
double stats_count_
 
Vector< double > entropy_stats_
 
CuMatrix< double > posterior_stats_
 

Additional Inherited Members

- Static Public Member Functions inherited from Component
static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream (works out its type). Dies on error. More...
 
static ComponentNewComponentOfType (const std::string &type)
 Returns a new Component of the given type e.g. More...
 

Detailed Description

RestrictedAttentionComponent implements an attention model with restricted temporal context.

What is implemented here is a case of self-attention, meaning that the set of indexes on the input is the same set as the indexes on the output (like an N->N mapping, ignoring edge effects, as opposed to an N->M mapping that you might see in a translation model). "Restricted" means that the source indexes are constrained to be close to the destination indexes, i.e. when outputting something for time 't' we attend to a narrow range of source time indexes close to 't'.

This component is just a fixed nonlinearity (albeit of a type that "knows about" time, i.e. the output at time 't' depends on inputs at a range of time values). This component is not updatable; all the parameters are expected to live in the previous component which is most likely going to be of type NaturalGradientAffineComponent. For a more in-depth explanation, please see comments in the source of the file attention.h. Also, look at the comments for InputDim() and OutputDim() which help to clarify the input and output formats.

The following are the parameters accepted on the config line, with examples of their values.

num-heads E.g. num-heads=10. Defaults to 1. Having multiple heads just means the same nonlinearity is repeated many times. InputDim() and OutputDim() are multiples of num-heads. key-dim E.g. key-dim=60. Must be specified. Dimension of input keys. value-dim E.g. value-dim=60. Must be specified. Dimension of input values (these are the things over which the component forms a weighted sum, although if output-context=true we append to the output the weights of the weighted sum, as they might also carry useful information. time-stride Stride for 't' value, e.g. 1 or 3. For example, if time-stride=3, to compute the output for t=10 we'd use the input for time values like ... t=7, t=10, t=13, ... (the ends of this range depend on num-left-inputs and num-right-inputs). num-left-inputs Number of frames to the left of the current frame, that we use as input, e.g. 5. (The 't' values used will be separated by 'time-stride'). num-left-inputs must be >= 0. Must be specified. num-right-inputs Number of frames to the right of the current frame, that we use as input, e.g. 2. Must be >= 0 and must be specified. You are not allowed to set both num-left-inputs and num-right-inputs to zero. num-left-inputs-required The number of frames to the left, that are required in order to produce an output. Defaults to the same as num-left-inputs, but you can set it to a smaller value if you want. We'll use zero-padding for non-required inputs that are not present in the input. Be careful with this because it interacts with decoding settings; for non-online decoding and for dumping of egs it would be advisable to increase the extra-left-context parameter by the sum of the difference between num-left-inputs-required and num-left-inputs, although you could leave extra-left-context-initial at zero. num-right-inputs-required See num-left-inputs-required for explanation; it's the mirror image. Defaults to num-right-inputs. However, be even more careful with the right-hand version; if you set this, online (looped) decoding will not work correctly. It might be wiser just to reduce num-right-inputs if you care about real-time decoding. key-scale Scale on the keys (but not the added context). Defaults to 1.0 / sqrt(key-dim), like the 1/sqrt(d_k) value in the "Attention is all you need" paper. This helps prevent saturation of the softmax. output-context (Default: true). If true, output the softmax that encodes which positions we chose, in addition to the input values.

Definition at line 106 of file nnet-attention-component.h.

Constructor & Destructor Documentation

◆ RestrictedAttentionComponent() [1/2]

Definition at line 110 of file nnet-attention-component.h.

Referenced by RestrictedAttentionComponent::Copy().

110 { }

◆ RestrictedAttentionComponent() [2/2]

Definition at line 61 of file nnet-attention-component.cc.

62  :
63  num_heads_(other.num_heads_),
64  key_dim_(other.key_dim_),
65  value_dim_(other.value_dim_),
66  num_left_inputs_(other.num_left_inputs_),
67  num_right_inputs_(other.num_right_inputs_),
68  time_stride_(other.time_stride_),
69  context_dim_(other.context_dim_),
70  num_left_inputs_required_(other.num_left_inputs_required_),
71  num_right_inputs_required_(other.num_right_inputs_required_),
72  output_context_(other.output_context_),
73  key_scale_(other.key_scale_),
74  stats_count_(other.stats_count_),
75  entropy_stats_(other.entropy_stats_),
76  posterior_stats_(other.posterior_stats_) { }

Member Function Documentation

◆ Add()

void Add ( BaseFloat  alpha,
const Component other 
)
virtual

This virtual function when called by – an UpdatableComponent adds the parameters of another updatable component, times some constant, to the current parameters.

– a NonlinearComponent (or another component that stores stats, like BatchNormComponent)– it relates to adding stats. Otherwise it will normally do nothing.

Reimplemented from Component.

Definition at line 261 of file nnet-attention-component.cc.

References CuMatrixBase< Real >::AddMat(), VectorBase< Real >::AddVec(), VectorBase< Real >::Dim(), RestrictedAttentionComponent::entropy_stats_, KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), RestrictedAttentionComponent::posterior_stats_, Vector< Real >::Resize(), CuMatrix< Real >::Resize(), and RestrictedAttentionComponent::stats_count_.

Referenced by RestrictedAttentionComponent::Properties().

261  {
262  const RestrictedAttentionComponent *other =
263  dynamic_cast<const RestrictedAttentionComponent*>(&other_in);
264  KALDI_ASSERT(other != NULL);
265  if (entropy_stats_.Dim() == 0 && other->entropy_stats_.Dim() != 0)
266  entropy_stats_.Resize(other->entropy_stats_.Dim());
267  if (posterior_stats_.NumRows() == 0 && other->posterior_stats_.NumRows() != 0)
268  posterior_stats_.Resize(other->posterior_stats_.NumRows(), other->posterior_stats_.NumCols());
269  if (other->entropy_stats_.Dim() != 0)
270  entropy_stats_.AddVec(alpha, other->entropy_stats_);
271  if (other->posterior_stats_.NumRows() != 0)
272  posterior_stats_.AddMat(alpha, other->posterior_stats_);
273  stats_count_ += alpha * other->stats_count_;
274 }
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
Definition: cu-matrix.cc:954
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:50

◆ Backprop()

void Backprop ( const std::string &  debug_info,
const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
const CuMatrixBase< BaseFloat > &  out_deriv,
void *  memo,
Component to_update,
CuMatrixBase< BaseFloat > *  in_deriv 
) const
virtual

Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL, this can compute input-data derivatives and/or perform model update.

Parameters
[in]debug_infoThe component name, to be printed out in any warning messages.
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]in_valueThe matrix that was given as input to the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsInput == 0.
[in]out_valueThe matrix that was output from the Propagate function. Will be ignored (and may be empty) if Properties()&kBackpropNeedsOutput == 0
[in]out_derivThe derivative at the output of this component.
[in]memoThis will normally be NULL, but for component types that set the flag kUsesMemo, this will be the return value of the Propagate() function that corresponds to this Backprop() function. Ownership of any pointers is not transferred to the Backprop function; DeleteMemo() will be called to delete it.
[out]to_updateIf model update is desired, the Component to be updated, else NULL. Does not have to be identical to this. If supplied, you can assume that to_update->Properties() & kUpdatableComponent is nonzero.
[out]in_derivThe derivative at the input of this component, if needed (else NULL). If Properties()&kBackpropInPlace, may be the same matrix as out_deriv. If Properties()&kBackpropAdds, this is added to by the Backprop routine, else it is set. The component code chooses which mode to work in, based on convenience.

Implements Component.

Definition at line 292 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::BackpropOneHead(), RestrictedAttentionComponent::Memo::c, RestrictedAttentionComponent::context_dim_, RestrictedAttentionComponent::PrecomputedIndexes::io, KALDI_ASSERT, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::num_heads_, ConvolutionComputationIo::num_images, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, CuMatrixBase< Real >::NumRows(), NVTX_RANGE, RestrictedAttentionComponent::output_context_, kaldi::SameDim(), and RestrictedAttentionComponent::value_dim_.

Referenced by RestrictedAttentionComponent::Properties().

300  {
301  NVTX_RANGE("RestrictedAttentionComponent::Backprop");
302  const PrecomputedIndexes *indexes =
303  dynamic_cast<const PrecomputedIndexes*>(indexes_in);
304  KALDI_ASSERT(indexes != NULL);
305  Memo *memo = static_cast<Memo*>(memo_in);
306  KALDI_ASSERT(memo != NULL);
307  const time_height_convolution::ConvolutionComputationIo &io = indexes->io;
308  KALDI_ASSERT(indexes != NULL &&
309  in_value.NumRows() == io.num_t_in * io.num_images &&
310  out_deriv.NumRows() == io.num_t_out * io.num_images &&
311  in_deriv != NULL && SameDim(in_value, *in_deriv));
312 
313  const CuMatrix<BaseFloat> &c = memo->c;
314 
315  int32 query_dim = key_dim_ + context_dim_,
316  input_dim_per_head = key_dim_ + value_dim_ + query_dim,
317  output_dim_per_head = value_dim_ + (output_context_ ? context_dim_ : 0);
318 
319  for (int32 h = 0; h < num_heads_; h++) {
320  CuSubMatrix<BaseFloat>
321  in_value_part(in_value, 0, in_value.NumRows(),
322  h * input_dim_per_head, input_dim_per_head),
323  c_part(c, 0, out_deriv.NumRows(),
325  out_deriv_part(out_deriv, 0, out_deriv.NumRows(),
326  h * output_dim_per_head, output_dim_per_head),
327  in_deriv_part(*in_deriv, 0, in_value.NumRows(),
328  h * input_dim_per_head, input_dim_per_head);
329  BackpropOneHead(io, in_value_part, c_part, out_deriv_part,
330  &in_deriv_part);
331  }
332 }
void BackpropOneHead(const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &out_deriv, CuMatrixBase< BaseFloat > *in_deriv) const
kaldi::int32 int32
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define NVTX_RANGE(name)
Definition: cu-common.h:143

◆ BackpropOneHead()

void BackpropOneHead ( const time_height_convolution::ConvolutionComputationIo io,
const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  c,
const CuMatrixBase< BaseFloat > &  out_deriv,
CuMatrixBase< BaseFloat > *  in_deriv 
) const
private

Definition at line 335 of file nnet-attention-component.cc.

References kaldi::nnet3::attention::AttentionBackward(), RestrictedAttentionComponent::context_dim_, KALDI_ASSERT, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, ConvolutionComputationIo::num_images, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), RestrictedAttentionComponent::output_context_, kaldi::SameDim(), ConvolutionComputationIo::start_t_in, ConvolutionComputationIo::start_t_out, ConvolutionComputationIo::t_step_in, ConvolutionComputationIo::t_step_out, and RestrictedAttentionComponent::value_dim_.

Referenced by RestrictedAttentionComponent::Backprop().

340  {
341  // the easiest way to understand this is to compare it with PropagateOneHead().
342  int32 query_dim = key_dim_ + context_dim_,
343  full_value_dim = value_dim_ + (output_context_ ? context_dim_ : 0);
344  KALDI_ASSERT(in_value.NumRows() == io.num_images * io.num_t_in &&
345  out_deriv.NumRows() == io.num_images * io.num_t_out &&
346  out_deriv.NumCols() == full_value_dim &&
347  in_value.NumCols() == (key_dim_ + value_dim_ + query_dim) &&
348  io.t_step_in == io.t_step_out &&
349  (io.start_t_out - io.start_t_in) % io.t_step_in == 0 &&
350  SameDim(in_value, *in_deriv) &&
351  c.NumRows() == out_deriv.NumRows() &&
352  c.NumCols() == context_dim_);
353 
354  // 'steps_left_context' is the number of time-steps the input has on the left
355  // that don't appear in the output.
356  int32 steps_left_context = (io.start_t_out - io.start_t_in) / io.t_step_in,
357  rows_left_context = steps_left_context * io.num_images;
358  KALDI_ASSERT(rows_left_context >= 0);
359 
360 
361  CuSubMatrix<BaseFloat> queries(in_value, rows_left_context, out_deriv.NumRows(),
362  key_dim_ + value_dim_, query_dim),
363  queries_deriv(*in_deriv, rows_left_context, out_deriv.NumRows(),
364  key_dim_ + value_dim_, query_dim),
365  keys(in_value, 0, in_value.NumRows(), 0, key_dim_),
366  keys_deriv(*in_deriv, 0, in_value.NumRows(), 0, key_dim_),
367  values(in_value, 0, in_value.NumRows(), key_dim_, value_dim_),
368  values_deriv(*in_deriv, 0, in_value.NumRows(), key_dim_, value_dim_);
369 
370  attention::AttentionBackward(key_scale_, keys, queries, values, c, out_deriv,
371  &keys_deriv, &queries_deriv, &values_deriv);
372 }
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to &#39;AttentionForward&#39;, propagating the derivative back to th...
Definition: attention.cc:154
kaldi::int32 int32
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Check()

void Check ( ) const
private

Definition at line 277 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::context_dim_, KALDI_ASSERT, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, RestrictedAttentionComponent::stats_count_, RestrictedAttentionComponent::time_stride_, and RestrictedAttentionComponent::value_dim_.

Referenced by RestrictedAttentionComponent::InitFromConfig().

277  {
278  KALDI_ASSERT(num_heads_ > 0 && key_dim_ > 0 && value_dim_ > 0 &&
279  num_left_inputs_ >= 0 && num_right_inputs_ >= 0 &&
281  time_stride_ > 0 &&
287  key_scale_ > 0.0 && key_scale_ <= 1.0 &&
288  stats_count_ >= 0.0);
289 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Copy()

virtual Component* Copy ( ) const
inlinevirtual

Copies component (deep copy).

Implements Component.

Definition at line 155 of file nnet-attention-component.h.

References RestrictedAttentionComponent::RestrictedAttentionComponent().

◆ CreateIndexesVector()

void CreateIndexesVector ( const std::vector< std::pair< int32, int32 > > &  n_x_pairs,
int32  t_start,
int32  t_step,
int32  num_t_values,
const std::unordered_set< Index, IndexHasher > &  index_set,
std::vector< Index > *  output_indexes 
)
staticprivate

Utility function used in GetIndexes().

Creates a grid of Indexes, where 't' has the larger stride, and within each block of Indexes for a given 't', we have the given list of (n, x) pairs. For Indexes that we create where the 't' value was not present in 'index_set', we set the 't' value to kNoTime (indicating that it's only for padding, not a real input or an output that's ever used).

Definition at line 575 of file nnet-attention-component.cc.

References KALDI_ASSERT, and kaldi::nnet3::kNoTime.

Referenced by RestrictedAttentionComponent::GetIndexes().

579  {
580  output_indexes->resize(static_cast<size_t>(num_t_values) * n_x_pairs.size());
581  std::vector<Index>::iterator out_iter = output_indexes->begin();
582  for (int32 t = t_start; t < t_start + (t_step * num_t_values); t += t_step) {
583  std::vector<std::pair<int32, int32> >::const_iterator
584  iter = n_x_pairs.begin(), end = n_x_pairs.end();
585  for (; iter != end; ++iter) {
586  out_iter->n = iter->first;
587  out_iter->t = t;
588  out_iter->x = iter->second;
589  if (index_set.count(*out_iter) == 0)
590  out_iter->t = kNoTime;
591  ++out_iter;
592  }
593  }
594  KALDI_ASSERT(out_iter == output_indexes->end());
595 }
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
const int kNoTime
Definition: nnet-common.cc:573

◆ DeleteMemo()

virtual void DeleteMemo ( void *  memo) const
inlinevirtual

This virtual function only needs to be overwritten by Components that return a non-NULL memo from their Propagate() function.

It's called by NnetComputer in cases where Propagate returns a memo but there will be no backprop to consume it.

Reimplemented from Component.

Definition at line 158 of file nnet-attention-component.h.

References RestrictedAttentionComponent::GetInputIndexes(), RestrictedAttentionComponent::IsComputable(), RestrictedAttentionComponent::PrecomputeIndexes(), and RestrictedAttentionComponent::ReorderIndexes().

158 { delete static_cast<Memo*>(memo); }

◆ GetComputationStructure()

void GetComputationStructure ( const std::vector< Index > &  input_indexes,
const std::vector< Index > &  output_indexes,
time_height_convolution::ConvolutionComputationIo io 
) const
private

Definition at line 389 of file nnet-attention-component.cc.

References kaldi::Gcd(), kaldi::nnet3::time_height_convolution::GetComputationIo(), KALDI_ASSERT, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, ConvolutionComputationIo::start_t_in, ConvolutionComputationIo::start_t_out, ConvolutionComputationIo::t_step_in, ConvolutionComputationIo::t_step_out, and RestrictedAttentionComponent::time_stride_.

Referenced by RestrictedAttentionComponent::PrecomputeIndexes(), and RestrictedAttentionComponent::ReorderIndexes().

392  {
393  GetComputationIo(input_indexes, output_indexes, io);
394  // if there was only one output and/or input index (unlikely),
395  // just let the grid periodicity be t_stride_.
396  if (io->t_step_out == 0) io->t_step_out = time_stride_;
397  if (io->t_step_in == 0) io->t_step_in = time_stride_;
398 
399  // We need the grid size on the input and output to be the same, and to divide
400  // t_stride_. If someone is requesting the output more frequently than
401  // t_stride_, then after this change we may end up computing more outputs than
402  // we need, but this is not a configuration that I think is very likely. We
403  // let the grid step be the gcd of the input and output steps, and of
404  // t_stride_.
405  // The next few statements may have the effect of making the grid finer at the
406  // input and output, while having the same start and end point.
407  int32 t_step = Gcd(Gcd(io->t_step_out, io->t_step_in), time_stride_);
408  int32 multiple_out = io->t_step_out / t_step,
409  multiple_in = io->t_step_in / t_step;
410  io->t_step_in = t_step;
411  io->t_step_out = t_step;
412  io->num_t_out = 1 + multiple_out * (io->num_t_out - 1);
413  io->num_t_in = 1 + multiple_in * (io->num_t_in - 1);
414 
415  // Now ensure that the extent of the input has at least
416  // the requested left-context and right context; if
417  // this increases the amount of input, we'll do zero-padding.
418  int32 first_requested_input =
419  io->start_t_out - (time_stride_ * num_left_inputs_),
420  first_required_input =
421  io->start_t_out - (time_stride_ * num_left_inputs_required_),
422  last_t_out = io->start_t_out + (io->num_t_out - 1) * t_step,
423  last_t_in = io->start_t_in + (io->num_t_in - 1) * t_step,
424  last_requested_input = last_t_out + (time_stride_ * num_right_inputs_),
425  last_required_input =
426  last_t_out + (time_stride_ * num_right_inputs_required_);
427 
428  // check that we don't have *more* than the requested context,
429  // but that we have at least the required context.
430  KALDI_ASSERT(io->start_t_in >= first_requested_input &&
431  last_t_in <= last_requested_input &&
432  io->start_t_in <= first_required_input &&
433  last_t_in >= last_required_input);
434 
435  // For the inputs that were requested, but not required,
436  // we pad with zeros. We pad the 'io' object, adding these
437  // extra inputs structurally; in runtime they'll be set to zero.
438  io->start_t_in = first_requested_input;
439  io->num_t_in = 1 + (last_requested_input - first_requested_input) / t_step;
440 }
I Gcd(I m, I n)
Definition: kaldi-math.h:297
kaldi::int32 int32
void GetComputationIo(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, ConvolutionComputationIo *io)
This function takes lists of input and output indexes to a computation (e.g.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetIndexes()

void GetIndexes ( const std::vector< Index > &  input_indexes,
const std::vector< Index > &  output_indexes,
time_height_convolution::ConvolutionComputationIo io,
std::vector< Index > *  new_input_indexes,
std::vector< Index > *  new_output_indexes 
) const
private

Definition at line 597 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::CreateIndexesVector(), kaldi::nnet3::GetNxList(), KALDI_ASSERT, ConvolutionComputationIo::num_images, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, ConvolutionComputationIo::start_t_in, ConvolutionComputationIo::start_t_out, ConvolutionComputationIo::t_step_in, and ConvolutionComputationIo::t_step_out.

Referenced by RestrictedAttentionComponent::PrecomputeIndexes(), and RestrictedAttentionComponent::ReorderIndexes().

602  {
603 
604  std::unordered_set<Index, IndexHasher> input_set, output_set;
605  for (std::vector<Index>::const_iterator iter = input_indexes.begin();
606  iter != input_indexes.end(); ++iter)
607  input_set.insert(*iter);
608  for (std::vector<Index>::const_iterator iter = output_indexes.begin();
609  iter != output_indexes.end(); ++iter)
610  output_set.insert(*iter);
611 
612  std::vector<std::pair<int32, int32> > n_x_pairs;
613  GetNxList(input_indexes, &n_x_pairs); // the n,x pairs at the output will be
614  // identical.
615  KALDI_ASSERT(n_x_pairs.size() == io.num_images);
616  CreateIndexesVector(n_x_pairs, io.start_t_in, io.t_step_in, io.num_t_in,
617  input_set, new_input_indexes);
618  CreateIndexesVector(n_x_pairs, io.start_t_out, io.t_step_out, io.num_t_out,
619  output_set, new_output_indexes);
620 }
static void CreateIndexesVector(const std::vector< std::pair< int32, int32 > > &n_x_pairs, int32 t_start, int32 t_step, int32 num_t_values, const std::unordered_set< Index, IndexHasher > &index_set, std::vector< Index > *output_indexes)
Utility function used in GetIndexes().
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void GetNxList(const std::vector< Index > &indexes, std::vector< std::pair< int32, int32 > > *pairs)
This function outputs a unique, lexicographically sorted list of the pairs of (n, x) values that are ...

◆ GetInputIndexes()

void GetInputIndexes ( const MiscComputationInfo misc_info,
const Index output_index,
std::vector< Index > *  desired_indexes 
) const
virtual

This function only does something interesting for non-simple Components.

For a given index at the output of the component, tells us what indexes are required at its input (note: "required" encompasses also optionally-required things; it will enumerate all things that we'd like to have). See also IsComputable().

Parameters
[in]misc_infoThis argument is supplied to handle things that the framework can't very easily supply: information like which time indexes are needed for AggregateComponent, which time-indexes are available at the input of a recurrent network, and so on. We will add members to misc_info as needed.
[in]output_indexThe Index at the output of the component, for which we are requesting the list of indexes at the component's input.
[out]desired_indexesA list of indexes that are desired at the input. are to be written to here. By "desired" we mean required or optionally-required.

The default implementation of this function is suitable for any SimpleComponent; it just copies the output_index to a single identical element in input_indexes.

Reimplemented from Component.

Definition at line 507 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::context_dim_, rnnlm::i, KALDI_ASSERT, kaldi::nnet3::kNoTime, Index::n, rnnlm::n, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_right_inputs_, Index::t, RestrictedAttentionComponent::time_stride_, and Index::x.

Referenced by RestrictedAttentionComponent::DeleteMemo().

510  {
511  KALDI_ASSERT(output_index.t != kNoTime);
512  int32 first_time = output_index.t - (time_stride_ * num_left_inputs_),
513  last_time = output_index.t + (time_stride_ * num_right_inputs_);
514  desired_indexes->clear();
515  desired_indexes->resize(context_dim_);
516  int32 n = output_index.n, x = output_index.x,
517  i = 0;
518  for (int32 t = first_time; t <= last_time; t += time_stride_, i++) {
519  (*desired_indexes)[i].n = n;
520  (*desired_indexes)[i].t = t;
521  (*desired_indexes)[i].x = x;
522  }
524 }
kaldi::int32 int32
struct rnnlm::@11::@12 n
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
const int kNoTime
Definition: nnet-common.cc:573

◆ Info()

std::string Info ( ) const
virtual

Returns some text-form information about this component, for diagnostics.

Starts with the type of the component. E.g. "SigmoidComponent dim=900", although most components will have much more info.

Reimplemented from Component.

Definition at line 32 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::context_dim_, VectorBase< Real >::Dim(), RestrictedAttentionComponent::entropy_stats_, rnnlm::i, RestrictedAttentionComponent::InputDim(), rnnlm::j, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, CuMatrixBase< Real >::NumCols(), RestrictedAttentionComponent::output_context_, RestrictedAttentionComponent::OutputDim(), RestrictedAttentionComponent::posterior_stats_, RestrictedAttentionComponent::stats_count_, RestrictedAttentionComponent::time_stride_, RestrictedAttentionComponent::Type(), and RestrictedAttentionComponent::value_dim_.

Referenced by RestrictedAttentionComponent::OutputDim().

32  {
33  std::stringstream stream;
34  stream << Type() << ", input-dim=" << InputDim()
35  << ", output-dim=" << OutputDim()
36  << ", num-heads=" << num_heads_
37  << ", time-stride=" << time_stride_
38  << ", key-dim=" << key_dim_
39  << ", value-dim=" << value_dim_
40  << ", num-left-inputs=" << num_left_inputs_
41  << ", num-right-inputs=" << num_right_inputs_
42  << ", context-dim=" << context_dim_
43  << ", num-left-inputs-required=" << num_left_inputs_required_
44  << ", num-right-inputs-required=" << num_right_inputs_required_
45  << ", output-context=" << (output_context_ ? "true" : "false")
46  << ", key-scale=" << key_scale_;
47  if (stats_count_ != 0.0) {
48  stream << ", entropy=";
49  for (int32 i = 0; i < entropy_stats_.Dim(); i++)
50  stream << (entropy_stats_(i) / stats_count_) << ',';
51  for (int32 i = 0; i < num_heads_ && i < 5; i++) {
52  stream << " posterior-stats[" << i <<"]=";
53  for (int32 j = 0; j < posterior_stats_.NumCols(); j++)
54  stream << (posterior_stats_(i,j) / stats_count_) << ',';
55  }
56  stream << " stats-count=" << stats_count_;
57  }
58  return stream.str();
59 }
virtual int32 InputDim() const
Returns input-dimension of this component.
kaldi::int32 int32
virtual int32 OutputDim() const
Returns output-dimension of this component.
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216

◆ InitFromConfig()

void InitFromConfig ( ConfigLine cfl)
virtual

Initialize, from a ConfigLine object.

Parameters
[in]cflA ConfigLine containing any parameters that are needed for initialization. For example: "dim=100 param-stddev=0.1"

Implements Component.

Definition at line 80 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::Check(), RestrictedAttentionComponent::context_dim_, ConfigLine::GetValue(), KALDI_ERR, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, RestrictedAttentionComponent::output_context_, RestrictedAttentionComponent::stats_count_, RestrictedAttentionComponent::time_stride_, RestrictedAttentionComponent::value_dim_, and ConfigLine::WholeLine().

Referenced by RestrictedAttentionComponent::OutputDim().

80  {
81  num_heads_ = 1;
82  key_dim_ = -1;
83  value_dim_ = -1;
84  num_left_inputs_ = -1;
85  num_right_inputs_ = -1;
86  time_stride_ = 1;
89  output_context_ = true;
90  key_scale_ = -1.0;
91 
92 
93  // mandatory arguments.
94  bool ok = cfl->GetValue("key-dim", &key_dim_) &&
95  cfl->GetValue("value-dim", &value_dim_) &&
96  cfl->GetValue("num-left-inputs", &num_left_inputs_) &&
97  cfl->GetValue("num-right-inputs", &num_right_inputs_);
98 
99  if (!ok)
100  KALDI_ERR << "All of the values key-dim, value-dim, "
101  "num-left-inputs and num-right-inputs must be defined.";
102  // optional arguments.
103  cfl->GetValue("num-heads", &num_heads_);
104  cfl->GetValue("time-stride", &time_stride_);
105  cfl->GetValue("num-left-inputs-required", &num_left_inputs_required_);
106  cfl->GetValue("num-right-inputs-required", &num_right_inputs_required_);
107  cfl->GetValue("output-context", &output_context_);
108  cfl->GetValue("key-scale", &key_scale_);
109 
110  if (key_scale_ < 0.0) key_scale_ = 1.0 / sqrt(key_dim_);
115 
116  if (num_heads_ <= 0 || key_dim_ <= 0 || value_dim_ <= 0 ||
117  num_left_inputs_ < 0 || num_right_inputs_ < 0 ||
121  time_stride_ <= 0)
122  KALDI_ERR << "Config line contains invalid values: "
123  << cfl->WholeLine();
124  stats_count_ = 0.0;
126  Check();
127 }
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ InputDim()

virtual int32 InputDim ( ) const
inlinevirtual

Returns input-dimension of this component.

Implements Component.

Definition at line 115 of file nnet-attention-component.h.

References RestrictedAttentionComponent::context_dim_, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::num_heads_, and RestrictedAttentionComponent::value_dim_.

Referenced by RestrictedAttentionComponent::Info().

115  {
116  // the input is interpreted as being appended blocks one for each head; each
117  // such block is interpreted as (key, value, query).
118  int32 query_dim = key_dim_ + context_dim_;
119  return num_heads_ * (key_dim_ + value_dim_ + query_dim);
120  }
kaldi::int32 int32

◆ IsComputable()

bool IsComputable ( const MiscComputationInfo misc_info,
const Index output_index,
const IndexSet input_index_set,
std::vector< Index > *  used_inputs 
) const
virtual

This function only does something interesting for non-simple Components, and it exists to make it possible to manage optionally-required inputs.

It tells the user whether a given output index is computable from a given set of input indexes, and if so, says which input indexes will be used in the computation.

Implementations of this function are required to have the property that adding an element to "input_index_set" can only ever change IsComputable from false to true, never vice versa.

Parameters
[in]misc_infoSome information specific to the computation, such as minimum and maximum times for certain components to do adaptation on; it's a place to put things that don't easily fit in the framework.
[in]output_indexThe index that is to be computed at the output of this Component.
[in]input_index_setThe set of indexes that is available at the input of this Component.
[out]used_inputsIf this is non-NULL and the output is computable this will be set to the list of input indexes that will actually be used in the computation.
Returns
Returns true iff this output is computable from the provided inputs.

The default implementation of this function is suitable for any SimpleComponent: it just returns true if output_index is in input_index_set, and if so sets used_inputs to vector containing that one Index.

Reimplemented from Component.

Definition at line 527 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::context_dim_, KALDI_ASSERT, kaldi::nnet3::kNoTime, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, Index::t, and RestrictedAttentionComponent::time_stride_.

Referenced by RestrictedAttentionComponent::DeleteMemo().

531  {
532  KALDI_ASSERT(output_index.t != kNoTime);
533  Index index(output_index);
534 
535  if (used_inputs != NULL) {
536  int32 first_time = output_index.t - (time_stride_ * num_left_inputs_),
537  last_time = output_index.t + (time_stride_ * num_right_inputs_);
538  used_inputs->clear();
539  used_inputs->reserve(context_dim_);
540 
541  for (int32 t = first_time; t <= last_time; t += time_stride_) {
542  index.t = t;
543  if (input_index_set(index)) {
544  // This input index is available.
545  used_inputs->push_back(index);
546  } else {
547  // This input index is not available.
548  int32 offset = (t - output_index.t) / time_stride_;
549  if (offset >= -num_left_inputs_required_ &&
550  offset <= num_right_inputs_required_) {
551  used_inputs->clear();
552  return false;
553  }
554  }
555  }
556  // All required time-offsets of the output were computable. -> return true.
557  return true;
558  } else {
559  int32 t = output_index.t,
560  first_time_required = t - (time_stride_ * num_left_inputs_required_),
561  last_time_required = t + (time_stride_ * num_right_inputs_required_);
562  for (int32 t = first_time_required;
563  t <= last_time_required;
564  t += time_stride_) {
565  index.t = t;
566  if (!input_index_set(index))
567  return false;
568  }
569  return true;
570  }
571 }
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
const int kNoTime
Definition: nnet-common.cc:573

◆ OutputDim()

virtual int32 OutputDim ( ) const
inlinevirtual

Returns output-dimension of this component.

Implements Component.

Definition at line 121 of file nnet-attention-component.h.

References RestrictedAttentionComponent::context_dim_, RestrictedAttentionComponent::Info(), RestrictedAttentionComponent::InitFromConfig(), RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::output_context_, and RestrictedAttentionComponent::value_dim_.

Referenced by RestrictedAttentionComponent::Info().

121  {
122  // the output consists of appended blocks, one for each head; each such
123  // block is is the attention weighted average of the input values, to which
124  // we append softmax encoding of the positions we chose, if output_context_
125  // == true.
126  return num_heads_ * (value_dim_ + (output_context_ ? context_dim_ : 0));
127  }

◆ PrecomputeIndexes()

ComponentPrecomputedIndexes * PrecomputeIndexes ( const MiscComputationInfo misc_info,
const std::vector< Index > &  input_indexes,
const std::vector< Index > &  output_indexes,
bool  need_backprop 
) const
virtual

This function must return NULL for simple Components.

Returns a pointer to a class that may contain some precomputed component-specific and computation-specific indexes to be in used in the Propagate and Backprop functions.

Parameters
[in]misc_infoThis argument is supplied to handle things that the framework can't very easily supply: information like which time indexes are needed for AggregateComponent, which time-indexes are available at the input of a recurrent network, and so on. misc_info may not even ever be used here. We will add members to misc_info as needed.
[in]input_indexesA vector of indexes that explains what time-indexes (and other indexes) each row of the in/in_value/in_deriv matrices given to Propagate and Backprop will mean.
[in]output_indexesA vector of indexes that explains what time-indexes (and other indexes) each row of the out/out_value/out_deriv matrices given to Propagate and Backprop will mean.
[in]need_backpropTrue if we might need to do backprop with this component, so that if any different indexes are needed for backprop then those should be computed too.
Returns
Returns a child-class of class ComponentPrecomputedIndexes, or NULL if this component for does not need to precompute any indexes (e.g. if it is a simple component and does not care about indexes).

Reimplemented from Component.

Definition at line 622 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::GetComputationStructure(), RestrictedAttentionComponent::GetIndexes(), kaldi::GetVerboseLevel(), RestrictedAttentionComponent::PrecomputedIndexes::io, and KALDI_ASSERT.

Referenced by RestrictedAttentionComponent::DeleteMemo().

627  {
628  PrecomputedIndexes *ans = new PrecomputedIndexes();
629  GetComputationStructure(input_indexes, output_indexes, &(ans->io));
630  if (GetVerboseLevel() >= 2) {
631  // what goes next is just a check.
632  std::vector<Index> new_input_indexes, new_output_indexes;
633  GetIndexes(input_indexes, output_indexes, ans->io,
634  &new_input_indexes, &new_output_indexes);
635  // input_indexes and output_indexes should be the ones that were
636  // output by ReorderIndexes(), so they should already
637  // have gone through the GetComputationStructure()->GetIndexes()
638  // procedure. Applying the same procedure twice is supposed to
639  // give an unchanged results.
640  KALDI_ASSERT(input_indexes == new_input_indexes &&
641  output_indexes == new_output_indexes);
642  }
643  return ans;
644 }
void GetIndexes(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo &io, std::vector< Index > *new_input_indexes, std::vector< Index > *new_output_indexes) const
void GetComputationStructure(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo *io) const
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Propagate()

void * Propagate ( const ComponentPrecomputedIndexes indexes,
const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
) const
virtual

Propagate function.

Parameters
[in]indexesA pointer to some information output by this class's PrecomputeIndexes function (will be NULL for simple components, i.e. those that don't do things like splicing).
[in]inThe input to this component. Num-columns == InputDim().
[out]outThe output of this component. Num-columns == OutputDim(). Note: output of this component will be added to the initial value of "out" if Properties()&kPropagateAdds != 0; otherwise the output will be set and the initial value ignored. Each Component chooses whether it is more convenient implementation-wise to add or set, and the calling code has to deal with it.
Returns
Normally returns NULL, but may return a non-NULL value for components which have the flag kUsesMemo set. This value will be passed into the corresponding Backprop routine.

Implements Component.

Definition at line 132 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::context_dim_, RestrictedAttentionComponent::PrecomputedIndexes::io, KALDI_ASSERT, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::num_heads_, ConvolutionComputationIo::num_images, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, CuMatrixBase< Real >::NumRows(), RestrictedAttentionComponent::output_context_, RestrictedAttentionComponent::PropagateOneHead(), and RestrictedAttentionComponent::value_dim_.

Referenced by RestrictedAttentionComponent::Properties().

134  {
135  const PrecomputedIndexes *indexes = dynamic_cast<const PrecomputedIndexes*>(
136  indexes_in);
137  KALDI_ASSERT(indexes != NULL &&
138  in.NumRows() == indexes->io.num_t_in * indexes->io.num_images &&
139  out->NumRows() == indexes->io.num_t_out * indexes->io.num_images);
140 
141 
142  Memo *memo = new Memo();
143  memo->c.Resize(out->NumRows(), context_dim_ * num_heads_);
144 
145  int32 query_dim = key_dim_ + context_dim_;
146  int32 input_dim_per_head = key_dim_ + value_dim_ + query_dim,
147  output_dim_per_head = value_dim_ + (output_context_ ? context_dim_ : 0);
148  for (int32 h = 0; h < num_heads_; h++) {
149  CuSubMatrix<BaseFloat> in_part(in, 0, in.NumRows(),
150  h * input_dim_per_head, input_dim_per_head),
151  c_part(memo->c, 0, out->NumRows(),
153  out_part(*out, 0, out->NumRows(),
154  h * output_dim_per_head, output_dim_per_head);
155  PropagateOneHead(indexes->io, in_part, &c_part, &out_part);
156  }
157  return static_cast<void*>(memo);
158 }
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void PropagateOneHead(const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *out) const

◆ PropagateOneHead()

void PropagateOneHead ( const time_height_convolution::ConvolutionComputationIo io,
const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  c,
CuMatrixBase< BaseFloat > *  out 
) const
private

Definition at line 160 of file nnet-attention-component.cc.

References kaldi::nnet3::attention::AttentionForward(), RestrictedAttentionComponent::context_dim_, KALDI_ASSERT, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, ConvolutionComputationIo::num_images, ConvolutionComputationIo::num_t_in, ConvolutionComputationIo::num_t_out, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), RestrictedAttentionComponent::output_context_, ConvolutionComputationIo::start_t_in, ConvolutionComputationIo::start_t_out, ConvolutionComputationIo::t_step_in, ConvolutionComputationIo::t_step_out, and RestrictedAttentionComponent::value_dim_.

Referenced by RestrictedAttentionComponent::Propagate().

164  {
165  int32 query_dim = key_dim_ + context_dim_,
166  full_value_dim = value_dim_ + (output_context_ ? context_dim_ : 0);
167  KALDI_ASSERT(in.NumRows() == io.num_images * io.num_t_in &&
168  out->NumRows() == io.num_images * io.num_t_out &&
169  out->NumCols() == full_value_dim &&
170  in.NumCols() == (key_dim_ + value_dim_ + query_dim) &&
171  io.t_step_in == io.t_step_out &&
172  (io.start_t_out - io.start_t_in) % io.t_step_in == 0);
173 
174  // 'steps_left_context' is the number of time-steps the input has on the left
175  // that don't appear in the output.
176  int32 steps_left_context = (io.start_t_out - io.start_t_in) / io.t_step_in,
177  rows_left_context = steps_left_context * io.num_images;
178  KALDI_ASSERT(rows_left_context >= 0);
179 
180  // 'queries' contains the queries. We don't use all rows of the input
181  // queries; only the rows that correspond to the time-indexes at the
182  // output, i.e. we exclude the left-context and right-context.
183  // 'out'; the remaining rows of 'in' that we didn't select correspond to left
184  // and right temporal context.
185  CuSubMatrix<BaseFloat> queries(in, rows_left_context, out->NumRows(),
186  key_dim_ + value_dim_, query_dim);
187  // 'keys' contains the keys; note, these are not extended with
188  // context information; that happens further in.
189  CuSubMatrix<BaseFloat> keys(in, 0, in.NumRows(), 0, key_dim_);
190 
191  // 'values' contains the values which we will interpolate.
192  // these don't contain the context information; that will be added
193  // later if output_context_ == true.
194  CuSubMatrix<BaseFloat> values(in, 0, in.NumRows(), key_dim_, value_dim_);
195 
196  attention::AttentionForward(key_scale_, keys, queries, values, c, out);
197 }
kaldi::int32 int32
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code.
Definition: attention.cc:97
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Properties()

◆ Read()

void Read ( std::istream &  is,
bool  binary 
)
virtual

Read function (used after we know the type of the Component); accepts input that is missing the token that describes the component type, in case it has already been consumed.

Implements Component.

Definition at line 473 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::context_dim_, RestrictedAttentionComponent::entropy_stats_, kaldi::ExpectOneOrTwoTokens(), kaldi::nnet3::ExpectToken(), RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, RestrictedAttentionComponent::output_context_, RestrictedAttentionComponent::posterior_stats_, Vector< Real >::Read(), CuMatrix< Real >::Read(), kaldi::ReadBasicType(), RestrictedAttentionComponent::stats_count_, RestrictedAttentionComponent::time_stride_, and RestrictedAttentionComponent::value_dim_.

Referenced by RestrictedAttentionComponent::Properties().

473  {
474  ExpectOneOrTwoTokens(is, binary, "<RestrictedAttentionComponent>",
475  "<NumHeads>");
476  ReadBasicType(is, binary, &num_heads_);
477  ExpectToken(is, binary, "<KeyDim>");
478  ReadBasicType(is, binary, &key_dim_);
479  ExpectToken(is, binary, "<ValueDim>");
480  ReadBasicType(is, binary, &value_dim_);
481  ExpectToken(is, binary, "<NumLeftInputs>");
482  ReadBasicType(is, binary, &num_left_inputs_);
483  ExpectToken(is, binary, "<NumRightInputs>");
484  ReadBasicType(is, binary, &num_right_inputs_);
485  ExpectToken(is, binary, "<TimeStride>");
486  ReadBasicType(is, binary, &time_stride_);
487  ExpectToken(is, binary, "<NumLeftInputsRequired>");
489  ExpectToken(is, binary, "<NumRightInputsRequired>");
491  ExpectToken(is, binary, "<OutputContext>");
492  ReadBasicType(is, binary, &output_context_);
493  ExpectToken(is, binary, "<KeyScale>");
494  ReadBasicType(is, binary, &key_scale_);
495  ExpectToken(is, binary, "<StatsCount>");
496  ReadBasicType(is, binary, &stats_count_);
497  ExpectToken(is, binary, "<EntropyStats>");
498  entropy_stats_.Read(is, binary);
499  ExpectToken(is, binary, "<PosteriorStats>");
500  posterior_stats_.Read(is, binary);
501  ExpectToken(is, binary, "</RestrictedAttentionComponent>");
502 
504 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)
This function is like ExpectToken but for two tokens, and it will either accept token1 and then token...
Definition: text-utils.cc:536
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void Read(std::istream &is, bool binary)
I/O functions.
Definition: cu-matrix.cc:494
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.

◆ ReorderIndexes()

void ReorderIndexes ( std::vector< Index > *  input_indexes,
std::vector< Index > *  output_indexes 
) const
virtual

This function only does something interesting for non-simple Components.

It provides an opportunity for a Component to reorder the or pad the indexes at its input and output. This might be useful, for instance, if a component requires a particular ordering of the indexes that doesn't correspond to their natural ordering. Components that might modify the indexes are required to return the kReordersIndexes flag in their Properties(). The ReorderIndexes() function is now allowed to insert blanks into the indexes. The 'blanks' must be of the form (n,kNoTime,x), where the marker kNoTime (a very negative number) is there where the 't' indexes normally live. The reason we don't just have, say, (-1,-1,-1), relates to the need to preserve a regular pattern over the 'n' indexes so that 'shortcut compilation' (c.f. ExpandComputation()) can work correctly

Parameters
[in,out]Indexesat the input of the Component.
[in,out]Indexesat the output of the Component

Reimplemented from Component.

Definition at line 376 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::GetComputationStructure(), and RestrictedAttentionComponent::GetIndexes().

Referenced by RestrictedAttentionComponent::DeleteMemo().

378  {
379  using namespace time_height_convolution;
380  ConvolutionComputationIo io;
381  GetComputationStructure(*input_indexes, *output_indexes, &io);
382  std::vector<Index> new_input_indexes, new_output_indexes;
383  GetIndexes(*input_indexes, *output_indexes, io,
384  &new_input_indexes, &new_output_indexes);
385  input_indexes->swap(new_input_indexes);
386  output_indexes->swap(new_output_indexes);
387 }
void GetIndexes(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo &io, std::vector< Index > *new_input_indexes, std::vector< Index > *new_output_indexes) const
void GetComputationStructure(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo *io) const

◆ Scale()

void Scale ( BaseFloat  scale)
virtual

This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when called by an UpdatableComponent.

– a Nonlinear component (or another component that stores stats, like BatchNormComponent)– it relates to scaling activation stats, not parameters. Otherwise it will normally do nothing.

Reimplemented from Component.

Definition at line 255 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::entropy_stats_, RestrictedAttentionComponent::posterior_stats_, VectorBase< Real >::Scale(), CuMatrixBase< Real >::Scale(), and RestrictedAttentionComponent::stats_count_.

Referenced by RestrictedAttentionComponent::Properties().

255  {
256  entropy_stats_.Scale(scale);
257  posterior_stats_.Scale(scale);
258  stats_count_ *= scale;
259 }
void Scale(Real value)
Definition: cu-matrix.cc:644
void Scale(Real alpha)
Multiplies all elements by this constant.

◆ StoreStats()

void StoreStats ( const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
void *  memo 
)
virtual

This function may store stats on average activation values, and for some component types, the average value of the derivative of the nonlinearity.

It only does something for those components that have nonzero Properties()&kStoresStats.

Parameters
[in]in_valueThe input to the Propagate() function. Note: if the component sets the flag kPropagateInPlace, this should not be used; the empty matrix will be provided here if in-place propagation was used.
[in]out_valueThe output of the Propagate() function.
[in]memoThe 'memo' returned by the Propagate() function; this will usually be NULL.

Reimplemented from Component.

Definition at line 200 of file nnet-attention-component.cc.

References CuVectorBase< Real >::AddDiagMatMat(), CuMatrixBase< Real >::AddMat(), CuVectorBase< Real >::AddRowSumMat(), VectorBase< Real >::AddVec(), CuMatrixBase< Real >::ApplyFloor(), CuMatrixBase< Real >::ApplyLog(), RestrictedAttentionComponent::Memo::c, RestrictedAttentionComponent::context_dim_, CuVectorBase< Real >::Data(), VectorBase< Real >::Dim(), RestrictedAttentionComponent::entropy_stats_, KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, RestrictedAttentionComponent::num_heads_, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), RestrictedAttentionComponent::posterior_stats_, kaldi::RandInt(), Vector< Real >::Resize(), CuMatrix< Real >::Resize(), and RestrictedAttentionComponent::stats_count_.

Referenced by RestrictedAttentionComponent::Properties().

203  {
204  const Memo *memo = static_cast<const Memo*>(memo_in);
205  KALDI_ASSERT(memo != NULL);
206  if (entropy_stats_.Dim() != num_heads_) {
209  stats_count_ = 0.0;
210  }
211  const CuMatrix<BaseFloat> &c = memo->c;
212  if (RandInt(0, 2) == 0)
213  return; // only actually store the stats for one in three minibatches, to
214  // save time.
215 
216  { // first get the posterior stats.
217  CuVector<BaseFloat> c_sum(num_heads_ * context_dim_);
218  c_sum.AddRowSumMat(1.0, c, 0.0);
219  // view the vector as a matrix.
220  CuSubMatrix<BaseFloat> c_sum_as_mat(c_sum.Data(), num_heads_,
222  CuMatrix<double> c_sum_as_mat_dbl(c_sum_as_mat);
223  posterior_stats_.AddMat(1.0, c_sum_as_mat_dbl);
224  KALDI_ASSERT(c.NumCols() == num_heads_ * context_dim_);
225  }
226  { // now get the entropy stats.
227  CuMatrix<BaseFloat> log_c(c);
228  log_c.ApplyFloor(1.0e-20);
229  log_c.ApplyLog();
230  CuVector<BaseFloat> dot_prod(num_heads_ * context_dim_);
231  dot_prod.AddDiagMatMat(-1.0, c, kTrans, log_c, kNoTrans, 0.0);
232  // dot_prod is the sum over the matrix's rows (which correspond
233  // to heads, and context positions), of - c * log(c), which is
234  // part of the entropy. To get the actual contribution to the
235  // entropy, we have to sum 'dot_prod' over blocks of
236  // size 'context_dim_'; that gives us the entropy contribution
237  // per head. We'd have to divide by c.NumRows() to get the
238  // actual entropy, but that's reflected in stats_count_.
239  CuSubMatrix<BaseFloat> entropy_mat(dot_prod.Data(), num_heads_,
241  CuVector<BaseFloat> entropy_vec(num_heads_);
242  entropy_vec.AddColSumMat(1.0, entropy_mat);
243  Vector<double> entropy_vec_dbl(entropy_vec);
244  entropy_stats_.AddVec(1.0, entropy_vec_dbl);
245  }
246  stats_count_ += c.NumRows();
247 }
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
Definition: cu-matrix.cc:954
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:50
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95

◆ Type()

virtual std::string Type ( ) const
inlinevirtual

Returns a string such as "SigmoidComponent", describing the type of the object.

Implements Component.

Definition at line 130 of file nnet-attention-component.h.

Referenced by RestrictedAttentionComponent::Info().

130 { return "RestrictedAttentionComponent"; }

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write component to stream.

Implements Component.

Definition at line 442 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::entropy_stats_, RestrictedAttentionComponent::key_dim_, RestrictedAttentionComponent::key_scale_, RestrictedAttentionComponent::num_heads_, RestrictedAttentionComponent::num_left_inputs_, RestrictedAttentionComponent::num_left_inputs_required_, RestrictedAttentionComponent::num_right_inputs_, RestrictedAttentionComponent::num_right_inputs_required_, RestrictedAttentionComponent::output_context_, RestrictedAttentionComponent::posterior_stats_, RestrictedAttentionComponent::stats_count_, RestrictedAttentionComponent::time_stride_, RestrictedAttentionComponent::value_dim_, VectorBase< Real >::Write(), CuMatrixBase< Real >::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by RestrictedAttentionComponent::Properties().

442  {
443  WriteToken(os, binary, "<RestrictedAttentionComponent>");
444  WriteToken(os, binary, "<NumHeads>");
445  WriteBasicType(os, binary, num_heads_);
446  WriteToken(os, binary, "<KeyDim>");
447  WriteBasicType(os, binary, key_dim_);
448  WriteToken(os, binary, "<ValueDim>");
449  WriteBasicType(os, binary, value_dim_);
450  WriteToken(os, binary, "<NumLeftInputs>");
451  WriteBasicType(os, binary, num_left_inputs_);
452  WriteToken(os, binary, "<NumRightInputs>");
453  WriteBasicType(os, binary, num_right_inputs_);
454  WriteToken(os, binary, "<TimeStride>");
455  WriteBasicType(os, binary, time_stride_);
456  WriteToken(os, binary, "<NumLeftInputsRequired>");
458  WriteToken(os, binary, "<NumRightInputsRequired>");
460  WriteToken(os, binary, "<OutputContext>");
461  WriteBasicType(os, binary, output_context_);
462  WriteToken(os, binary, "<KeyScale>");
463  WriteBasicType(os, binary, key_scale_);
464  WriteToken(os, binary, "<StatsCount>");
465  WriteBasicType(os, binary, stats_count_);
466  WriteToken(os, binary, "<EntropyStats>");
467  entropy_stats_.Write(os, binary);
468  WriteToken(os, binary, "<PosteriorStats>");
469  posterior_stats_.Write(os, binary);
470  WriteToken(os, binary, "</RestrictedAttentionComponent>");
471 }
void Write(std::ostream &os, bool binary) const
Definition: cu-matrix.cc:502
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

◆ ZeroStats()

void ZeroStats ( )
virtual

Components that provide an implementation of StoreStats should also provide an implementation of ZeroStats(), to set those stats to zero.

Other components that store other types of statistics (e.g. regarding gradient clipping) should implement ZeroStats() also.

Reimplemented from Component.

Definition at line 249 of file nnet-attention-component.cc.

References RestrictedAttentionComponent::entropy_stats_, RestrictedAttentionComponent::posterior_stats_, VectorBase< Real >::SetZero(), CuMatrixBase< Real >::SetZero(), and RestrictedAttentionComponent::stats_count_.

Referenced by RestrictedAttentionComponent::Properties().

249  {
252  stats_count_ = 0.0;
253 }
void SetZero()
Math operations, some calling kernels.
Definition: cu-matrix.cc:509
void SetZero()
Set vector to all zeros.

Member Data Documentation

◆ context_dim_

◆ entropy_stats_

◆ key_dim_

◆ key_scale_

◆ num_heads_

◆ num_left_inputs_

◆ num_left_inputs_required_

◆ num_right_inputs_

◆ num_right_inputs_required_

◆ output_context_

◆ posterior_stats_

◆ stats_count_

◆ time_stride_

◆ value_dim_


The documentation for this class was generated from the following files: