attention.h
Go to the documentation of this file.
1 // nnet3/attention.h
2 
3 // Copyright 2017 Johns Hopkins University (author: Daniel Povey)
4 // Hossein Hadian
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET3_ATTENTION_H_
22 #define KALDI_NNET3_ATTENTION_H_
23 
24 #include "base/kaldi-common.h"
25 #include "util/common-utils.h"
26 #include "itf/options-itf.h"
27 #include "matrix/matrix-lib.h"
29 #include "nnet3/nnet-common.h"
30 #include "nnet3/convolution.h"
31 
32 #include <iostream>
33 
34 namespace kaldi {
35 namespace nnet3 {
36 namespace attention {
37 
48 
49 
50 
51 // Our attention is "multi-head", like in Google's paper. Note: we're basically
52 // implementing multi-head attention as a fixed nonlinearity, with the actual
53 // parameters relegated to the previous layer. That is, the attention layer
54 // won't have any parameters of its own, but the parameters of the preceding
55 // layer will be interpretable as the parameters. It doesn't change what's
56 // computed, it just affects how the neural net is divided into components.
57 //
58 // * Basic restricted self-attention (without positional encoding).
59 //
60 // To explain what's going on, we start with the simplest form of attention:
61 // single-head, and no positional encoding, but with restricted context. For purposes
62 // of exposition we assume that the time offsets we need form a contiguous
63 // range, i.e. with time-stride == 1; the code does have the notion of a stride (you'll
64 // see later).
65 //
66 // Using notation similar to the Google paper, suppose we have a time-sequence
67 // of inputs, and the inputs are (keys, values and queries):
68 //
69 // k_t, v_t, q_t
70 //
71 // where k_t and q_t are vectors of dimension 'key_dim' and v_t is a vector
72 // of dimension 'value_dim' (you may choose to make this the same as key_dim, but
73 // that isn't a constraint).
74 
75 // Let's make num_left_inputs and num_right_inputs be the number of
76 // left-context and right-context frames required, and for some t,
77 // let input_indexes(t) be the set
78 // [ t - num_left_inputs, t - num_left_inputs + 1, ... t + num_right_inputs].
79 // To evaluate the output (which we'll write u_t), we need the query
80 // value q_t, plus the keys and values k_s and v_s for all s in input_indexes(t).
81 // If the inputs are not available for some subset of input_indexes(t),
82 // we just let them be zeros; the network can learn to ignore them if it wants,
83 // but making them zeros is simpler to implement.
84 //
85 //
86 // Anyway, the output u_t (without positional encoding yet) is:
87 //
88 // u_t := \sum_{s in input_indexes(t)} Z_t exp(q_t . k_s) v_s
89 //
90 // where Z_t is 1/(\sum_s exp(q_t . k_s)). We'll handle scaling
91 // issues (the 1/sqrt(dim) factor in the Google paper) later on,
92 // by scaling the keys.
93 //
94 //
95 // * Positional encoding
96 // We now explain how we include positional encoding in the model.
97 //
98 //
99 // Let context_dim = 1 + num_left_inputs + num_right_inputs.
100 // Let v be a vector, and let the function Extend(v, o) (where
101 // 0 <= o < context_dim) extend v with extra dimensions
102 // that encode the time-offset. To be precise, we have
103 //
104 // Extend(v, o) = Append(v, u_o)
105 //
106 // where u_o is a unit vector of dimension context_dim that is nonzero in the
107 // o'th dimension (assuming zero-based indexing).
108 //
109 // So when we add the positional encoding (and the scale on the keys), we replace
110 // the equation:
111 // u_t := \sum_{s in input_indexes(t)} Z_t exp(q_t . k_s) v_s
112 // with:
113 // u_t := \sum_{s in input_indexes(t)} Z_t exp(alpha q_t . Extend(key-scale * k_s, s - t + num_left_inputs)) Extend(v_s, s - t + num_left_inputs)
114 //
115 // (we won't actually physically extend the vectors as we compute this,
116 // we'll do it a different way, but it's equivalent to what we wrote
117 // above. The dimension of q_t is key_dim + context_dim, and the dimension
118 // of the output u_t is value_dim + context_dim.
119 //
120 //
121 // * Multi-head attention
122 //
123 // The attention component if we had a single head, would have an input dimension
124 // of (2*key_dim + context_dim + value_dim), which would be interpreted
125 // as Append(k_t, q_t, v_t), of dimensions respectively
126 // (key_dim, key_dim + context_dim, value_dim). It would have an output
127 // dimension of value_dim + context_dim.
128 //
129 // In any case, the multi-head version has input and output dimension that
130 // are larger by a factor of 'num_heads', and which is equivalent to
131 // several single-head components appended together.
132 //
133 //
134 //
135 // * The actual calculation
136 //
137 // Let's assume that we might have multiple independent sequences; we'll
138 // call this 'num_images' because we're borrowing certain structures from
139 // the convolution code.
140 
141 // The input is formatted as a matrix, whose NumRows() could be written as
142 // num_images * num_t_in, where num_t_in is the number of distinct input 't'
143 // values, and whose output is num_images * num_t_out. To keep it simple we'll
144 // explain this under the assumption that we don't have any 't' stride in the
145 // attention (t_stride == 1 in the code), and that num_heads == 1; both of
146 // those are fairly simple modifications to the basic scheme.
147 // The image (normally 'n') index has a higher stride than the 't' index in
148 // both the input and the output. We assume that there is 'enough'
149 // context of the input to compute all required offsets of the output.
150 //
151 // Define the intermediate quantity b_{t,o}, which you can think of
152 // as the input to softmax; the index 't' is the output time-index
153 // index at the output, and o ranges from 0 to context_dim - 1.
154 //
155 // b_{t,o} = q_t . Extend(key-scale * k_{t + o - num_left_inputs}, o)
156 //
157 // To get rid of the Extend() expressions, define sub-ranges of q_t by
158 // writing q_t = Append(r_t, s_t) where r_t is of dimension 'key_dim'
159 // and s_t is of dimension context_dim.
160 //
161 // b_{t,o} = s_{t,o} + key-scale (r_t . k_{t+o-num_left_inputs}) [eqn:b]
162 //
163 // The 'b' quantity is the input to the softmax. Define
164 // c_t = Softmax(b_t)
165 // so \sum_o c_{t,o} = 1.0. These are the weights on the values.
166 //
167 //
168 // The output can be written as:
169 // u_t := \sum_o c_{t,o} Extend(v_{t+o-num_left_inputs}, o)
170 // but we can write this in a form more suitable for computation as:
171 // u_t := Append(\sum_o c_{t,o} v_{t+o-num_left_inputs}, c_t) [eqn:u]
172 //
173 //
174 // * Implementation
175 //
176 // The most time-consuming parts of this computation, we imagine, would be the
177 // dot-products in [eqn:b] and the weighted sum (\sum_o) in [eqn:u]. They have
178 // an awkward band-diagonal structure that would not be particularly convenient
179 // to implement using CUBLAS and the like; I don't believe the relevant operations
180 // exist in the BLAS interface, at least for [eqn:u].
181 //
182 // In the future I hope to implement this with block-wise matrix
183 // multiplies-- imagine covering the band-diagonal part of a matrix with
184 // rectangular blocks in such a way that all the nonzero elements are covered,
185 // but the blocks might go over the zero parts a bit. This could be done with
186 // Or perhaps we can figure out how to implement the block-diagonal matrix
187 // multiplies in CUDA.
188 
189 
190 
217  const CuMatrixBase<BaseFloat> &A,
218  const CuMatrixBase<BaseFloat> &B,
219  CuMatrixBase<BaseFloat> *C);
220 
221 
236 void ApplyScalesToOutput(BaseFloat alpha,
237  const CuMatrixBase<BaseFloat> &B,
238  const CuMatrixBase<BaseFloat> &C,
239  CuMatrixBase<BaseFloat> *A);
240 
241 
255 void ApplyScalesToInput(BaseFloat alpha,
256  const CuMatrixBase<BaseFloat> &A,
257  const CuMatrixBase<BaseFloat> &C,
258  CuMatrixBase<BaseFloat> *B);
259 
260 
261 
295 void AttentionForward(BaseFloat key_scale,
296  const CuMatrixBase<BaseFloat> &keys,
297  const CuMatrixBase<BaseFloat> &queries,
298  const CuMatrixBase<BaseFloat> &values,
299  CuMatrixBase<BaseFloat> *c,
300  CuMatrixBase<BaseFloat> *output);
301 
310 void AttentionBackward(BaseFloat key_scale,
311  const CuMatrixBase<BaseFloat> &keys,
312  const CuMatrixBase<BaseFloat> &queries,
313  const CuMatrixBase<BaseFloat> &values,
314  const CuMatrixBase<BaseFloat> &c,
315  const CuMatrixBase<BaseFloat> &output_deriv,
316  CuMatrixBase<BaseFloat> *keys_deriv,
317  CuMatrixBase<BaseFloat> *queries_deriv,
318  CuMatrixBase<BaseFloat> *values_deriv);
319 
320 
321 
322 
323 
324 
325 } // namespace attention
326 } // namespace nnet3
327 } // namespace kaldi
328 
329 
330 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to &#39;AttentionForward&#39;, propagating the derivative back to th...
Definition: attention.cc:154
void ApplyScalesToInput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
This function is related to GetAttentionDotProducts(); it is used in backprop.
Definition: attention.cc:76
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code.
Definition: attention.cc:97
float BaseFloat
Definition: kaldi-types.h:29
This file contains some fairly low-level utilities for implementing convolutional neural networks and...
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
Definition: attention.cc:55
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention.
Definition: attention.cc:32