attention-test.cc
Go to the documentation of this file.
1 // nnet3/attention-test.cc
2 
3 // Copyright 2017 Hossein Hadian
4 // 2017 Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "nnet3/attention.h"
22 #include "util/common-utils.h"
23 
24 namespace kaldi {
25 namespace nnet3 {
26 namespace attention {
27 
28 
29 // (*C)(i, j) = alpha * VecVec(A.Row(i), B.Row(i + j * row_shift))
31  const CuMatrixBase<BaseFloat> &A,
32  const CuMatrixBase<BaseFloat> &B,
34  KALDI_ASSERT(A.NumCols() == B.NumCols() &&
35  A.NumRows() == C->NumRows());
36  int32 input_num_cols = A.NumCols(),
37  num_extra_rows = B.NumRows() - A.NumRows(),
38  context_dim = C->NumCols();
39  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
40  int32 row_shift = num_extra_rows / (context_dim - 1);
41  for (int32 i = 0; i < C->NumRows(); i++) {
42  for (int32 j = 0; j < C->NumCols(); j++) {
43  (*C)(i, j) = 0.0;
44  for (int32 k = 0; k < input_num_cols; k++) {
45  (*C)(i, j) += alpha * A(i, k) * B(i + (j * row_shift), k);
46  }
47  }
48  }
49 }
50 
51 // A->Row(i) += \sum_k alpha * C(i, k) * B.Row(i + k * row_shift).
53  const CuMatrixBase<BaseFloat> &B,
54  const CuMatrixBase<BaseFloat> &C,
56  KALDI_ASSERT(A->NumCols() == B.NumCols() &&
57  A->NumRows() == C.NumRows());
58  int32 num_extra_rows = B.NumRows() - A->NumRows(),
59  context_dim = C.NumCols();
60  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
61  int32 row_shift = num_extra_rows / (context_dim - 1);
62  for (int32 i = 0; i < A->NumRows(); i++) {
63  for (int32 j = 0; j < A->NumCols(); j++) {
64  for (int32 k = 0; k < context_dim; k++) {
65  (*A)(i, j) += alpha * C(i, k) * B(i + (k * row_shift), j);
66  }
67  }
68  }
69 }
70 
71 // B->Row(i + j * row_shift) += alpha * C(i, j) * A.Row(i).
73  const CuMatrixBase<BaseFloat> &A,
74  const CuMatrixBase<BaseFloat> &C,
76  KALDI_ASSERT(A.NumCols() == B->NumCols() &&
77  A.NumRows() == C.NumRows());
78  int32 num_extra_rows = B->NumRows() - A.NumRows(),
79  context_dim = C.NumCols();
80  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
81  int32 row_shift = num_extra_rows / (context_dim - 1);
82  for (int32 i = 0; i < A.NumRows(); i++) {
83  for (int32 j = 0; j < A.NumCols(); j++) {
84  for (int32 k = 0; k < context_dim; k++) {
85  (*B)(i + (k * row_shift), j) += alpha * C(i, k) * A(i, j);
86  }
87  }
88  }
89 }
90 
92  int32 output_num_rows = RandInt(1, 50), input_num_cols = RandInt(1, 10),
93  row_shift = RandInt(1, 5), context_dim = RandInt(2, 5),
94  num_extra_rows = (context_dim - 1) * row_shift,
95  input_num_rows = output_num_rows + num_extra_rows;
96  BaseFloat alpha = 0.25 * RandInt(1, 5);
97  CuMatrix<BaseFloat> A(output_num_rows, input_num_cols),
98  B(input_num_rows, input_num_cols),
99  C(output_num_rows, context_dim);
100 
101  B.SetRandn();
102  C.SetRandn();
103  A.Set(0.0);
104  CuMatrix<BaseFloat> A2(A);
105  ApplyScalesToOutput(alpha, B, C, &A);
106  ApplyScalesToOutputSimple(alpha, B, C, &A2);
107  AssertEqual(A, A2);
108 
109  CuMatrix<BaseFloat> C2(C);
110  GetAttentionDotProductsSimple(alpha, A, B, &C);
111  GetAttentionDotProducts(alpha, A, B, &C2);
112  AssertEqual(C, C2);
113 
114  CuMatrix<BaseFloat> B2(B);
115  ApplyScalesToInput(alpha, A, C, &B);
116  ApplyScalesToInputSimple(alpha, A, C, &B2);
117  AssertEqual(B, B2);
118 }
119 
121  BaseFloat key_scale = 0.5 * RandInt(1, 3);
122  BaseFloat epsilon = 1.0e-03;
123  int32 test_dim = 3;
124  bool output_context = (RandInt(0, 1) == 0);
125  int32 output_num_rows = RandInt(1, 50),
126  value_dim = RandInt(10, 30), key_dim = RandInt(10, 30),
127  row_shift = RandInt(1, 5), context_dim = RandInt(2, 5),
128  num_extra_rows = (context_dim - 1) * row_shift,
129  input_num_rows = output_num_rows + num_extra_rows,
130  query_dim = key_dim + context_dim;
131  CuMatrix<BaseFloat> keys(input_num_rows, key_dim),
132  queries(output_num_rows, query_dim),
133  values(input_num_rows, value_dim),
134  C(output_num_rows, context_dim),
135  output(output_num_rows, value_dim + (output_context ? context_dim : 0));
136 
137 
138  keys.SetRandn();
139  queries.SetRandn();
140  values.SetRandn();
141 
142 
143  AttentionForward(key_scale, keys, queries, values, &C, &output);
144 
145  CuMatrix<BaseFloat> keys_deriv(input_num_rows, key_dim),
146  queries_deriv(output_num_rows, query_dim),
147  values_deriv(input_num_rows, value_dim),
148  output_deriv(output_num_rows, output.NumCols());
149 
150  output_deriv.SetRandn();
151 
152  AttentionBackward(key_scale, keys, queries, values, C,
153  output_deriv, &keys_deriv, &queries_deriv,
154  &values_deriv);
155 
156  BaseFloat objf_baseline = TraceMatMat(output_deriv, output, kTrans);
157 
158 
159 
160 
161  { // perturb the values and see if the objf changes as predicted.
162  Vector<BaseFloat> predicted_vec(test_dim), observed_vec(test_dim);
163  for (int32 i = 0; i < test_dim; i++) {
164  CuMatrix<BaseFloat> values2(input_num_rows, value_dim);
165  values2.SetRandn();
166  values2.Scale(epsilon);
167  BaseFloat predicted_delta_objf = TraceMatMat(values_deriv, values2, kTrans);
168  values2.AddMat(1.0, values);
169 
170  output.SetZero();
171  AttentionForward(key_scale, keys, queries, values2, &C, &output);
172  BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans),
173  observed_delta_objf = objf2 - objf_baseline;
174  KALDI_LOG << "Changing values: predicted objf change is "
175  << predicted_delta_objf << ", observed objf change is "
176  << observed_delta_objf;
177  predicted_vec(i) = predicted_delta_objf;
178  observed_vec(i) = observed_delta_objf;
179  }
180  KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
181  }
182 
183  { // perturb the keys and see if the objf changes as predicted.
184  Vector<BaseFloat> predicted_vec(test_dim), observed_vec(test_dim);
185  for (int32 i = 0; i < test_dim; i++) {
186  CuMatrix<BaseFloat> keys2(input_num_rows, key_dim);
187  keys2.SetRandn();
188  keys2.Scale(epsilon);
189  BaseFloat predicted_delta_objf = TraceMatMat(keys_deriv, keys2, kTrans);
190  keys2.AddMat(1.0, keys);
191 
192  output.SetZero();
193  AttentionForward(key_scale, keys2, queries, values, &C, &output);
194  BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans),
195  observed_delta_objf = objf2 - objf_baseline;
196  KALDI_LOG << "Changing keys: predicted objf change is "
197  << predicted_delta_objf << ", observed objf change is "
198  << observed_delta_objf;
199  predicted_vec(i) = predicted_delta_objf;
200  observed_vec(i) = observed_delta_objf;
201  }
202  KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
203  }
204 
205 
206  { // perturb the queries and see if the objf changes as predicted.
207  Vector<BaseFloat> predicted_vec(test_dim), observed_vec(test_dim);
208  for (int32 i = 0; i < test_dim; i++) {
209  CuMatrix<BaseFloat> queries2(output_num_rows, query_dim);
210  queries2.SetRandn();
211  queries2.Scale(epsilon);
212  BaseFloat predicted_delta_objf = TraceMatMat(queries_deriv, queries2, kTrans);
213  queries2.AddMat(1.0, queries);
214 
215  output.SetZero();
216  AttentionForward(key_scale, keys, queries2, values, &C, &output);
217  BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans),
218  observed_delta_objf = objf2 - objf_baseline;
219  KALDI_LOG << "Changing queries: predicted objf change is "
220  << predicted_delta_objf << ", observed objf change is "
221  << observed_delta_objf;
222  predicted_vec(i) = predicted_delta_objf;
223  observed_vec(i) = observed_delta_objf;
224  }
225  KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
226  }
227 }
228 
232 }
233 
234 
235 } // namespace attention
236 } // namespace nnet3
237 } // namespace kaldi
238 
239 
240 int main() {
241  using namespace kaldi;
242  using namespace kaldi::nnet3;
243  using namespace kaldi::nnet3::attention;
244  for (int32 loop = 0; loop < 2; loop++) {
245 #if HAVE_CUDA == 1
246  CuDevice::Instantiate().SetDebugStrideMode(true);
247  if (loop == 0)
248  CuDevice::Instantiate().SelectGpuId("no"); // -1 means no GPU
249  else
250  CuDevice::Instantiate().SelectGpuId("optional"); // -2 .. automatic selection
251 #endif
252  for (int32 i = 0; i < 5; i++) {
254  }
255  }
256 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to &#39;AttentionForward&#39;, propagating the derivative back to th...
Definition: attention.cc:154
void ApplyScalesToInputSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
void ApplyScalesToInput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
This function is related to GetAttentionDotProducts(); it is used in backprop.
Definition: attention.cc:76
kaldi::int32 int32
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code.
Definition: attention.cc:97
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
Definition: cu-matrix.cc:954
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void ApplyScalesToOutputSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
void Scale(Real value)
Definition: cu-matrix.cc:644
void GetAttentionDotProductsSimple(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
Definition: attention.cc:55
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention.
Definition: attention.cc:32
void SetZero()
Math operations, some calling kernels.
Definition: cu-matrix.cc:509
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
Matrix for CUDA computing.
Definition: matrix-common.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
A class representing a vector.
Definition: kaldi-vector.h:406
int main()
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void UnitTestAttentionDotProductAndAddScales()
This file contains the lower-level interface for self-attention.
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
#define KALDI_LOG
Definition: kaldi-error.h:153
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95