42 KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
43 int32 row_shift = num_extra_rows / (context_dim - 1);
46 for (
int32 o = 0; o < context_dim; o++) {
65 KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
66 int32 row_shift = num_extra_rows / (context_dim - 1);
68 for (
int32 o = 0; o < context_dim; o++) {
86 KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
87 int32 row_shift = num_extra_rows / (context_dim - 1);
89 for (
int32 o = 0; o < context_dim; o++) {
107 num_output_rows = queries.
NumRows(),
108 context_dim = queries.
NumCols() - key_dim,
111 num_input_rows > num_output_rows &&
113 (num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
114 values.
NumRows() == num_input_rows);
118 (output->
NumCols() == value_dim ||
119 output->
NumCols() == value_dim + context_dim));
122 queries, 0, num_output_rows,
124 queries_context_part(
125 queries, 0, num_output_rows,
126 key_dim, context_dim);
132 c->
AddMat(1.0, queries_context_part);
142 *output, 0, num_output_rows, 0, value_dim);
147 if (output->
NumCols() == value_dim + context_dim) {
149 *output, 0, num_output_rows, value_dim, context_dim);
168 num_output_rows = queries.
NumRows(),
169 context_dim = queries.
NumCols() - key_dim,
172 num_input_rows > num_output_rows &&
174 (num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
175 values.
NumRows() == num_input_rows);
177 SameDim(queries, *queries_deriv) &&
178 SameDim(values, *values_deriv));
183 (output_deriv.
NumCols() == value_dim ||
184 output_deriv.
NumCols() == value_dim + context_dim));
190 output_deriv, 0, num_output_rows, 0, value_dim);
196 if (output_deriv.
NumCols() == value_dim + context_dim) {
198 output_deriv, 0, num_output_rows, value_dim, context_dim);
201 c_deriv.
AddMat(1.0, output_deriv_context_part);
213 queries, 0, num_output_rows,
215 queries_key_part_deriv(
216 *queries_deriv, 0, num_output_rows,
218 queries_context_part_deriv(
219 *queries_deriv, 0, num_output_rows,
220 key_dim, context_dim);
224 queries_context_part_deriv.
AddMat(1.0, c_deriv);
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to th...
void ApplyScalesToInput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
This function is related to GetAttentionDotProducts(); it is used in backprop.
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code.
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void AddDiagMatMat(Real alpha, const CuMatrixBase< Real > &M, MatrixTransposeType transM, const CuMatrixBase< Real > &N, MatrixTransposeType transN, Real beta=1.0)
Add the diagonal of a matrix product: *this = diag(M N), assuming the "trans" arguments are both kNoT...
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention.
void SoftMaxPerRow(const CuMatrixBase< Real > &src)
Softmax nonlinearity Y = Softmax(X) : Yij = e^Xij / sum_k(e^Xik), done to each row, with attention to avoiding overflow or underflow.
This class is used for a piece of a CuMatrix.
void DiffSoftmaxPerRow(const CuMatrixBase< Real > &value, const CuMatrixBase< Real > &diff)
Differentiate backward through the softmax function.
Matrix for CUDA computing.
MatrixIndexT NumCols() const
#define KALDI_ASSERT(cond)
This file contains the lower-level interface for self-attention.
MatrixIndexT NumRows() const
Dimensions.
void AddDiagVecMat(const Real alpha, const CuVectorBase< Real > &v, const CuMatrixBase< Real > &M, MatrixTransposeType transM, Real beta=1.0)
*this = beta * *this + alpha * diag(v) * M [or M^T].