41 template<
typename Real>
52 template<
typename Real>
59 template<
typename Real>
75 KALDI_LOG <<
"the upper diaganoal sum for A is : " << sum;
82 KALDI_LOG <<
"the upper diaganoal sum for B is : " << sum;
98 C(
i) = 1 +
Rand() % 4;
196 C(
i) = (
i/(1.0*dim)) + 1;
212 AssertEqual(Identity, X, (
sizeof(Real) == 4 ? 0.1 : 0.001));
237 std::cout << D(
i,
j) <<
" ";
314 std::cout << D(
i,
j) <<
" ";
322 std::cout << D(
i,
j) <<
" ";
331 std::cout << D(
i,
j) <<
" ";
560 template<
typename Real>
562 UnitTestTrace<Real>();
563 UnitTestCholesky<Real>();
564 UnitTestInvert<Real>();
566 UnitTestCopyFromMat<Real>();
567 UnitTestCopySp<Real>();
568 UnitTestConstructor<Real>();
569 UnitTestVector<Real>();
570 UnitTestMulTp<Real>();
571 UnitTestMatrix<Real>();
572 UnitTestSetZeroAboveDiag<Real>();
577 using namespace kaldi;
580 for (
int32 loop = 0; loop < 2; loop++) {
582 CuDevice::Instantiate().SelectGpuId(
"no");
584 CuDevice::Instantiate().SelectGpuId(
"yes");
586 kaldi::CuMatrixUnitTest<float>();
589 if (!kaldi::CuDevice::Instantiate().DoublePrecisionSupported()) {
590 KALDI_WARN <<
"Double precision not supported, not testing that code";
594 kaldi::CuMatrixUnitTest<double>();
598 kaldi::CuDevice::Instantiate().PrintProfile();
void MulElements(const CuVectorBase< Real > &v)
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
MatrixIndexT Stride() const
static void UnitTestCopyFromMat()
Packed symetric matrix class.
static void UnitTestConstructor()
void MulTp(const CuTpMatrix< Real > &M, const MatrixTransposeType trans)
Multiplies this vector by lower-triangular marix: *this <– *this *M.
void CopyToMat(MatrixBase< OtherReal > *dst, MatrixTransposeType trans=kNoTrans) const
static void UnitTestSetZeroAboveDiag()
static void UnitTestVector()
static void UnitTestInvert()
float RandGauss(struct RandomState *state=NULL)
void MulTp(const TpMatrix< Real > &M, const MatrixTransposeType trans)
Multiplies this vector by lower-triangular matrix: *this <– *this *M.
A class for storing matrices.
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
MatrixIndexT NumRows() const
static void UnitTestMatrix()
static void InitRand(VectorBase< Real > *v)
void SetRandn()
< Set to unit matrix.
void SetVerboseLevel(int32 i)
This should be rarely used, except by programs using Kaldi as library; command-line programs set the ...
void SymInvertPosDef()
Inversion for positive definite symmetric matrices.
static void UnitTestMulTp()
static void UnitTestCholesky()
void MulElements(const VectorBase< Real > &v)
Multiply element-by-element by another vector.
void SetRandn()
Sets to random values of a normal distribution.
void AddMatMat(const Real alpha, const MatrixBase< Real > &A, MatrixTransposeType transA, const MatrixBase< Real > &B, MatrixTransposeType transB, const Real beta)
static void CuMatrixUnitTest()
Packed symetric matrix class.
void Cholesky(CuMatrixBase< Real > *inv_cholesky=NULL)
This function does sets *this to the Cholesky factor of *this (i.e.
static void UnitTestTrace()
MatrixIndexT Dim() const
Returns the dimension of the vector.
void Scale(Real alpha)
Multiplies all elements by this constant.
int Rand(struct RandomState *state)
void SetRandn()
Set vector to random normally-distributed noise.
void CopyFromMat(const CuMatrixBase< Real > &orig, SpCopyType copy_type=kTakeLower)
void CopyToSp(SpMatrix< Real > *dst) const
A class representing a vector.
#define KALDI_ASSERT(cond)
Real Cond() const
Returns maximum ratio of singular values.
void AddVecVec(const Real alpha, const VectorBase< OtherReal > &a, const VectorBase< OtherReal > &b)
*this += alpha * a * b^T
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
static void UnitTestCopySp()
Provides a vector abstraction class.
void CopyToVec(VectorBase< OtherReal > *dst) const
void SetZeroAboveDiag()
Zeroes all elements for which col > row.