20 #ifndef KALDI_CUDAMATRIX_CUBLAS_WRAPPERS_H_ 21 #define KALDI_CUDAMATRIX_CUBLAS_WRAPPERS_H_ 1 29 inline cublasStatus_t cublas_gemm(
30 cublasHandle_t handle, cublasOperation_t transa,
31 cublasOperation_t transb,
int m,
int n,
int k,
float alpha,
32 const float *A,
int lda,
const float *B,
int ldb,
float beta,
34 return cublasSgemm_v2(handle,transa,transb,m,n,k,&alpha,A,lda,B,ldb,&beta,C,ldc);
36 inline cublasStatus_t cublas_gemm(
37 cublasHandle_t handle, cublasOperation_t transa,
38 cublasOperation_t transb,
int m,
int n,
int k,
double alpha,
39 const double *A,
int lda,
const double *B,
int ldb,
double beta,
41 return cublasDgemm_v2(handle,transa,transb,m,n,k,&alpha,A,lda,B,ldb,&beta,C,ldc);
43 inline cublasStatus_t cublas_ger(
44 cublasHandle_t handle,
int m,
int n,
float alpha,
45 const float *x,
int incx,
const float *y,
int incy,
float *A,
int lda ) {
46 return cublasSger_v2(handle,m,n,&alpha,x,incx,y,incy,A,lda);
48 inline cublasStatus_t cublas_ger(cublasHandle_t handle,
int m,
int n,
double alpha,
49 const double *x,
int incx,
const double *y,
int incy,
double *A,
int lda ) {
50 return cublasDger_v2(handle,m,n,&alpha,x,incx,y,incy,A,lda);
52 inline cublasStatus_t cublas_gemmBatched(
53 cublasHandle_t handle, cublasOperation_t transa,
54 cublasOperation_t transb,
int m,
int n,
int k,
float alpha,
55 const float *A[],
int lda,
const float *B[],
int ldb,
float beta,
56 float *C[],
int ldc,
int batchCount) {
57 return cublasSgemmBatched(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc, batchCount);
59 inline cublasStatus_t cublas_gemmBatched(
60 cublasHandle_t handle, cublasOperation_t transa,
61 cublasOperation_t transb,
int m,
int n,
int k,
double alpha,
62 const double *A[],
int lda,
const double *B[],
int ldb,
double beta,
63 double *C[],
int ldc,
int batchCount) {
64 return cublasDgemmBatched(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc, batchCount);
66 inline cublasStatus_t cublas_trsm(cublasHandle_t handle,
int m,
int n,
67 float alpha,
const float* A,
int lda,
69 return cublasStrsm_v2(handle,CUBLAS_SIDE_LEFT,CUBLAS_FILL_MODE_UPPER,CUBLAS_OP_N,CUBLAS_DIAG_NON_UNIT,m,n,&alpha,A,lda,B,ldb);
71 inline cublasStatus_t cublas_trsm(cublasHandle_t handle,
int m,
int n,
72 double alpha,
const double* A,
int lda,
74 return cublasDtrsm_v2(handle,CUBLAS_SIDE_LEFT,CUBLAS_FILL_MODE_UPPER,CUBLAS_OP_N,CUBLAS_DIAG_NON_UNIT,m,n,&alpha,A,lda,B,ldb);
76 inline cublasStatus_t cublas_syrk(
77 cublasHandle_t handle, cublasFillMode_t uplo,
78 cublasOperation_t trans,
int n,
int k,
float alpha,
79 const float *A,
int lda,
float beta,
float *C,
int ldc) {
80 return cublasSsyrk_v2(handle,uplo,trans,n,k,&alpha,A,lda,&beta,C,ldc);
82 inline cublasStatus_t cublas_syrk(
83 cublasHandle_t handle, cublasFillMode_t uplo,
84 cublasOperation_t trans,
int n,
int k,
double alpha,
85 const double *A,
int lda,
double beta,
double *C,
int ldc) {
86 return cublasDsyrk_v2(handle,uplo,trans,n,k,&alpha,A,lda,&beta,C,ldc);
88 inline cublasStatus_t cublas_dot(cublasHandle_t handle,
int n,
const float *x,
89 int incx,
const float *y,
int incy,
91 return cublasSdot_v2(handle, n, x, incx, y, incy, result);
93 inline cublasStatus_t cublas_dot(cublasHandle_t handle,
int n,
const double *x,
94 int incx,
const double *y,
int incy,
96 return cublasDdot_v2(handle, n, x, incx, y, incy, result);
98 inline cublasStatus_t cublas_asum(cublasHandle_t handle,
int n,
const float* x,
99 int incx,
float *result) {
100 return cublasSasum_v2(handle, n, x, incx, result);
102 inline cublasStatus_t cublas_asum(cublasHandle_t handle,
int n,
const double* x,
103 int incx,
double *result) {
104 return cublasDasum_v2(handle, n, x, incx, result);
106 inline cublasStatus_t cublas_nrm2(cublasHandle_t handle,
int n,
const float* x,
107 int incx,
float *result) {
108 return cublasSnrm2_v2(handle, n, x, incx, result);
110 inline cublasStatus_t cublas_nrm2(cublasHandle_t handle,
int n,
const double* x,
111 int incx,
double *result) {
112 return cublasDnrm2_v2(handle, n, x, incx, result);
114 inline cudaError_t cublas_copy(cublasHandle_t handle,
int n,
const float* x,
115 int incx,
double* y,
int incy) {
118 cublas_copy_kaldi_fd(dimGrid, dimBlock, n, x, incx, y, incy);
119 return cudaGetLastError();
121 inline cudaError_t cublas_copy(cublasHandle_t handle,
int n,
const double* x,
122 int incx,
float* y,
int incy) {
125 cublas_copy_kaldi_df(dimGrid, dimBlock, n, x, incx, y, incy);
126 return cudaGetLastError();
128 inline cublasStatus_t cublas_copy(cublasHandle_t handle,
int n,
const float* x,
129 int incx,
float* y,
int incy) {
130 return cublasScopy_v2(handle,n,x,incx,y,incy);
132 inline cublasStatus_t cublas_copy(cublasHandle_t handle,
int n,
const double* x,
133 int incx,
double* y,
int incy) {
134 return cublasDcopy_v2(handle,n,x,incx,y,incy);
136 inline cublasStatus_t cublas_scal(cublasHandle_t handle,
int n,
float alpha,
137 float* mat,
int incx) {
138 return cublasSscal_v2(handle, n, &alpha, mat, incx);
140 inline cublasStatus_t cublas_scal(cublasHandle_t handle,
int n,
double alpha,
141 double* mat,
int incx) {
142 return cublasDscal_v2(handle, n, &alpha, mat, incx);
145 inline cublasStatus_t cublas_axpy(cublasHandle_t handle,
int n,
float alpha,
146 const float* x,
int incx,
float* y,
int incy) {
147 return cublasSaxpy_v2(handle, n, &alpha, x, incx, y, incy);
149 inline cublasStatus_t cublas_axpy(cublasHandle_t handle,
int n,
double alpha,
150 const double* x,
int incx,
double* y,
int incy) {
151 return cublasDaxpy_v2(handle, n, &alpha, x, incx, y, incy);
153 inline cublasStatus_t cublas_gemv(
154 cublasHandle_t handle, cublasOperation_t trans,
155 int m,
int n,
float alpha,
const float* A,
int lda,
const float* x,
156 int incx,
float beta,
float* y,
int incy) {
157 return cublasSgemv_v2(handle,trans,m,n,&alpha,A,lda,x,incx,&beta,y,incy);
159 inline cublasStatus_t cublas_gemv(
160 cublasHandle_t handle, cublasOperation_t trans,
161 int m,
int n,
double alpha,
const double* A,
int lda,
const double* x,
162 int incx,
double beta,
double* y,
int incy) {
163 return cublasDgemv_v2(handle,trans,m,n,&alpha,A,lda,x,incx,&beta,y,incy);
166 inline cublasStatus_t cublas_spmv(
167 cublasHandle_t handle, cublasFillMode_t uplo,
168 int n,
float alpha,
const float *AP,
const float *x,
int incx,
169 float beta,
float *y,
int incy) {
170 return cublasSspmv_v2(handle, uplo, n, &alpha, AP, x, incx, &beta, y, incy);
172 inline cublasStatus_t cublas_spmv(
173 cublasHandle_t handle, cublasFillMode_t uplo,
174 int n,
double alpha,
const double *AP,
const double *x,
int incx,
175 double beta,
double *y,
int incy) {
176 return cublasDspmv_v2(handle, uplo, n, &alpha, AP, x, incx, &beta, y, incy);
184 inline cublasStatus_t cublas_tpmv(cublasHandle_t handle, cublasOperation_t trans,
185 int n,
const float* Ap,
float* x,
int incx) {
186 return cublasStpmv_v2(handle, CUBLAS_FILL_MODE_UPPER, trans, CUBLAS_DIAG_NON_UNIT, n, Ap, x, incx);
188 inline cublasStatus_t cublas_tpmv(cublasHandle_t handle, cublasOperation_t trans,
189 int n,
const double* Ap,
double* x,
int incx) {
190 return cublasDtpmv_v2(handle, CUBLAS_FILL_MODE_UPPER, trans, CUBLAS_DIAG_NON_UNIT, n, Ap, x, incx);
193 inline cublasStatus_t cublas_spr(cublasHandle_t handle, cublasFillMode_t uplo,
194 int n,
float alpha,
const float *x,
int incx,
196 return cublasSspr_v2(handle, uplo, n, &alpha, x, incx, AP);
198 inline cublasStatus_t cublas_spr(cublasHandle_t handle, cublasFillMode_t uplo,
199 int n,
double alpha,
const double *x,
int incx,
201 return cublasDspr_v2(handle, uplo, n, &alpha, x, incx, AP);
208 inline cusparseStatus_t cusparse_csr2csc(cusparseHandle_t handle,
int m,
int n,
209 int nnz,
const float *csrVal,
210 const int *csrRowPtr,
211 const int *csrColInd,
float *cscVal,
212 int *cscRowInd,
int *cscColPtr,
213 cusparseAction_t copyValues,
214 cusparseIndexBase_t idxBase) {
215 return cusparseScsr2csc(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd,
216 cscVal, cscRowInd, cscColPtr, copyValues, idxBase);
218 inline cusparseStatus_t cusparse_csr2csc(cusparseHandle_t handle,
int m,
int n,
219 int nnz,
const double *csrVal,
220 const int *csrRowPtr,
221 const int *csrColInd,
double *cscVal,
222 int *cscRowInd,
int *cscColPtr,
223 cusparseAction_t copyValues,
224 cusparseIndexBase_t idxBase) {
225 return cusparseDcsr2csc(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd,
226 cscVal, cscRowInd, cscColPtr, copyValues, idxBase);
229 inline cusparseStatus_t cusparse_csrmm(cusparseHandle_t handle,
230 cusparseOperation_t transA,
int m,
int n,
231 int k,
int nnz,
const float *alpha,
232 const cusparseMatDescr_t descrA,
233 const float *csrValA,
234 const int *csrRowPtrA,
235 const int *csrColIndA,
const float *B,
236 int ldb,
const float *beta,
float *C,
238 return cusparseScsrmm(handle, transA, m, n, k, nnz, alpha, descrA, csrValA,
239 csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
241 inline cusparseStatus_t cusparse_csrmm(cusparseHandle_t handle,
242 cusparseOperation_t transA,
int m,
int n,
243 int k,
int nnz,
const double *alpha,
244 const cusparseMatDescr_t descrA,
245 const double *csrValA,
246 const int *csrRowPtrA,
247 const int *csrColIndA,
const double *B,
248 int ldb,
const double *beta,
double *C,
250 return cusparseDcsrmm(handle, transA, m, n, k, nnz, alpha, descrA, csrValA,
251 csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
254 inline cusparseStatus_t cusparse_csrmm2(cusparseHandle_t handle,
255 cusparseOperation_t transA,
256 cusparseOperation_t transB,
int m,
257 int n,
int k,
int nnz,
259 const cusparseMatDescr_t descrA,
260 const float *csrValA,
261 const int *csrRowPtrA,
262 const int *csrColIndA,
const float *B,
263 int ldb,
const float *beta,
float *C,
265 return cusparseScsrmm2(handle, transA, transB, m, n, k, nnz, alpha, descrA,
266 csrValA, csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
268 inline cusparseStatus_t cusparse_csrmm2(cusparseHandle_t handle,
269 cusparseOperation_t transA,
270 cusparseOperation_t transB,
int m,
271 int n,
int k,
int nnz,
273 const cusparseMatDescr_t descrA,
274 const double *csrValA,
275 const int *csrRowPtrA,
276 const int *csrColIndA,
const double *B,
277 int ldb,
const double *beta,
double *C,
279 return cusparseDcsrmm2(handle, transA, transB, m, n, k, nnz, alpha, descrA,
280 csrValA, csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...