#include <iterator>
#include <sstream>
#include <iomanip>
#include "nnet3/attention.h"
#include "nnet3/nnet-parse.h"
Go to the source code of this file.
|
void | GetAttentionDotProducts (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C) |
| This function is a utility function that is at the core of how we implement attention. More...
|
|
void | ApplyScalesToOutput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A) |
| This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softmax scales, and in backprop. More...
|
|
void | ApplyScalesToInput (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B) |
| This function is related to GetAttentionDotProducts(); it is used in backprop. More...
|
|
void | AttentionForward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output) |
| This is a higher-level interface to the attention code. More...
|
|
void | AttentionBackward (BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv) |
| Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to the keys, queries and values. More...
|
|