attention.h File Reference

This file contains the lower-level interface for self-attention. More...

#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "itf/options-itf.h"
#include "matrix/matrix-lib.h"
#include "cudamatrix/cu-matrix-lib.h"
#include "nnet3/nnet-common.h"
#include "nnet3/convolution.h"
#include <iostream>
Include dependency graph for attention.h:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 
 kaldi::nnet3
 
 kaldi::nnet3::attention
 

Functions

void GetAttentionDotProducts (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
 This function is a utility function that is at the core of how we implement attention. More...
 
void ApplyScalesToOutput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
 This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softmax scales, and in backprop. More...
 
void ApplyScalesToInput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
 This function is related to GetAttentionDotProducts(); it is used in backprop. More...
 
void AttentionForward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
 This is a higher-level interface to the attention code. More...
 
void AttentionBackward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
 Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to the keys, queries and values. More...
 

Detailed Description

This file contains the lower-level interface for self-attention.

This is a form of self-attention, inspired by Google's paper "Attention is all you need", but implemented in a way that's more obviously suitable for speech tasks. The main difference is that instead of taking as input *all frames* from the previous layer, we accept a limited grid of frames (so the left-context and right-context are finite). Also time-encoding is handled in a different way– we encode the time as a relative offset.

Definition in file attention.h.