39 KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
40 int32 row_shift = num_extra_rows / (context_dim - 1);
44 for (
int32 k = 0; k < input_num_cols; k++) {
45 (*C)(
i,
j) += alpha * A(i, k) * B(i + (
j * row_shift), k);
60 KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
61 int32 row_shift = num_extra_rows / (context_dim - 1);
64 for (
int32 k = 0; k < context_dim; k++) {
65 (*A)(
i,
j) += alpha * C(i, k) * B(i + (k * row_shift),
j);
80 KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
81 int32 row_shift = num_extra_rows / (context_dim - 1);
84 for (
int32 k = 0; k < context_dim; k++) {
85 (*B)(
i + (k * row_shift),
j) += alpha * C(
i, k) * A(
i,
j);
94 num_extra_rows = (context_dim - 1) * row_shift,
95 input_num_rows = output_num_rows + num_extra_rows;
98 B(input_num_rows, input_num_cols),
99 C(output_num_rows, context_dim);
124 bool output_context = (
RandInt(0, 1) == 0);
128 num_extra_rows = (context_dim - 1) * row_shift,
129 input_num_rows = output_num_rows + num_extra_rows,
130 query_dim = key_dim + context_dim;
132 queries(output_num_rows, query_dim),
133 values(input_num_rows, value_dim),
134 C(output_num_rows, context_dim),
135 output(output_num_rows, value_dim + (output_context ? context_dim : 0));
146 queries_deriv(output_num_rows, query_dim),
147 values_deriv(input_num_rows, value_dim),
148 output_deriv(output_num_rows, output.
NumCols());
153 output_deriv, &keys_deriv, &queries_deriv,
163 for (
int32 i = 0;
i < test_dim;
i++) {
166 values2.
Scale(epsilon);
168 values2.
AddMat(1.0, values);
173 observed_delta_objf = objf2 - objf_baseline;
174 KALDI_LOG <<
"Changing values: predicted objf change is " 175 << predicted_delta_objf <<
", observed objf change is " 176 << observed_delta_objf;
177 predicted_vec(
i) = predicted_delta_objf;
178 observed_vec(
i) = observed_delta_objf;
180 KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
185 for (
int32 i = 0;
i < test_dim;
i++) {
188 keys2.
Scale(epsilon);
195 observed_delta_objf = objf2 - objf_baseline;
196 KALDI_LOG <<
"Changing keys: predicted objf change is " 197 << predicted_delta_objf <<
", observed objf change is " 198 << observed_delta_objf;
199 predicted_vec(
i) = predicted_delta_objf;
200 observed_vec(
i) = observed_delta_objf;
202 KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
208 for (
int32 i = 0;
i < test_dim;
i++) {
211 queries2.
Scale(epsilon);
213 queries2.
AddMat(1.0, queries);
218 observed_delta_objf = objf2 - objf_baseline;
219 KALDI_LOG <<
"Changing queries: predicted objf change is " 220 << predicted_delta_objf <<
", observed objf change is " 221 << observed_delta_objf;
222 predicted_vec(
i) = predicted_delta_objf;
223 observed_vec(
i) = observed_delta_objf;
225 KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
241 using namespace kaldi;
244 for (
int32 loop = 0; loop < 2; loop++) {
246 CuDevice::Instantiate().SetDebugStrideMode(
true);
248 CuDevice::Instantiate().SelectGpuId(
"no");
250 CuDevice::Instantiate().SelectGpuId(
"optional");
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to th...
void ApplyScalesToInputSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
void ApplyScalesToInput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
This function is related to GetAttentionDotProducts(); it is used in backprop.
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code.
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void ApplyScalesToOutputSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
void TestAttentionForwardBackward()
void GetAttentionDotProductsSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention.
void SetZero()
Math operations, some calling kernels.
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
Matrix for CUDA computing.
MatrixIndexT NumCols() const
A class representing a vector.
#define KALDI_ASSERT(cond)
void UnitTestAttentionDotProductAndAddScales()
This file contains the lower-level interface for self-attention.
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
MatrixIndexT NumRows() const
Dimensions.
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)