Functions | |
void | GetAttentionDotProductsSimple (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C) |
void | ApplyScalesToOutputSimple (BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A) |
void | ApplyScalesToInputSimple (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B) |
void | UnitTestAttentionDotProductAndAddScales () |
void | TestAttentionForwardBackward () |
void | UnitTestAttention () |
void | GetAttentionDotProducts (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C) |
This function is a utility function that is at the core of how we implement attention. More... | |
void | ApplyScalesToOutput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A) |
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softmax scales, and in backprop. More... | |
void | ApplyScalesToInput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B) |
This function is related to GetAttentionDotProducts(); it is used in backprop. More... | |
void | AttentionForward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output) |
This is a higher-level interface to the attention code. More... | |
void | AttentionBackward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv) |
Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to the keys, queries and values. More... | |
void ApplyScalesToInput | ( | BaseFloat | alpha, |
const CuMatrixBase< BaseFloat > & | A, | ||
const CuMatrixBase< BaseFloat > & | C, | ||
CuMatrixBase< BaseFloat > * | B | ||
) |
This function is related to GetAttentionDotProducts(); it is used in backprop.
We have put the A, B and C in an unusual order here in order to make clearer the relationship with GetAttentionDotProducts(). The matrices have the same relationship in terms of their dimensions, as A, B and C do in GetAttentionDotProducts().
This function implements:
B->Row(i + j * row_shift) += alpha * C(i, j) * A.Row(i).
Definition at line 76 of file attention.cc.
References CuMatrixBase< Real >::AddDiagVecMat(), KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().
Referenced by AttentionBackward(), and UnitTestAttentionDotProductAndAddScales().
void kaldi::nnet3::attention::ApplyScalesToInputSimple | ( | BaseFloat | alpha, |
const CuMatrixBase< BaseFloat > & | A, | ||
const CuMatrixBase< BaseFloat > & | C, | ||
CuMatrixBase< BaseFloat > * | B | ||
) |
Definition at line 72 of file attention-test.cc.
References rnnlm::i, rnnlm::j, KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().
Referenced by UnitTestAttentionDotProductAndAddScales().
void ApplyScalesToOutput | ( | BaseFloat | alpha, |
const CuMatrixBase< BaseFloat > & | B, | ||
const CuMatrixBase< BaseFloat > & | C, | ||
CuMatrixBase< BaseFloat > * | A | ||
) |
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softmax scales, and in backprop.
We have put the A, B and C in an unusual order here in order to make clearer the relationship with GetAttentionDotProducts(). The matrices have the same relationship in terms of their dimensions, as A, B and C do in GetAttentionDotProducts().
This function implements:
A->Row(i) += alpha * C(i, j) * B.Row(i + j * row_shift).
Definition at line 55 of file attention.cc.
References CuMatrixBase< Real >::AddDiagVecMat(), KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().
Referenced by AttentionBackward(), AttentionForward(), and UnitTestAttentionDotProductAndAddScales().
void kaldi::nnet3::attention::ApplyScalesToOutputSimple | ( | BaseFloat | alpha, |
const CuMatrixBase< BaseFloat > & | B, | ||
const CuMatrixBase< BaseFloat > & | C, | ||
CuMatrixBase< BaseFloat > * | A | ||
) |
Definition at line 52 of file attention-test.cc.
References rnnlm::i, rnnlm::j, KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().
Referenced by UnitTestAttentionDotProductAndAddScales().
void AttentionBackward | ( | BaseFloat | key_scale, |
const CuMatrixBase< BaseFloat > & | keys, | ||
const CuMatrixBase< BaseFloat > & | queries, | ||
const CuMatrixBase< BaseFloat > & | values, | ||
const CuMatrixBase< BaseFloat > & | c, | ||
const CuMatrixBase< BaseFloat > & | output_deriv, | ||
CuMatrixBase< BaseFloat > * | keys_deriv, | ||
CuMatrixBase< BaseFloat > * | queries_deriv, | ||
CuMatrixBase< BaseFloat > * | values_deriv | ||
) |
Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to the keys, queries and values.
The interface should be easy to understand with reference to AttentionForward(), so we won't document it, except to note that 'keys_deriv', 'queries_deriv' and 'values_deriv' are added to*, not set, by this function.
Definition at line 154 of file attention.cc.
References CuMatrixBase< Real >::AddMat(), ApplyScalesToInput(), ApplyScalesToOutput(), CuMatrixBase< Real >::DiffSoftmaxPerRow(), GetAttentionDotProducts(), KALDI_ASSERT, kaldi::kUndefined, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), and kaldi::SameDim().
Referenced by RestrictedAttentionComponent::BackpropOneHead(), and TestAttentionForwardBackward().
void AttentionForward | ( | BaseFloat | key_scale, |
const CuMatrixBase< BaseFloat > & | keys, | ||
const CuMatrixBase< BaseFloat > & | queries, | ||
const CuMatrixBase< BaseFloat > & | values, | ||
CuMatrixBase< BaseFloat > * | c, | ||
CuMatrixBase< BaseFloat > * | output | ||
) |
This is a higher-level interface to the attention code.
Read the extended comment in the file nnet3/attention.h for context.
[in] | key_scale | Scale on the non-context part of the keys. |
[in] | keys | Matrix whose rows contains the keys, dimension is num-input-rows by key-dim. |
[in] | queries | Matrix whose rows contains the queries, dimension is num-output-rows by query-dim, where query-dim == key-dim + context-dim. num-output-rows - num-input-rows must be a multiple of context-dim - 1 (we'll 'shift' the keys by multiples of 0, n, 2n, ... (context-dim - 1) * n. |
[in] | values | Values to average at the output, of dimension num-input-rows by value-dim. [we may add context information to these averages if required, see comment for 'output']. |
[out] | c | Expected to be finite at entry (no infs or nan's); at exit this will contain the output of the softmax. Must be of dimension num-output-rows by context-dim. |
[out] | output | The output of the attention mechanism will be *added* to this location. Dimension must be num-output-rows by either value-dim, or value-dim + context-dim. To the first 'value-dim' columns of this will be added the weighted combination of 'values', weighted by the corresponding weights of 'c' (e.g. the first column of 'c' scaling the first 'output-dim' rows of 'values', then the next column of 'c' scaling the submatrix of 'values' shifted by 'n', and so on. If the output->NumCols() is value-dim + context-dim, 'c' will be added to the remaining columns of 'output'. |
Definition at line 97 of file attention.cc.
References CuMatrixBase< Real >::AddMat(), ApplyScalesToOutput(), CuMatrixBase< Real >::CopyFromMat(), GetAttentionDotProducts(), KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), and CuMatrixBase< Real >::SoftMaxPerRow().
Referenced by RestrictedAttentionComponent::PropagateOneHead(), and TestAttentionForwardBackward().
void GetAttentionDotProducts | ( | BaseFloat | alpha, |
const CuMatrixBase< BaseFloat > & | A, | ||
const CuMatrixBase< BaseFloat > & | B, | ||
CuMatrixBase< BaseFloat > * | C | ||
) |
This function is a utility function that is at the core of how we implement attention.
It may in future need to be renamed and possibly moved into the cudamatrix directory and implemented in CUDA. The current implementation is quite inefficient. We can also consider doing a complete redesign of how the implementation works, such that this function doesn't exist at all; or we could have a batched version of this function that would operate on a batch of A, B and C at once (or a "strided, batched" version where the difference between the members of the batch is expressed as a stride).
This function implements a special operation that you could view as some kind of matrix multiplication where only a band of the product is retained.
The inputs A and B must have the same number of columns (A.NumCols() == B.NumCols()), and A and C must have the same number of rows (A.NumRows() == C->NumRows()). The number of rows of B must exceed the number of rows of A. Define num_extra_rows = B.NumRows() - A.NumRows(). Then C.NumCols() - 1 must divide num_extra_rows. Define row_shift = num_extra_rows / (C.NumCols() - 1).
This function implements: (*C)(i, j) = alpha * VecVec(A.Row(i), B.Row(i + j * row_shift))
Definition at line 32 of file attention.cc.
References CuVectorBase< Real >::AddDiagMatMat(), CuMatrixBase< Real >::CopyFromMat(), KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().
Referenced by AttentionBackward(), AttentionForward(), and UnitTestAttentionDotProductAndAddScales().
void kaldi::nnet3::attention::GetAttentionDotProductsSimple | ( | BaseFloat | alpha, |
const CuMatrixBase< BaseFloat > & | A, | ||
const CuMatrixBase< BaseFloat > & | B, | ||
CuMatrixBase< BaseFloat > * | C | ||
) |
Definition at line 30 of file attention-test.cc.
References rnnlm::i, rnnlm::j, KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().
Referenced by UnitTestAttentionDotProductAndAddScales().
void kaldi::nnet3::attention::TestAttentionForwardBackward | ( | ) |
Definition at line 120 of file attention-test.cc.
References CuMatrixBase< Real >::AddMat(), AttentionBackward(), AttentionForward(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, kaldi::kTrans, CuMatrixBase< Real >::NumCols(), kaldi::RandInt(), CuMatrixBase< Real >::Scale(), CuMatrixBase< Real >::SetRandn(), CuMatrixBase< Real >::SetZero(), and kaldi::TraceMatMat().
Referenced by UnitTestAttention().
void kaldi::nnet3::attention::UnitTestAttention | ( | ) |
Definition at line 229 of file attention-test.cc.
References TestAttentionForwardBackward(), and UnitTestAttentionDotProductAndAddScales().
Referenced by main().
void kaldi::nnet3::attention::UnitTestAttentionDotProductAndAddScales | ( | ) |
Definition at line 91 of file attention-test.cc.
References ApplyScalesToInput(), ApplyScalesToInputSimple(), ApplyScalesToOutput(), ApplyScalesToOutputSimple(), kaldi::AssertEqual(), GetAttentionDotProducts(), GetAttentionDotProductsSimple(), kaldi::RandInt(), and CuMatrixBase< Real >::SetRandn().
Referenced by UnitTestAttention().