This file contains the lower-level interface for self-attention. More...
#include "base/kaldi-common.h"#include "util/common-utils.h"#include "itf/options-itf.h"#include "matrix/matrix-lib.h"#include "cudamatrix/cu-matrix-lib.h"#include "nnet3/nnet-common.h"#include "nnet3/convolution.h"#include <iostream>

Go to the source code of this file.
Namespaces | |
| kaldi | |
| This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:  | |
| kaldi::nnet3 | |
| kaldi::nnet3::attention | |
Functions | |
| void | GetAttentionDotProducts (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C) | 
| This function is a utility function that is at the core of how we implement attention.  More... | |
| void | ApplyScalesToOutput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A) | 
| This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softmax scales, and in backprop.  More... | |
| void | ApplyScalesToInput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B) | 
| This function is related to GetAttentionDotProducts(); it is used in backprop.  More... | |
| void | AttentionForward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output) | 
| This is a higher-level interface to the attention code.  More... | |
| void | AttentionBackward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv) | 
| Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to the keys, queries and values.  More... | |
This file contains the lower-level interface for self-attention.
This is a form of self-attention, inspired by Google's paper "Attention is all you need", but implemented in a way that's more obviously suitable for speech tasks. The main difference is that instead of taking as input *all frames* from the previous layer, we accept a limited grid of frames (so the left-context and right-context are finite). Also time-encoding is handled in a different way– we encode the time as a relative offset.
Definition in file attention.h.