attention.cc
Go to the documentation of this file.
1 // nnet3/attention.cc
2 
3 // Copyright 2017 Johns Hopkins University (author: Daniel Povey)
4 // Hossein Hadian
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <iterator>
22 #include <sstream>
23 #include <iomanip>
24 #include "nnet3/attention.h"
25 #include "nnet3/nnet-parse.h"
26 
27 namespace kaldi {
28 namespace nnet3 {
29 namespace attention {
30 
31 
33  const CuMatrixBase<BaseFloat> &A,
34  const CuMatrixBase<BaseFloat> &B,
36  KALDI_ASSERT(A.NumCols() == B.NumCols() &&
37  A.NumRows() == C->NumRows());
38  int32 num_output_rows = A.NumRows(),
39  input_num_cols = A.NumCols(),
40  num_extra_rows = B.NumRows() - A.NumRows(),
41  context_dim = C->NumCols();
42  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
43  int32 row_shift = num_extra_rows / (context_dim - 1);
44  CuMatrix<BaseFloat> Ctrans(C->NumCols(),
45  C->NumRows());
46  for (int32 o = 0; o < context_dim; o++) {
47  CuSubVector<BaseFloat> c_col(Ctrans, o);
48  CuSubMatrix<BaseFloat> B_part(B, o * row_shift, num_output_rows,
49  0, input_num_cols);
50  c_col.AddDiagMatMat(alpha, A, kNoTrans, B_part, kTrans, 0.0);
51  }
52  C->CopyFromMat(Ctrans, kTrans);
53 }
54 
56  const CuMatrixBase<BaseFloat> &B,
57  const CuMatrixBase<BaseFloat> &C,
59  KALDI_ASSERT(A->NumCols() == B.NumCols() &&
60  A->NumRows() == C.NumRows());
61  int32 num_output_rows = A->NumRows(),
62  input_num_cols = A->NumCols(),
63  num_extra_rows = B.NumRows() - A->NumRows(),
64  context_dim = C.NumCols();
65  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
66  int32 row_shift = num_extra_rows / (context_dim - 1);
67  CuMatrix<BaseFloat> Ctrans(C, kTrans);
68  for (int32 o = 0; o < context_dim; o++) {
69  CuSubVector<BaseFloat> c_col(Ctrans, o);
70  CuSubMatrix<BaseFloat> B_part(B, o * row_shift, num_output_rows,
71  0, input_num_cols);
72  A->AddDiagVecMat(alpha, c_col, B_part, kNoTrans, 1.0);
73  }
74 }
75 
77  const CuMatrixBase<BaseFloat> &A,
78  const CuMatrixBase<BaseFloat> &C,
80  KALDI_ASSERT(A.NumCols() == B->NumCols() &&
81  A.NumRows() == C.NumRows());
82  int32 num_output_rows = A.NumRows(),
83  input_num_cols = A.NumCols(),
84  num_extra_rows = B->NumRows() - A.NumRows(),
85  context_dim = C.NumCols();
86  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
87  int32 row_shift = num_extra_rows / (context_dim - 1);
88  CuMatrix<BaseFloat> Ctrans(C, kTrans);
89  for (int32 o = 0; o < context_dim; o++) {
90  CuSubVector<BaseFloat> c_col(Ctrans, o);
91  CuSubMatrix<BaseFloat> B_part(*B, o * row_shift, num_output_rows,
92  0, input_num_cols);
93  B_part.AddDiagVecMat(alpha, c_col, A, kNoTrans, 1.0);
94  }
95 }
96 
97 void AttentionForward(BaseFloat key_scale,
98  const CuMatrixBase<BaseFloat> &keys,
99  const CuMatrixBase<BaseFloat> &queries,
100  const CuMatrixBase<BaseFloat> &values,
102  CuMatrixBase<BaseFloat> *output) {
103  // First check the dimensions and values.
104  KALDI_ASSERT(key_scale > 0.0);
105  int32 num_input_rows = keys.NumRows(),
106  key_dim = keys.NumCols(),
107  num_output_rows = queries.NumRows(),
108  context_dim = queries.NumCols() - key_dim,
109  value_dim = values.NumCols();
110  KALDI_ASSERT(num_input_rows > 0 && key_dim > 0 &&
111  num_input_rows > num_output_rows &&
112  context_dim > 0 &&
113  (num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
114  values.NumRows() == num_input_rows);
115  KALDI_ASSERT(c->NumRows() == num_output_rows &&
116  c->NumCols() == context_dim);
117  KALDI_ASSERT(output->NumRows() == num_output_rows &&
118  (output->NumCols() == value_dim ||
119  output->NumCols() == value_dim + context_dim));
120 
121  CuSubMatrix<BaseFloat> queries_key_part(
122  queries, 0, num_output_rows,
123  0, key_dim),
124  queries_context_part(
125  queries, 0, num_output_rows,
126  key_dim, context_dim);
127 
128  GetAttentionDotProducts(key_scale,
129  queries_key_part,
130  keys, c);
131  // think of 'queries_context_part' as a position-dependent bias term.
132  c->AddMat(1.0, queries_context_part);
133  // compute the soft-max function. Up till this point, 'c'
134  // actually contained what in attention.h we called 'b', which is
135  // the input to the softmax.
136  c->SoftMaxPerRow(*c);
137 
138 
139  // the part of the output that is weighted
140  // combinations of the input values.
141  CuSubMatrix<BaseFloat> output_values_part(
142  *output, 0, num_output_rows, 0, value_dim);
143 
144  ApplyScalesToOutput(1.0, values, *c, &output_values_part);
145 
146 
147  if (output->NumCols() == value_dim + context_dim) {
148  CuSubMatrix<BaseFloat> output_context_part(
149  *output, 0, num_output_rows, value_dim, context_dim);
150  output_context_part.CopyFromMat(*c);
151  }
152 }
153 
155  const CuMatrixBase<BaseFloat> &keys,
156  const CuMatrixBase<BaseFloat> &queries,
157  const CuMatrixBase<BaseFloat> &values,
158  const CuMatrixBase<BaseFloat> &c,
159  const CuMatrixBase<BaseFloat> &output_deriv,
160  CuMatrixBase<BaseFloat> *keys_deriv,
161  CuMatrixBase<BaseFloat> *queries_deriv,
162  CuMatrixBase<BaseFloat> *values_deriv) {
163 
164  // First check the dimensions and values.
165  KALDI_ASSERT(key_scale > 0.0);
166  int32 num_input_rows = keys.NumRows(),
167  key_dim = keys.NumCols(),
168  num_output_rows = queries.NumRows(),
169  context_dim = queries.NumCols() - key_dim,
170  value_dim = values.NumCols();
171  KALDI_ASSERT(num_input_rows > 0 && key_dim > 0 &&
172  num_input_rows > num_output_rows &&
173  context_dim > 0 &&
174  (num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
175  values.NumRows() == num_input_rows);
176  KALDI_ASSERT(SameDim(keys, *keys_deriv) &&
177  SameDim(queries, *queries_deriv) &&
178  SameDim(values, *values_deriv));
179 
180  KALDI_ASSERT(c.NumRows() == num_output_rows &&
181  c.NumCols() == context_dim);
182  KALDI_ASSERT(output_deriv.NumRows() == num_output_rows &&
183  (output_deriv.NumCols() == value_dim ||
184  output_deriv.NumCols() == value_dim + context_dim));
185 
186  CuMatrix<BaseFloat> c_deriv(num_output_rows, context_dim,
187  kUndefined);
188 
189  CuSubMatrix<BaseFloat> output_values_part_deriv(
190  output_deriv, 0, num_output_rows, 0, value_dim);
191  // This is the backprop w.r.t. the forward-pass statement:
192  // ApplyScalesToOutput(1.0, values, *c, &output_values_part);
193  GetAttentionDotProducts(1.0, output_values_part_deriv,
194  values, &c_deriv);
195 
196  if (output_deriv.NumCols() == value_dim + context_dim) {
197  CuSubMatrix<BaseFloat> output_deriv_context_part(
198  output_deriv, 0, num_output_rows, value_dim, context_dim);
199  // this is the backprop w.r.t. the
200  // forward-pass statement: output_context_part.CopyFromMat(*c);
201  c_deriv.AddMat(1.0, output_deriv_context_part);
202  }
203 
204  // Propagate the derivatives back through the softmax nonlinearity,
205  // in-place; this is the backprop w.r.t. the statement
206  // 'c->SoftMaxPerRow(*c);'. From this point on, c_deriv actually
207  // contains the derivative to the pre-softmax values which we call
208  // 'b' in the math.
209  c_deriv.DiffSoftmaxPerRow(c, c_deriv);
210 
211 
212  CuSubMatrix<BaseFloat> queries_key_part(
213  queries, 0, num_output_rows,
214  0, key_dim),
215  queries_key_part_deriv(
216  *queries_deriv, 0, num_output_rows,
217  0, key_dim),
218  queries_context_part_deriv(
219  *queries_deriv, 0, num_output_rows,
220  key_dim, context_dim);
221 
222  // Below is the backprop corresponding to the forward-propagation command:
223  // c->AddMat(1.0, queries_context_part)
224  queries_context_part_deriv.AddMat(1.0, c_deriv);
225 
226  // The following statement is the part of the backprop w.r.t. the
227  // statement:
228  // GetAttentionDotProducts(key_scale, queries_key_part, keys, c);
229  // which propagates the derivative back to 'queries_key_part'.
230  ApplyScalesToOutput(key_scale, keys, c_deriv, &queries_key_part_deriv);
231 
232  // The following statement is the part of the backprop w.r.t. the
233  // statement:
234  // GetAttentionDotProducts(key_scale, queries_key_part, keys, c);
235  // which propagates the derivative back to 'keys'.
236  ApplyScalesToInput(key_scale, queries_key_part, c_deriv, keys_deriv);
237 
238  // The followign statement is the part of the backprop w.r.t.
239  // the statement:
240  // ApplyScalesToOutput(1.0, values, *c, &output_values_part);
241  // which propagates the derivative back to 'values'.
242  ApplyScalesToInput(1.0, output_values_part_deriv, c, values_deriv);
243 }
244 
245 } // namespace attention
246 } // namespace nnet3
247 } // namespace kaldi
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:344
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to &#39;AttentionForward&#39;, propagating the derivative back to th...
Definition: attention.cc:154
void ApplyScalesToInput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
This function is related to GetAttentionDotProducts(); it is used in backprop.
Definition: attention.cc:76
kaldi::int32 int32
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code.
Definition: attention.cc:97
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
Definition: cu-matrix.cc:954
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void AddDiagMatMat(Real alpha, const CuMatrixBase< Real > &M, MatrixTransposeType transM, const CuMatrixBase< Real > &N, MatrixTransposeType transN, Real beta=1.0)
Add the diagonal of a matrix product: *this = diag(M N), assuming the "trans" arguments are both kNoT...
Definition: cu-vector.cc:611
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
float BaseFloat
Definition: kaldi-types.h:29
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
Definition: attention.cc:55
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention.
Definition: attention.cc:32
void SoftMaxPerRow(const CuMatrixBase< Real > &src)
Softmax nonlinearity Y = Softmax(X) : Yij = e^Xij / sum_k(e^Xik), done to each row, with attention to avoiding overflow or underflow.
Definition: cu-matrix.cc:1717
This class is used for a piece of a CuMatrix.
Definition: matrix-common.h:70
void DiffSoftmaxPerRow(const CuMatrixBase< Real > &value, const CuMatrixBase< Real > &diff)
Differentiate backward through the softmax function.
Definition: cu-matrix.cc:1868
Matrix for CUDA computing.
Definition: matrix-common.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
This file contains the lower-level interface for self-attention.
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
void AddDiagVecMat(const Real alpha, const CuVectorBase< Real > &v, const CuMatrixBase< Real > &M, MatrixTransposeType transM, Real beta=1.0)
*this = beta * *this + alpha * diag(v) * M [or M^T].
Definition: cu-matrix.cc:1382