This file contains the lower-level interface for self-attention. More...
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "itf/options-itf.h"
#include "matrix/matrix-lib.h"
#include "cudamatrix/cu-matrix-lib.h"
#include "nnet3/nnet-common.h"
#include "nnet3/convolution.h"
#include <iostream>
Go to the source code of this file.
Namespaces | |
kaldi | |
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference: | |
kaldi::nnet3 | |
kaldi::nnet3::attention | |
Functions | |
void | GetAttentionDotProducts (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C) |
This function is a utility function that is at the core of how we implement attention. More... | |
void | ApplyScalesToOutput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A) |
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softmax scales, and in backprop. More... | |
void | ApplyScalesToInput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B) |
This function is related to GetAttentionDotProducts(); it is used in backprop. More... | |
void | AttentionForward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output) |
This is a higher-level interface to the attention code. More... | |
void | AttentionBackward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv) |
Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to the keys, queries and values. More... | |
This file contains the lower-level interface for self-attention.
This is a form of self-attention, inspired by Google's paper "Attention is all you need", but implemented in a way that's more obviously suitable for speech tasks. The main difference is that instead of taking as input *all frames* from the previous layer, we accept a limited grid of frames (so the left-context and right-context are finite). Also time-encoding is handled in a different way– we encode the time as a relative offset.
Definition in file attention.h.