attention-test.cc File Reference
#include "nnet3/attention.h"
#include "util/common-utils.h"
Include dependency graph for attention-test.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 
 kaldi::nnet3
 
 kaldi::nnet3::attention
 

Functions

void GetAttentionDotProductsSimple (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &B, CuMatrixBase< BaseFloat > *C)
 
void ApplyScalesToOutputSimple (BaseFloat alpha, const CuMatrixBase< BaseFloat > &B, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *A)
 
void ApplyScalesToInputSimple (BaseFloat alpha, const CuMatrixBase< BaseFloat > &A, const CuMatrixBase< BaseFloat > &C, CuMatrixBase< BaseFloat > *B)
 
void UnitTestAttentionDotProductAndAddScales ()
 
void TestAttentionForwardBackward ()
 
void UnitTestAttention ()
 
int main ()
 

Function Documentation

◆ main()

int main ( )

Definition at line 240 of file attention-test.cc.

References rnnlm::i, and kaldi::nnet3::attention::UnitTestAttention().

240  {
241  using namespace kaldi;
242  using namespace kaldi::nnet3;
243  using namespace kaldi::nnet3::attention;
244  for (int32 loop = 0; loop < 2; loop++) {
245 #if HAVE_CUDA == 1
246  CuDevice::Instantiate().SetDebugStrideMode(true);
247  if (loop == 0)
248  CuDevice::Instantiate().SelectGpuId("no"); // -1 means no GPU
249  else
250  CuDevice::Instantiate().SelectGpuId("optional"); // -2 .. automatic selection
251 #endif
252  for (int32 i = 0; i < 5; i++) {
254  }
255  }
256 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
kaldi::int32 int32