21 #ifndef KALDI_NNET3_ATTENTION_H_ 22 #define KALDI_NNET3_ATTENTION_H_ 217 const CuMatrixBase<BaseFloat> &A,
218 const CuMatrixBase<BaseFloat> &B,
219 CuMatrixBase<BaseFloat> *C);
237 const CuMatrixBase<BaseFloat> &B,
238 const CuMatrixBase<BaseFloat> &C,
239 CuMatrixBase<BaseFloat> *A);
256 const CuMatrixBase<BaseFloat> &A,
257 const CuMatrixBase<BaseFloat> &C,
258 CuMatrixBase<BaseFloat> *B);
296 const CuMatrixBase<BaseFloat> &keys,
297 const CuMatrixBase<BaseFloat> &queries,
298 const CuMatrixBase<BaseFloat> &values,
299 CuMatrixBase<BaseFloat> *c,
300 CuMatrixBase<BaseFloat> *output);
311 const CuMatrixBase<BaseFloat> &keys,
312 const CuMatrixBase<BaseFloat> &queries,
313 const CuMatrixBase<BaseFloat> &values,
314 const CuMatrixBase<BaseFloat> &c,
315 const CuMatrixBase<BaseFloat> &output_deriv,
316 CuMatrixBase<BaseFloat> *keys_deriv,
317 CuMatrixBase<BaseFloat> *queries_deriv,
318 CuMatrixBase<BaseFloat> *values_deriv);
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to th...
void ApplyScalesToInput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
This function is related to GetAttentionDotProducts(); it is used in backprop.
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code.
This file contains some fairly low-level utilities for implementing convolutional neural networks and...
void ApplyScalesToOutput(BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
This function is related to GetAttentionDotProducts(); it is used in scaling the values by the softma...
void GetAttentionDotProducts(BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
This function is a utility function that is at the core of how we implement attention.