20 #ifndef KALDI_MATRIX_CBLAS_WRAPPERS_H_ 21 #define KALDI_MATRIX_CBLAS_WRAPPERS_H_ 1 37 inline void cblas_Xcopy(
const int N,
const float *X,
const int incX,
float *Y,
39 cblas_scopy(N, X, incX, Y, incY);
42 inline void cblas_Xcopy(
const int N,
const double *X,
const int incX,
double *Y,
44 cblas_dcopy(N, X, incX, Y, incY);
48 inline float cblas_Xasum(
const int N,
const float *X,
const int incX) {
49 return cblas_sasum(N, X, incX);
52 inline double cblas_Xasum(
const int N,
const double *X,
const int incX) {
53 return cblas_dasum(N, X, incX);
56 inline void cblas_Xrot(
const int N,
float *X,
const int incX,
float *Y,
57 const int incY,
const float c,
const float s) {
58 cblas_srot(N, X, incX, Y, incY, c, s);
60 inline void cblas_Xrot(
const int N,
double *X,
const int incX,
double *Y,
61 const int incY,
const double c,
const double s) {
62 cblas_drot(N, X, incX, Y, incY, c, s);
64 inline float cblas_Xdot(
const int N,
const float *
const X,
65 const int incX,
const float *
const Y,
67 return cblas_sdot(N, X, incX, Y, incY);
69 inline double cblas_Xdot(
const int N,
const double *
const X,
70 const int incX,
const double *
const Y,
72 return cblas_ddot(N, X, incX, Y, incY);
74 inline void cblas_Xaxpy(
const int N,
const float alpha,
const float *X,
75 const int incX,
float *Y,
const int incY) {
76 cblas_saxpy(N, alpha, X, incX, Y, incY);
78 inline void cblas_Xaxpy(
const int N,
const double alpha,
const double *X,
79 const int incX,
double *Y,
const int incY) {
80 cblas_daxpy(N, alpha, X, incX, Y, incY);
82 inline void cblas_Xscal(
const int N,
const float alpha,
float *data,
84 cblas_sscal(N, alpha, data, inc);
86 inline void cblas_Xscal(
const int N,
const double alpha,
double *data,
88 cblas_dscal(N, alpha, data, inc);
90 inline void cblas_Xspmv(
const float alpha,
const int num_rows,
const float *Mdata,
91 const float *v,
const int v_inc,
92 const float beta,
float *y,
const int y_inc) {
93 cblas_sspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc);
95 inline void cblas_Xspmv(
const double alpha,
const int num_rows,
const double *Mdata,
96 const double *v,
const int v_inc,
97 const double beta,
double *y,
const int y_inc) {
98 cblas_dspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc);
101 const int num_rows,
float *y,
const int y_inc) {
102 cblas_stpmv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans),
103 CblasNonUnit, num_rows, Mdata, y, y_inc);
106 const int num_rows,
double *y,
const int y_inc) {
107 cblas_dtpmv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans),
108 CblasNonUnit, num_rows, Mdata, y, y_inc);
113 const int num_rows,
float *y,
const int y_inc) {
114 cblas_stpsv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans),
115 CblasNonUnit, num_rows, Mdata, y, y_inc);
118 const int num_rows,
double *y,
const int y_inc) {
119 cblas_dtpsv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans),
120 CblasNonUnit, num_rows, Mdata, y, y_inc);
127 cblas_sspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata,
128 ydata, ystride, beta, xdata, xstride);
133 cblas_dspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata,
134 ydata, ystride, beta, xdata, xstride);
141 cblas_sspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata,
142 incX, Ydata, incY, Adata);
147 cblas_dspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata,
148 incX, Ydata, incY, Adata);
154 cblas_sspr(CblasRowMajor, CblasLower, dim, alpha, Xdata, incX, Adata);
158 cblas_dspr(CblasRowMajor, CblasLower, dim, alpha, Xdata, incX, Adata);
166 cblas_sgemv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows,
167 num_cols, alpha, Mdata, stride, xdata, incX, beta, ydata, incY);
170 MatrixIndexT num_cols,
double alpha,
const double *Mdata,
173 cblas_dgemv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows,
174 num_cols, alpha, Mdata, stride, xdata, incX, beta, ydata, incY);
180 MatrixIndexT num_above,
float alpha,
const float *Mdata,
183 cblas_sgbmv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows,
184 num_cols, num_below, num_above, alpha, Mdata, stride, xdata,
185 incX, beta, ydata, incY);
189 MatrixIndexT num_above,
double alpha,
const double *Mdata,
192 cblas_dgbmv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows,
193 num_cols, num_below, num_above, alpha, Mdata, stride, xdata,
194 incX, beta, ydata, incY);
198 template<
typename Real>
205 if (beta != 1.0)
cblas_Xscal(num_rows, beta, ydata, incY);
207 Real x_i = xdata[
i * incX];
208 if (x_i == 0.0)
continue;
210 cblas_Xaxpy(num_rows, x_i * alpha, Mdata +
i, stride, ydata, incY);
213 if (beta != 1.0)
cblas_Xscal(num_cols, beta, ydata, incY);
215 Real x_i = xdata[
i * incX];
216 if (x_i == 0.0)
continue;
219 Mdata + (
i * stride), 1, ydata, incY);
233 cblas_sgemm(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(transA),
234 static_cast<CBLAS_TRANSPOSE>(transB),
235 num_rows, num_cols, transA ==
kNoTrans ? a_num_cols : a_num_rows,
236 alpha, Adata, a_stride, Bdata, b_stride,
237 beta, Mdata, stride);
248 cblas_dgemm(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(transA),
249 static_cast<CBLAS_TRANSPOSE>(transB),
250 num_rows, num_cols, transA ==
kNoTrans ? a_num_cols : a_num_rows,
251 alpha, Adata, a_stride, Bdata, b_stride,
252 beta, Mdata, stride);
262 cblas_ssymm(CblasRowMajor, CblasLeft, CblasLower, sz, sz, alpha, Adata,
263 a_stride, Bdata, b_stride, beta, Mdata, stride);
271 cblas_dsymm(CblasRowMajor, CblasLeft, CblasLower, sz, sz, alpha, Adata,
272 a_stride, Bdata, b_stride, beta, Mdata, stride);
276 const float *xdata,
MatrixIndexT incX,
const float *ydata,
278 cblas_sger(CblasRowMajor, num_rows, num_cols, alpha, xdata, 1, ydata, 1,
282 const double *xdata,
MatrixIndexT incX,
const double *ydata,
284 cblas_dger(CblasRowMajor, num_rows, num_cols, alpha, xdata, 1, ydata, 1,
297 const MatrixIndexT other_dim_a,
const float alpha,
const float *A,
298 const MatrixIndexT a_stride,
const float beta,
float *C,
300 cblas_ssyrk(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans),
301 dim_c, other_dim_a, alpha, A, a_stride, beta, C, c_stride);
306 const MatrixIndexT other_dim_a,
const double alpha,
const double *A,
307 const MatrixIndexT a_stride,
const double beta,
double *C,
309 cblas_dsyrk(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans),
310 dim_c, other_dim_a, alpha, A, a_stride, beta, C, c_stride);
324 cblas_dsbmv(CblasRowMajor, CblasLower, dim, 0, alpha, A,
325 1, x, 1, beta, y, 1);
335 cblas_ssbmv(CblasRowMajor, CblasLower, dim, 0, alpha, A,
336 1, x, 1, beta, y, 1);
345 double c1, c2, c3, c4;
347 for (i = 0; i + 4 <= dim; i += 4) {
349 c2 = a[i+1] * b[i+1];
350 c3 = a[i+2] * b[i+2];
351 c4 = a[i+3] * b[i+3];
365 float c1, c2, c3, c4;
367 for (i = 0; i + 4 <= dim; i += 4) {
369 c2 = a[i+1] * b[i+1];
370 c3 = a[i+2] * b[i+2];
371 c4 = a[i+3] * b[i+3];
384 #if !defined(HAVE_ATLAS) 385 inline void clapack_Xtptri(KaldiBlasInt *num_rows,
float *Mdata, KaldiBlasInt *result) {
386 stptri_(const_cast<char *>(
"U"), const_cast<char *>(
"N"), num_rows, Mdata, result);
388 inline void clapack_Xtptri(KaldiBlasInt *num_rows,
double *Mdata, KaldiBlasInt *result) {
389 dtptri_(const_cast<char *>(
"U"), const_cast<char *>(
"N"), num_rows, Mdata, result);
393 float *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot,
394 KaldiBlasInt *result) {
395 sgetrf_(num_rows, num_cols, Mdata, stride, pivot, result);
398 double *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot,
399 KaldiBlasInt *result) {
400 dgetrf_(num_rows, num_cols, Mdata, stride, pivot, result);
404 inline void clapack_Xgetri2(KaldiBlasInt *num_rows,
float *Mdata, KaldiBlasInt *stride,
405 KaldiBlasInt *pivot,
float *p_work,
406 KaldiBlasInt *l_work, KaldiBlasInt *result) {
407 sgetri_(num_rows, Mdata, stride, pivot, p_work, l_work, result);
409 inline void clapack_Xgetri2(KaldiBlasInt *num_rows,
double *Mdata, KaldiBlasInt *stride,
410 KaldiBlasInt *pivot,
double *p_work,
411 KaldiBlasInt *l_work, KaldiBlasInt *result) {
412 dgetri_(num_rows, Mdata, stride, pivot, p_work, l_work, result);
416 KaldiBlasInt *num_rows,
float *Mdata, KaldiBlasInt *stride,
417 float *sv,
float *Vdata, KaldiBlasInt *vstride,
418 float *Udata, KaldiBlasInt *ustride,
float *p_work,
419 KaldiBlasInt *l_work, KaldiBlasInt *result) {
421 num_cols, num_rows, Mdata, stride,
422 sv, Vdata, vstride, Udata, ustride,
423 p_work, l_work, result);
426 KaldiBlasInt *num_rows,
double *Mdata, KaldiBlasInt *stride,
427 double *sv,
double *Vdata, KaldiBlasInt *vstride,
428 double *Udata, KaldiBlasInt *ustride,
double *p_work,
429 KaldiBlasInt *l_work, KaldiBlasInt *result) {
431 num_cols, num_rows, Mdata, stride,
432 sv, Vdata, vstride, Udata, ustride,
433 p_work, l_work, result);
437 KaldiBlasInt *ipiv,
float *work, KaldiBlasInt *result) {
438 ssptri_(const_cast<char *>(
"U"), num_rows, Mdata, ipiv, work, result);
441 KaldiBlasInt *ipiv,
double *work, KaldiBlasInt *result) {
442 dsptri_(const_cast<char *>(
"U"), num_rows, Mdata, ipiv, work, result);
446 KaldiBlasInt *ipiv, KaldiBlasInt *result) {
447 ssptrf_(const_cast<char *>(
"U"), num_rows, Mdata, ipiv, result);
450 KaldiBlasInt *ipiv, KaldiBlasInt *result) {
451 dsptrf_(const_cast<char *>(
"U"), num_rows, Mdata, ipiv, result);
456 int *pivot,
int *result) {
457 *result = clapack_sgetrf(CblasColMajor, num_rows, num_cols,
458 Mdata, stride, pivot);
463 int *pivot,
int *result) {
464 *result = clapack_dgetrf(CblasColMajor, num_rows, num_cols,
465 Mdata, stride, pivot);
468 inline int clapack_Xtrtri(
int num_rows,
float *Mdata,
MatrixIndexT stride) {
469 return clapack_strtri(CblasColMajor, CblasUpper, CblasNonUnit, num_rows,
473 inline int clapack_Xtrtri(
int num_rows,
double *Mdata,
MatrixIndexT stride) {
474 return clapack_dtrtri(CblasColMajor, CblasUpper, CblasNonUnit, num_rows,
479 int *pivot,
int *result) {
480 *result = clapack_sgetri(CblasColMajor, num_rows, Mdata, stride, pivot);
483 int *pivot,
int *result) {
484 *result = clapack_dgetri(CblasColMajor, num_rows, Mdata, stride, pivot);
void cblas_Xsbmv1(const MatrixIndexT dim, const double *A, const double alpha, const double *x, const double beta, double *y)
matrix-vector multiply using a banded matrix; we always call this with b = 1 meaning we're multiplyin...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void cblas_Xsyrk(const MatrixTransposeType trans, const MatrixIndexT dim_c, const MatrixIndexT other_dim_a, const float alpha, const float *A, const MatrixIndexT a_stride, const float beta, float *C, const MatrixIndexT c_stride)
float cblas_Xasum(const int N, const float *X, const int incX)
void clapack_Xgetri2(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot, float *p_work, KaldiBlasInt *l_work, KaldiBlasInt *result)
void clapack_Xsptri(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *ipiv, float *work, KaldiBlasInt *result)
void cblas_Xtpmv(MatrixTransposeType trans, const float *Mdata, const int num_rows, float *y, const int y_inc)
void cblas_Xspr2(MatrixIndexT dim, float alpha, const float *Xdata, MatrixIndexT incX, const float *Ydata, MatrixIndexT incY, float *Adata)
void cblas_Xtpsv(MatrixTransposeType trans, const float *Mdata, const int num_rows, float *y, const int y_inc)
void Xgemv_sparsevec(MatrixTransposeType trans, MatrixIndexT num_rows, MatrixIndexT num_cols, Real alpha, const Real *Mdata, MatrixIndexT stride, const Real *xdata, MatrixIndexT incX, Real beta, Real *ydata, MatrixIndexT incY)
void clapack_Xgetrf2(KaldiBlasInt *num_rows, KaldiBlasInt *num_cols, float *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot, KaldiBlasInt *result)
void cblas_Xcopy(const int N, const float *X, const int incX, float *Y, const int incY)
void clapack_Xgesvd(char *v, char *u, KaldiBlasInt *num_cols, KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *stride, float *sv, float *Vdata, KaldiBlasInt *vstride, float *Udata, KaldiBlasInt *ustride, float *p_work, KaldiBlasInt *l_work, KaldiBlasInt *result)
void mul_elements(const MatrixIndexT dim, const double *a, double *b)
This is not really a wrapper for CBLAS as CBLAS does not have this; in future we could extend this so...
float cblas_Xdot(const int N, const float *const X, const int incX, const float *const Y, const int incY)
void clapack_Xtptri(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *result)
void cblas_Xscal(const int N, const float alpha, float *data, const int inc)
void cblas_Xsymm(const float alpha, MatrixIndexT sz, const float *Adata, MatrixIndexT a_stride, const float *Bdata, MatrixIndexT b_stride, const float beta, float *Mdata, MatrixIndexT stride)
void cblas_Xspr(MatrixIndexT dim, float alpha, const float *Xdata, MatrixIndexT incX, float *Adata)
void cblas_Xgemm(const float alpha, MatrixTransposeType transA, const float *Adata, MatrixIndexT a_num_rows, MatrixIndexT a_num_cols, MatrixIndexT a_stride, MatrixTransposeType transB, const float *Bdata, MatrixIndexT b_stride, const float beta, float *Mdata, MatrixIndexT num_rows, MatrixIndexT num_cols, MatrixIndexT stride)
void cblas_Xaxpy(const int N, const float alpha, const float *X, const int incX, float *Y, const int incY)
void cblas_Xgemv(MatrixTransposeType trans, MatrixIndexT num_rows, MatrixIndexT num_cols, float alpha, const float *Mdata, MatrixIndexT stride, const float *xdata, MatrixIndexT incX, float beta, float *ydata, MatrixIndexT incY)
void cblas_Xspmv(const float alpha, const int num_rows, const float *Mdata, const float *v, const int v_inc, const float beta, float *y, const int y_inc)
void cblas_Xgbmv(MatrixTransposeType trans, MatrixIndexT num_rows, MatrixIndexT num_cols, MatrixIndexT num_below, MatrixIndexT num_above, float alpha, const float *Mdata, MatrixIndexT stride, const float *xdata, MatrixIndexT incX, float beta, float *ydata, MatrixIndexT incY)
void clapack_Xsptrf(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *ipiv, KaldiBlasInt *result)
void cblas_Xger(MatrixIndexT num_rows, MatrixIndexT num_cols, float alpha, const float *xdata, MatrixIndexT incX, const float *ydata, MatrixIndexT incY, float *Mdata, MatrixIndexT stride)
void cblas_Xrot(const int N, float *X, const int incX, float *Y, const int incY, const float c, const float s)