kaldi::nnet3::attention Namespace Reference

Functions

void GetAttentionDotProductsSimple (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
 
void ApplyScalesToOutputSimple (BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
 
void ApplyScalesToInputSimple (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
 
void UnitTestAttentionDotProductAndAddScales ()
 
void TestAttentionForwardBackward ()
 
void UnitTestAttention ()
 
void GetAttentionDotProducts (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
 This function is a utility function that is at the core of how we implement attention. More...
 
void ApplyScalesToOutput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
 This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softmax scales, and in backprop. More...
 
void ApplyScalesToInput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
 This function is related to GetAttentionDotProducts(); it is used in backprop. More...
 
void AttentionForward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
 This is a higher-level interface to the attention code. More...
 
void AttentionBackward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
 Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to the keys, queries and values. More...
 

Function Documentation

◆ ApplyScalesToInput()

void ApplyScalesToInput ( BaseFloat  alpha,
const CuMatrixBase< BaseFloat > &  A,
const CuMatrixBase< BaseFloat > &  C,
CuMatrixBase< BaseFloat > *  B 
)

This function is related to GetAttentionDotProducts(); it is used in backprop.

We have put the A, B and C in an unusual order here in order to make clearer the relationship with GetAttentionDotProducts(). The matrices have the same relationship in terms of their dimensions, as A, B and C do in GetAttentionDotProducts().

This function implements:

B->Row(i + j * row_shift) += alpha * C(i, j) * A.Row(i).

Definition at line 76 of file attention.cc.

References CuMatrixBase< Real >::AddDiagVecMat(), KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().

Referenced by AttentionBackward(), and UnitTestAttentionDotProductAndAddScales().

79  {
80  KALDI_ASSERT(A.NumCols() == B->NumCols() &&
81  A.NumRows() == C.NumRows());
82  int32 num_output_rows = A.NumRows(),
83  input_num_cols = A.NumCols(),
84  num_extra_rows = B->NumRows() - A.NumRows(),
85  context_dim = C.NumCols();
86  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
87  int32 row_shift = num_extra_rows / (context_dim - 1);
88  CuMatrix<BaseFloat> Ctrans(C, kTrans);
89  for (int32 o = 0; o < context_dim; o++) {
90  CuSubVector<BaseFloat> c_col(Ctrans, o);
91  CuSubMatrix<BaseFloat> B_part(*B, o * row_shift, num_output_rows,
92  0, input_num_cols);
93  B_part.AddDiagVecMat(alpha, c_col, A, kNoTrans, 1.0);
94  }
95 }
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ApplyScalesToInputSimple()

void kaldi::nnet3::attention::ApplyScalesToInputSimple ( BaseFloat  alpha,
const CuMatrixBase< BaseFloat > &  A,
const CuMatrixBase< BaseFloat > &  C,
CuMatrixBase< BaseFloat > *  B 
)

Definition at line 72 of file attention-test.cc.

References rnnlm::i, rnnlm::j, KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().

Referenced by UnitTestAttentionDotProductAndAddScales().

75  {
76  KALDI_ASSERT(A.NumCols() == B->NumCols() &&
77  A.NumRows() == C.NumRows());
78  int32 num_extra_rows = B->NumRows() - A.NumRows(),
79  context_dim = C.NumCols();
80  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
81  int32 row_shift = num_extra_rows / (context_dim - 1);
82  for (int32 i = 0; i < A.NumRows(); i++) {
83  for (int32 j = 0; j < A.NumCols(); j++) {
84  for (int32 k = 0; k < context_dim; k++) {
85  (*B)(i + (k * row_shift), j) += alpha * C(i, k) * A(i, j);
86  }
87  }
88  }
89 }
kaldi::int32 int32
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215

◆ ApplyScalesToOutput()

void ApplyScalesToOutput ( BaseFloat  alpha,
const CuMatrixBase< BaseFloat > &  B,
const CuMatrixBase< BaseFloat > &  C,
CuMatrixBase< BaseFloat > *  A 
)

This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softmax scales, and in backprop.

We have put the A, B and C in an unusual order here in order to make clearer the relationship with GetAttentionDotProducts(). The matrices have the same relationship in terms of their dimensions, as A, B and C do in GetAttentionDotProducts().

This function implements:

A->Row(i) += alpha * C(i, j) * B.Row(i + j * row_shift).

Definition at line 55 of file attention.cc.

References CuMatrixBase< Real >::AddDiagVecMat(), KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().

Referenced by AttentionBackward(), AttentionForward(), and UnitTestAttentionDotProductAndAddScales().

58  {
59  KALDI_ASSERT(A->NumCols() == B.NumCols() &&
60  A->NumRows() == C.NumRows());
61  int32 num_output_rows = A->NumRows(),
62  input_num_cols = A->NumCols(),
63  num_extra_rows = B.NumRows() - A->NumRows(),
64  context_dim = C.NumCols();
65  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
66  int32 row_shift = num_extra_rows / (context_dim - 1);
67  CuMatrix<BaseFloat> Ctrans(C, kTrans);
68  for (int32 o = 0; o < context_dim; o++) {
69  CuSubVector<BaseFloat> c_col(Ctrans, o);
70  CuSubMatrix<BaseFloat> B_part(B, o * row_shift, num_output_rows,
71  0, input_num_cols);
72  A->AddDiagVecMat(alpha, c_col, B_part, kNoTrans, 1.0);
73  }
74 }
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ApplyScalesToOutputSimple()

void kaldi::nnet3::attention::ApplyScalesToOutputSimple ( BaseFloat  alpha,
const CuMatrixBase< BaseFloat > &  B,
const CuMatrixBase< BaseFloat > &  C,
CuMatrixBase< BaseFloat > *  A 
)

Definition at line 52 of file attention-test.cc.

References rnnlm::i, rnnlm::j, KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().

Referenced by UnitTestAttentionDotProductAndAddScales().

55  {
56  KALDI_ASSERT(A->NumCols() == B.NumCols() &&
57  A->NumRows() == C.NumRows());
58  int32 num_extra_rows = B.NumRows() - A->NumRows(),
59  context_dim = C.NumCols();
60  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
61  int32 row_shift = num_extra_rows / (context_dim - 1);
62  for (int32 i = 0; i < A->NumRows(); i++) {
63  for (int32 j = 0; j < A->NumCols(); j++) {
64  for (int32 k = 0; k < context_dim; k++) {
65  (*A)(i, j) += alpha * C(i, k) * B(i + (k * row_shift), j);
66  }
67  }
68  }
69 }
kaldi::int32 int32
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215

◆ AttentionBackward()

void AttentionBackward ( BaseFloat  key_scale,
const CuMatrixBase< BaseFloat > &  keys,
const CuMatrixBase< BaseFloat > &  queries,
const CuMatrixBase< BaseFloat > &  values,
const CuMatrixBase< BaseFloat > &  c,
const CuMatrixBase< BaseFloat > &  output_deriv,
CuMatrixBase< BaseFloat > *  keys_deriv,
CuMatrixBase< BaseFloat > *  queries_deriv,
CuMatrixBase< BaseFloat > *  values_deriv 
)

Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to the keys, queries and values.

The interface should be easy to understand with reference to AttentionForward(), so we won't document it, except to note that 'keys_deriv', 'queries_deriv' and 'values_deriv' are added to*, not set, by this function.

Definition at line 154 of file attention.cc.

References CuMatrixBase< Real >::AddMat(), ApplyScalesToInput(), ApplyScalesToOutput(), CuMatrixBase< Real >::DiffSoftmaxPerRow(), GetAttentionDotProducts(), KALDI_ASSERT, kaldi::kUndefined, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), and kaldi::SameDim().

Referenced by RestrictedAttentionComponent::BackpropOneHead(), and TestAttentionForwardBackward().

162  {
163 
164  // First check the dimensions and values.
165  KALDI_ASSERT(key_scale > 0.0);
166  int32 num_input_rows = keys.NumRows(),
167  key_dim = keys.NumCols(),
168  num_output_rows = queries.NumRows(),
169  context_dim = queries.NumCols() - key_dim,
170  value_dim = values.NumCols();
171  KALDI_ASSERT(num_input_rows > 0 && key_dim > 0 &&
172  num_input_rows > num_output_rows &&
173  context_dim > 0 &&
174  (num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
175  values.NumRows() == num_input_rows);
176  KALDI_ASSERT(SameDim(keys, *keys_deriv) &&
177  SameDim(queries, *queries_deriv) &&
178  SameDim(values, *values_deriv));
179 
180  KALDI_ASSERT(c.NumRows() == num_output_rows &&
181  c.NumCols() == context_dim);
182  KALDI_ASSERT(output_deriv.NumRows() == num_output_rows &&
183  (output_deriv.NumCols() == value_dim ||
184  output_deriv.NumCols() == value_dim + context_dim));
185 
186  CuMatrix<BaseFloat> c_deriv(num_output_rows, context_dim,
187  kUndefined);
188 
189  CuSubMatrix<BaseFloat> output_values_part_deriv(
190  output_deriv, 0, num_output_rows, 0, value_dim);
191  // This is the backprop w.r.t. the forward-pass statement:
192  // ApplyScalesToOutput(1.0, values, *c, &output_values_part);
193  GetAttentionDotProducts(1.0, output_values_part_deriv,
194  values, &c_deriv);
195 
196  if (output_deriv.NumCols() == value_dim + context_dim) {
197  CuSubMatrix<BaseFloat> output_deriv_context_part(
198  output_deriv, 0, num_output_rows, value_dim, context_dim);
199  // this is the backprop w.r.t. the
200  // forward-pass statement: output_context_part.CopyFromMat(*c);
201  c_deriv.AddMat(1.0, output_deriv_context_part);
202  }
203 
204  // Propagate the derivatives back through the softmax nonlinearity,
205  // in-place; this is the backprop w.r.t. the statement
206  // 'c->SoftMaxPerRow(*c);'. From this point on, c_deriv actually
207  // contains the derivative to the pre-softmax values which we call
208  // 'b' in the math.
209  c_deriv.DiffSoftmaxPerRow(c, c_deriv);
210 
211 
212  CuSubMatrix<BaseFloat> queries_key_part(
213  queries, 0, num_output_rows,
214  0, key_dim),
215  queries_key_part_deriv(
216  *queries_deriv, 0, num_output_rows,
217  0, key_dim),
218  queries_context_part_deriv(
219  *queries_deriv, 0, num_output_rows,
220  key_dim, context_dim);
221 
222  // Below is the backprop corresponding to the forward-propagation command:
223  // c->AddMat(1.0, queries_context_part)
224  queries_context_part_deriv.AddMat(1.0, c_deriv);
225 
226  // The following statement is the part of the backprop w.r.t. the
227  // statement:
228  // GetAttentionDotProducts(key_scale, queries_key_part, keys, c);
229  // which propagates the derivative back to 'queries_key_part'.
230  ApplyScalesToOutput(key_scale, keys, c_deriv, &queries_key_part_deriv);
231 
232  // The following statement is the part of the backprop w.r.t. the
233  // statement:
234  // GetAttentionDotProducts(key_scale, queries_key_part, keys, c);
235  // which propagates the derivative back to 'keys'.
236  ApplyScalesToInput(key_scale, queries_key_part, c_deriv, keys_deriv);
237 
238  // The followign statement is the part of the backprop w.r.t.
239  // the statement:
240  // ApplyScalesToOutput(1.0, values, *c, &output_values_part);
241  // which propagates the derivative back to 'values'.
242  ApplyScalesToInput(1.0, output_values_part_deriv, c, values_deriv);
243 }
void ApplyScalesToInput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
This function is related to GetAttentionDotProducts(); it is used in backprop.
Definition: attention.cc:76
kaldi::int32 int32
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
Definition: attention.cc:55
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention.
Definition: attention.cc:32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ AttentionForward()

void AttentionForward ( BaseFloat  key_scale,
const CuMatrixBase< BaseFloat > &  keys,
const CuMatrixBase< BaseFloat > &  queries,
const CuMatrixBase< BaseFloat > &  values,
CuMatrixBase< BaseFloat > *  c,
CuMatrixBase< BaseFloat > *  output 
)

This is a higher-level interface to the attention code.

Read the extended comment in the file nnet3/attention.h for context.

Parameters
[in]key_scaleScale on the non-context part of the keys.
[in]keysMatrix whose rows contains the keys, dimension is num-input-rows by key-dim.
[in]queriesMatrix whose rows contains the queries, dimension is num-output-rows by query-dim, where query-dim == key-dim + context-dim. num-output-rows - num-input-rows must be a multiple of context-dim - 1 (we'll 'shift' the keys by multiples of 0, n, 2n, ... (context-dim - 1) * n.
[in]valuesValues to average at the output, of dimension num-input-rows by value-dim. [we may add context information to these averages if required, see comment for 'output'].
[out]cExpected to be finite at entry (no infs or nan's); at exit this will contain the output of the softmax. Must be of dimension num-output-rows by context-dim.
[out]outputThe output of the attention mechanism will be *added* to this location. Dimension must be num-output-rows by either value-dim, or value-dim + context-dim. To the first 'value-dim' columns of this will be added the weighted combination of 'values', weighted by the corresponding weights of 'c' (e.g. the first column of 'c' scaling the first 'output-dim' rows of 'values', then the next column of 'c' scaling the submatrix of 'values' shifted by 'n', and so on. If the output->NumCols() is value-dim + context-dim, 'c' will be added to the remaining columns of 'output'.

Definition at line 97 of file attention.cc.

References CuMatrixBase< Real >::AddMat(), ApplyScalesToOutput(), CuMatrixBase< Real >::CopyFromMat(), GetAttentionDotProducts(), KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), and CuMatrixBase< Real >::SoftMaxPerRow().

Referenced by RestrictedAttentionComponent::PropagateOneHead(), and TestAttentionForwardBackward().

102  {
103  // First check the dimensions and values.
104  KALDI_ASSERT(key_scale > 0.0);
105  int32 num_input_rows = keys.NumRows(),
106  key_dim = keys.NumCols(),
107  num_output_rows = queries.NumRows(),
108  context_dim = queries.NumCols() - key_dim,
109  value_dim = values.NumCols();
110  KALDI_ASSERT(num_input_rows > 0 && key_dim > 0 &&
111  num_input_rows > num_output_rows &&
112  context_dim > 0 &&
113  (num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
114  values.NumRows() == num_input_rows);
115  KALDI_ASSERT(c->NumRows() == num_output_rows &&
116  c->NumCols() == context_dim);
117  KALDI_ASSERT(output->NumRows() == num_output_rows &&
118  (output->NumCols() == value_dim ||
119  output->NumCols() == value_dim + context_dim));
120 
121  CuSubMatrix<BaseFloat> queries_key_part(
122  queries, 0, num_output_rows,
123  0, key_dim),
124  queries_context_part(
125  queries, 0, num_output_rows,
126  key_dim, context_dim);
127 
128  GetAttentionDotProducts(key_scale,
129  queries_key_part,
130  keys, c);
131  // think of 'queries_context_part' as a position-dependent bias term.
132  c->AddMat(1.0, queries_context_part);
133  // compute the soft-max function. Up till this point, 'c'
134  // actually contained what in attention.h we called 'b', which is
135  // the input to the softmax.
136  c->SoftMaxPerRow(*c);
137 
138 
139  // the part of the output that is weighted
140  // combinations of the input values.
141  CuSubMatrix<BaseFloat> output_values_part(
142  *output, 0, num_output_rows, 0, value_dim);
143 
144  ApplyScalesToOutput(1.0, values, *c, &output_values_part);
145 
146 
147  if (output->NumCols() == value_dim + context_dim) {
148  CuSubMatrix<BaseFloat> output_context_part(
149  *output, 0, num_output_rows, value_dim, context_dim);
150  output_context_part.CopyFromMat(*c);
151  }
152 }
kaldi::int32 int32
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
Definition: attention.cc:55
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention.
Definition: attention.cc:32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetAttentionDotProducts()

void GetAttentionDotProducts ( BaseFloat  alpha,
const CuMatrixBase< BaseFloat > &  A,
const CuMatrixBase< BaseFloat > &  B,
CuMatrixBase< BaseFloat > *  C 
)

This function is a utility function that is at the core of how we implement attention.

It may in future need to be renamed and possibly moved into the cudamatrix directory and implemented in CUDA. The current implementation is quite inefficient. We can also consider doing a complete redesign of how the implementation works, such that this function doesn't exist at all; or we could have a batched version of this function that would operate on a batch of A, B and C at once (or a "strided, batched" version where the difference between the members of the batch is expressed as a stride).

This function implements a special operation that you could view as some kind of matrix multiplication where only a band of the product is retained.

The inputs A and B must have the same number of columns (A.NumCols() == B.NumCols()), and A and C must have the same number of rows (A.NumRows() == C->NumRows()). The number of rows of B must exceed the number of rows of A. Define num_extra_rows = B.NumRows() - A.NumRows(). Then C.NumCols() - 1 must divide num_extra_rows. Define row_shift = num_extra_rows / (C.NumCols() - 1).

This function implements: (*C)(i, j) = alpha * VecVec(A.Row(i), B.Row(i + j * row_shift))

Definition at line 32 of file attention.cc.

References CuVectorBase< Real >::AddDiagMatMat(), CuMatrixBase< Real >::CopyFromMat(), KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().

Referenced by AttentionBackward(), AttentionForward(), and UnitTestAttentionDotProductAndAddScales().

35  {
36  KALDI_ASSERT(A.NumCols() == B.NumCols() &&
37  A.NumRows() == C->NumRows());
38  int32 num_output_rows = A.NumRows(),
39  input_num_cols = A.NumCols(),
40  num_extra_rows = B.NumRows() - A.NumRows(),
41  context_dim = C->NumCols();
42  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
43  int32 row_shift = num_extra_rows / (context_dim - 1);
44  CuMatrix<BaseFloat> Ctrans(C->NumCols(),
45  C->NumRows());
46  for (int32 o = 0; o < context_dim; o++) {
47  CuSubVector<BaseFloat> c_col(Ctrans, o);
48  CuSubMatrix<BaseFloat> B_part(B, o * row_shift, num_output_rows,
49  0, input_num_cols);
50  c_col.AddDiagMatMat(alpha, A, kNoTrans, B_part, kTrans, 0.0);
51  }
52  C->CopyFromMat(Ctrans, kTrans);
53 }
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetAttentionDotProductsSimple()

void kaldi::nnet3::attention::GetAttentionDotProductsSimple ( BaseFloat  alpha,
const CuMatrixBase< BaseFloat > &  A,
const CuMatrixBase< BaseFloat > &  B,
CuMatrixBase< BaseFloat > *  C 
)

Definition at line 30 of file attention-test.cc.

References rnnlm::i, rnnlm::j, KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().

Referenced by UnitTestAttentionDotProductAndAddScales().

33  {
34  KALDI_ASSERT(A.NumCols() == B.NumCols() &&
35  A.NumRows() == C->NumRows());
36  int32 input_num_cols = A.NumCols(),
37  num_extra_rows = B.NumRows() - A.NumRows(),
38  context_dim = C->NumCols();
39  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
40  int32 row_shift = num_extra_rows / (context_dim - 1);
41  for (int32 i = 0; i < C->NumRows(); i++) {
42  for (int32 j = 0; j < C->NumCols(); j++) {
43  (*C)(i, j) = 0.0;
44  for (int32 k = 0; k < input_num_cols; k++) {
45  (*C)(i, j) += alpha * A(i, k) * B(i + (j * row_shift), k);
46  }
47  }
48  }
49 }
kaldi::int32 int32
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215

◆ TestAttentionForwardBackward()

void kaldi::nnet3::attention::TestAttentionForwardBackward ( )

Definition at line 120 of file attention-test.cc.

References CuMatrixBase< Real >::AddMat(), AttentionBackward(), AttentionForward(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, kaldi::kTrans, CuMatrixBase< Real >::NumCols(), kaldi::RandInt(), CuMatrixBase< Real >::Scale(), CuMatrixBase< Real >::SetRandn(), CuMatrixBase< Real >::SetZero(), and kaldi::TraceMatMat().

Referenced by UnitTestAttention().

120  {
121  BaseFloat key_scale = 0.5 * RandInt(1, 3);
122  BaseFloat epsilon = 1.0e-03;
123  int32 test_dim = 3;
124  bool output_context = (RandInt(0, 1) == 0);
125  int32 output_num_rows = RandInt(1, 50),
126  value_dim = RandInt(10, 30), key_dim = RandInt(10, 30),
127  row_shift = RandInt(1, 5), context_dim = RandInt(2, 5),
128  num_extra_rows = (context_dim - 1) * row_shift,
129  input_num_rows = output_num_rows + num_extra_rows,
130  query_dim = key_dim + context_dim;
131  CuMatrix<BaseFloat> keys(input_num_rows, key_dim),
132  queries(output_num_rows, query_dim),
133  values(input_num_rows, value_dim),
134  C(output_num_rows, context_dim),
135  output(output_num_rows, value_dim + (output_context ? context_dim : 0));
136 
137 
138  keys.SetRandn();
139  queries.SetRandn();
140  values.SetRandn();
141 
142 
143  AttentionForward(key_scale, keys, queries, values, &C, &output);
144 
145  CuMatrix<BaseFloat> keys_deriv(input_num_rows, key_dim),
146  queries_deriv(output_num_rows, query_dim),
147  values_deriv(input_num_rows, value_dim),
148  output_deriv(output_num_rows, output.NumCols());
149 
150  output_deriv.SetRandn();
151 
152  AttentionBackward(key_scale, keys, queries, values, C,
153  output_deriv, &keys_deriv, &queries_deriv,
154  &values_deriv);
155 
156  BaseFloat objf_baseline = TraceMatMat(output_deriv, output, kTrans);
157 
158 
159 
160 
161  { // perturb the values and see if the objf changes as predicted.
162  Vector<BaseFloat> predicted_vec(test_dim), observed_vec(test_dim);
163  for (int32 i = 0; i < test_dim; i++) {
164  CuMatrix<BaseFloat> values2(input_num_rows, value_dim);
165  values2.SetRandn();
166  values2.Scale(epsilon);
167  BaseFloat predicted_delta_objf = TraceMatMat(values_deriv, values2, kTrans);
168  values2.AddMat(1.0, values);
169 
170  output.SetZero();
171  AttentionForward(key_scale, keys, queries, values2, &C, &output);
172  BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans),
173  observed_delta_objf = objf2 - objf_baseline;
174  KALDI_LOG << "Changing values: predicted objf change is "
175  << predicted_delta_objf << ", observed objf change is "
176  << observed_delta_objf;
177  predicted_vec(i) = predicted_delta_objf;
178  observed_vec(i) = observed_delta_objf;
179  }
180  KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
181  }
182 
183  { // perturb the keys and see if the objf changes as predicted.
184  Vector<BaseFloat> predicted_vec(test_dim), observed_vec(test_dim);
185  for (int32 i = 0; i < test_dim; i++) {
186  CuMatrix<BaseFloat> keys2(input_num_rows, key_dim);
187  keys2.SetRandn();
188  keys2.Scale(epsilon);
189  BaseFloat predicted_delta_objf = TraceMatMat(keys_deriv, keys2, kTrans);
190  keys2.AddMat(1.0, keys);
191 
192  output.SetZero();
193  AttentionForward(key_scale, keys2, queries, values, &C, &output);
194  BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans),
195  observed_delta_objf = objf2 - objf_baseline;
196  KALDI_LOG << "Changing keys: predicted objf change is "
197  << predicted_delta_objf << ", observed objf change is "
198  << observed_delta_objf;
199  predicted_vec(i) = predicted_delta_objf;
200  observed_vec(i) = observed_delta_objf;
201  }
202  KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
203  }
204 
205 
206  { // perturb the queries and see if the objf changes as predicted.
207  Vector<BaseFloat> predicted_vec(test_dim), observed_vec(test_dim);
208  for (int32 i = 0; i < test_dim; i++) {
209  CuMatrix<BaseFloat> queries2(output_num_rows, query_dim);
210  queries2.SetRandn();
211  queries2.Scale(epsilon);
212  BaseFloat predicted_delta_objf = TraceMatMat(queries_deriv, queries2, kTrans);
213  queries2.AddMat(1.0, queries);
214 
215  output.SetZero();
216  AttentionForward(key_scale, keys, queries2, values, &C, &output);
217  BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans),
218  observed_delta_objf = objf2 - objf_baseline;
219  KALDI_LOG << "Changing queries: predicted objf change is "
220  << predicted_delta_objf << ", observed objf change is "
221  << observed_delta_objf;
222  predicted_vec(i) = predicted_delta_objf;
223  observed_vec(i) = observed_delta_objf;
224  }
225  KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
226  }
227 }
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to &#39;AttentionForward&#39;, propagating the derivative back to th...
Definition: attention.cc:154
kaldi::int32 int32
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code.
Definition: attention.cc:97
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
float BaseFloat
Definition: kaldi-types.h:29
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95

◆ UnitTestAttention()

void kaldi::nnet3::attention::UnitTestAttention ( )

◆ UnitTestAttentionDotProductAndAddScales()

void kaldi::nnet3::attention::UnitTestAttentionDotProductAndAddScales ( )

Definition at line 91 of file attention-test.cc.

References ApplyScalesToInput(), ApplyScalesToInputSimple(), ApplyScalesToOutput(), ApplyScalesToOutputSimple(), kaldi::AssertEqual(), GetAttentionDotProducts(), GetAttentionDotProductsSimple(), kaldi::RandInt(), and CuMatrixBase< Real >::SetRandn().

Referenced by UnitTestAttention().

91  {
92  int32 output_num_rows = RandInt(1, 50), input_num_cols = RandInt(1, 10),
93  row_shift = RandInt(1, 5), context_dim = RandInt(2, 5),
94  num_extra_rows = (context_dim - 1) * row_shift,
95  input_num_rows = output_num_rows + num_extra_rows;
96  BaseFloat alpha = 0.25 * RandInt(1, 5);
97  CuMatrix<BaseFloat> A(output_num_rows, input_num_cols),
98  B(input_num_rows, input_num_cols),
99  C(output_num_rows, context_dim);
100 
101  B.SetRandn();
102  C.SetRandn();
103  A.Set(0.0);
104  CuMatrix<BaseFloat> A2(A);
105  ApplyScalesToOutput(alpha, B, C, &A);
106  ApplyScalesToOutputSimple(alpha, B, C, &A2);
107  AssertEqual(A, A2);
108 
109  CuMatrix<BaseFloat> C2(C);
110  GetAttentionDotProductsSimple(alpha, A, B, &C);
111  GetAttentionDotProducts(alpha, A, B, &C2);
112  AssertEqual(C, C2);
113 
114  CuMatrix<BaseFloat> B2(B);
115  ApplyScalesToInput(alpha, A, C, &B);
116  ApplyScalesToInputSimple(alpha, A, C, &B2);
117  AssertEqual(B, B2);
118 }
void ApplyScalesToInputSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
void ApplyScalesToInput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
This function is related to GetAttentionDotProducts(); it is used in backprop.
Definition: attention.cc:76
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void ApplyScalesToOutputSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
float BaseFloat
Definition: kaldi-types.h:29
void GetAttentionDotProductsSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
Definition: attention.cc:55
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention.
Definition: attention.cc:32
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95