42   KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
    43   int32 row_shift = num_extra_rows / (context_dim - 1);
    46   for (
int32 o = 0; o < context_dim; o++) {
    65   KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
    66   int32 row_shift = num_extra_rows / (context_dim - 1);
    68   for (
int32 o = 0; o < context_dim; o++) {
    86   KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
    87   int32 row_shift = num_extra_rows / (context_dim - 1);
    89   for (
int32 o = 0; o < context_dim; o++) {
   107       num_output_rows = queries.
NumRows(),
   108       context_dim = queries.
NumCols() - key_dim,
   111                num_input_rows > num_output_rows &&
   113                (num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
   114                values.
NumRows() == num_input_rows);
   118                (output->
NumCols() == value_dim ||
   119                 output->
NumCols() == value_dim + context_dim));
   122       queries, 0, num_output_rows,
   124       queries_context_part(
   125           queries, 0, num_output_rows,
   126           key_dim, context_dim);
   132   c->
AddMat(1.0, queries_context_part);
   142       *output, 0, num_output_rows, 0, value_dim);
   147   if (output->
NumCols() == value_dim + context_dim) {
   149         *output, 0, num_output_rows, value_dim, context_dim);
   168       num_output_rows = queries.
NumRows(),
   169       context_dim = queries.
NumCols() - key_dim,
   172                num_input_rows > num_output_rows &&
   174                (num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
   175                values.
NumRows() == num_input_rows);
   177                SameDim(queries, *queries_deriv) &&
   178                SameDim(values, *values_deriv));
   183                (output_deriv.
NumCols() == value_dim ||
   184                 output_deriv.
NumCols() == value_dim + context_dim));
   190       output_deriv, 0, num_output_rows, 0, value_dim);
   196   if (output_deriv.
NumCols() == value_dim + context_dim) {
   198         output_deriv, 0, num_output_rows, value_dim, context_dim);
   201     c_deriv.
AddMat(1.0, output_deriv_context_part);
   213       queries, 0, num_output_rows,
   215       queries_key_part_deriv(
   216           *queries_deriv, 0, num_output_rows,
   218       queries_context_part_deriv(
   219           *queries_deriv, 0, num_output_rows,
   220           key_dim, context_dim);
   224   queries_context_part_deriv.
AddMat(1.0, c_deriv);
 void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
 
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to th...
 
void ApplyScalesToInput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
This function is related to GetAttentionDotProducts(); it is used in backprop. 
 
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code. 
 
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A 
 
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
 
void AddDiagMatMat(Real alpha, const CuMatrixBase< Real > &M, MatrixTransposeType transM, const CuMatrixBase< Real > &N, MatrixTransposeType transN, Real beta=1.0)
Add the diagonal of a matrix product: *this = diag(M N), assuming the "trans" arguments are both kNoT...
 
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
 
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
 
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention. 
 
void SoftMaxPerRow(const CuMatrixBase< Real > &src)
Softmax nonlinearity Y = Softmax(X) : Yij = e^Xij / sum_k(e^Xik), done to each row, with attention to avoiding overflow or underflow. 
 
This class is used for a piece of a CuMatrix. 
 
void DiffSoftmaxPerRow(const CuMatrixBase< Real > &value, const CuMatrixBase< Real > &diff)
Differentiate backward through the softmax function. 
 
Matrix for CUDA computing. 
 
MatrixIndexT NumCols() const
 
#define KALDI_ASSERT(cond)
 
This file contains the lower-level interface for self-attention. 
 
MatrixIndexT NumRows() const
Dimensions. 
 
void AddDiagVecMat(const Real alpha, const CuVectorBase< Real > &v, const CuMatrixBase< Real > &M, MatrixTransposeType transM, Real beta=1.0)
*this = beta * *this + alpha * diag(v) * M [or M^T].