39   KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
    40   int32 row_shift = num_extra_rows / (context_dim - 1);
    44       for (
int32 k = 0; k < input_num_cols; k++) {
    45         (*C)(
i, 
j) += alpha * A(i, k) * B(i + (
j * row_shift), k);
    60   KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
    61   int32 row_shift = num_extra_rows / (context_dim - 1);
    64       for (
int32 k = 0; k < context_dim; k++) {
    65         (*A)(
i, 
j) += alpha * C(i, k) * B(i + (k * row_shift), 
j);
    80   KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
    81   int32 row_shift = num_extra_rows / (context_dim - 1);
    84       for (
int32 k = 0; k < context_dim; k++) {
    85         (*B)(
i + (k * row_shift), 
j) += alpha * C(
i, k) * A(
i, 
j);
    94       num_extra_rows = (context_dim - 1) * row_shift,
    95       input_num_rows = output_num_rows + num_extra_rows;
    98       B(input_num_rows, input_num_cols),
    99       C(output_num_rows, context_dim);
   124   bool output_context = (
RandInt(0, 1) == 0);
   128       num_extra_rows = (context_dim - 1) * row_shift,
   129       input_num_rows = output_num_rows + num_extra_rows,
   130       query_dim = key_dim + context_dim;
   132       queries(output_num_rows, query_dim),
   133       values(input_num_rows, value_dim),
   134       C(output_num_rows, context_dim),
   135       output(output_num_rows, value_dim + (output_context ? context_dim : 0));
   146       queries_deriv(output_num_rows, query_dim),
   147       values_deriv(input_num_rows, value_dim),
   148       output_deriv(output_num_rows, output.
NumCols());
   153                     output_deriv, &keys_deriv, &queries_deriv,
   163     for (
int32 i = 0; 
i < test_dim; 
i++) {
   166       values2.
Scale(epsilon);
   168       values2.
AddMat(1.0, values);
   173           observed_delta_objf = objf2 - objf_baseline;
   174       KALDI_LOG << 
"Changing values: predicted objf change is "   175                 << predicted_delta_objf << 
", observed objf change is "   176                 << observed_delta_objf;
   177       predicted_vec(
i) = predicted_delta_objf;
   178       observed_vec(
i) = observed_delta_objf;
   180     KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
   185     for (
int32 i = 0; 
i < test_dim; 
i++) {
   188       keys2.
Scale(epsilon);
   195           observed_delta_objf = objf2 - objf_baseline;
   196       KALDI_LOG << 
"Changing keys: predicted objf change is "   197                 << predicted_delta_objf << 
", observed objf change is "   198                 << observed_delta_objf;
   199       predicted_vec(
i) = predicted_delta_objf;
   200       observed_vec(
i) = observed_delta_objf;
   202     KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
   208     for (
int32 i = 0; 
i < test_dim; 
i++) {
   211       queries2.
Scale(epsilon);
   213       queries2.
AddMat(1.0, queries);
   218           observed_delta_objf = objf2 - objf_baseline;
   219       KALDI_LOG << 
"Changing queries: predicted objf change is "   220                 << predicted_delta_objf << 
", observed objf change is "   221                 << observed_delta_objf;
   222       predicted_vec(
i) = predicted_delta_objf;
   223       observed_vec(
i) = observed_delta_objf;
   225     KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
   241   using namespace kaldi;
   244   for (
int32 loop = 0; loop < 2; loop++) {
   246     CuDevice::Instantiate().SetDebugStrideMode(
true);
   248       CuDevice::Instantiate().SelectGpuId(
"no"); 
   250       CuDevice::Instantiate().SelectGpuId(
"optional"); 
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
 
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to th...
 
void ApplyScalesToInputSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
 
void ApplyScalesToInput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
This function is related to GetAttentionDotProducts(); it is used in backprop. 
 
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code. 
 
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A 
 
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
 
void ApplyScalesToOutputSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
 
void TestAttentionForwardBackward()
 
void GetAttentionDotProductsSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
 
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
 
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention. 
 
void SetZero()
Math operations, some calling kernels. 
 
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function. 
 
Matrix for CUDA computing. 
 
MatrixIndexT NumCols() const
 
A class representing a vector. 
 
#define KALDI_ASSERT(cond)
 
void UnitTestAttentionDotProductAndAddScales()
 
This file contains the lower-level interface for self-attention. 
 
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b)) 
 
MatrixIndexT NumRows() const
Dimensions. 
 
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)