cublas-wrappers.h
Go to the documentation of this file.
1 // cudamatrix/cublas-wrappers.h
2 
3 // Copyright 2013 Johns Hopkins University (author: Daniel Povey);
4 // 2017 Shiyin Kang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 
12 // http://www.apache.org/licenses/LICENSE-2.0
13 
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 #ifndef KALDI_CUDAMATRIX_CUBLAS_WRAPPERS_H_
21 #define KALDI_CUDAMATRIX_CUBLAS_WRAPPERS_H_ 1
22 
23 // Do not include this file directly. It is to be included
24 // by .cc files in this directory.
25 
26 namespace kaldi {
27 #if HAVE_CUDA == 1
28 
29 inline cublasStatus_t cublas_gemm(
30  cublasHandle_t handle, cublasOperation_t transa,
31  cublasOperation_t transb, int m, int n,int k, float alpha,
32  const float *A, int lda, const float *B, int ldb, float beta,
33  float *C, int ldc) {
34  return cublasSgemm_v2(handle,transa,transb,m,n,k,&alpha,A,lda,B,ldb,&beta,C,ldc);
35 }
36 inline cublasStatus_t cublas_gemm(
37  cublasHandle_t handle, cublasOperation_t transa,
38  cublasOperation_t transb, int m, int n,int k, double alpha,
39  const double *A, int lda, const double *B, int ldb, double beta,
40  double *C, int ldc) {
41  return cublasDgemm_v2(handle,transa,transb,m,n,k,&alpha,A,lda,B,ldb,&beta,C,ldc);
42 }
43 inline cublasStatus_t cublas_ger(
44  cublasHandle_t handle, int m, int n, float alpha,
45  const float *x, int incx, const float *y, int incy, float *A, int lda ) {
46  return cublasSger_v2(handle,m,n,&alpha,x,incx,y,incy,A,lda);
47 }
48 inline cublasStatus_t cublas_ger(cublasHandle_t handle, int m, int n, double alpha,
49  const double *x, int incx, const double *y, int incy, double *A, int lda ) {
50  return cublasDger_v2(handle,m,n,&alpha,x,incx,y,incy,A,lda);
51 }
52 inline cublasStatus_t cublas_gemmBatched(
53  cublasHandle_t handle, cublasOperation_t transa,
54  cublasOperation_t transb, int m, int n, int k, float alpha,
55  const float *A[], int lda, const float *B[], int ldb, float beta,
56  float *C[], int ldc, int batchCount) {
57  return cublasSgemmBatched(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc, batchCount);
58 }
59 inline cublasStatus_t cublas_gemmBatched(
60  cublasHandle_t handle, cublasOperation_t transa,
61  cublasOperation_t transb, int m, int n, int k, double alpha,
62  const double *A[], int lda, const double *B[], int ldb, double beta,
63  double *C[], int ldc, int batchCount) {
64  return cublasDgemmBatched(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc, batchCount);
65 }
66 inline cublasStatus_t cublas_trsm(cublasHandle_t handle, int m, int n,
67  float alpha, const float* A, int lda,
68  float* B, int ldb) {
69  return cublasStrsm_v2(handle,CUBLAS_SIDE_LEFT,CUBLAS_FILL_MODE_UPPER,CUBLAS_OP_N,CUBLAS_DIAG_NON_UNIT,m,n,&alpha,A,lda,B,ldb);
70 }
71 inline cublasStatus_t cublas_trsm(cublasHandle_t handle, int m, int n,
72  double alpha, const double* A, int lda,
73  double* B, int ldb) {
74  return cublasDtrsm_v2(handle,CUBLAS_SIDE_LEFT,CUBLAS_FILL_MODE_UPPER,CUBLAS_OP_N,CUBLAS_DIAG_NON_UNIT,m,n,&alpha,A,lda,B,ldb);
75 }
76 inline cublasStatus_t cublas_syrk(
77  cublasHandle_t handle, cublasFillMode_t uplo,
78  cublasOperation_t trans, int n, int k, float alpha,
79  const float *A, int lda, float beta, float *C, int ldc) {
80  return cublasSsyrk_v2(handle,uplo,trans,n,k,&alpha,A,lda,&beta,C,ldc);
81 }
82 inline cublasStatus_t cublas_syrk(
83  cublasHandle_t handle, cublasFillMode_t uplo,
84  cublasOperation_t trans, int n, int k, double alpha,
85  const double *A, int lda, double beta, double *C, int ldc) {
86  return cublasDsyrk_v2(handle,uplo,trans,n,k,&alpha,A,lda,&beta,C,ldc);
87 }
88 inline cublasStatus_t cublas_dot(cublasHandle_t handle, int n, const float *x,
89  int incx, const float *y, int incy,
90  float *result) {
91  return cublasSdot_v2(handle, n, x, incx, y, incy, result);
92 }
93 inline cublasStatus_t cublas_dot(cublasHandle_t handle, int n, const double *x,
94  int incx, const double *y, int incy,
95  double *result) {
96  return cublasDdot_v2(handle, n, x, incx, y, incy, result);
97 }
98 inline cublasStatus_t cublas_asum(cublasHandle_t handle, int n, const float* x,
99  int incx, float *result) {
100  return cublasSasum_v2(handle, n, x, incx, result);
101 }
102 inline cublasStatus_t cublas_asum(cublasHandle_t handle, int n, const double* x,
103  int incx, double *result) {
104  return cublasDasum_v2(handle, n, x, incx, result);
105 }
106 inline cublasStatus_t cublas_nrm2(cublasHandle_t handle, int n, const float* x,
107  int incx, float *result) {
108  return cublasSnrm2_v2(handle, n, x, incx, result);
109 }
110 inline cublasStatus_t cublas_nrm2(cublasHandle_t handle, int n, const double* x,
111  int incx, double *result) {
112  return cublasDnrm2_v2(handle, n, x, incx, result);
113 }
114 inline cudaError_t cublas_copy(cublasHandle_t handle, int n, const float* x,
115  int incx, double* y, int incy) {
116  int dimBlock(CU1DBLOCK);
117  int dimGrid(n_blocks(n, CU1DBLOCK));
118  cublas_copy_kaldi_fd(dimGrid, dimBlock, n, x, incx, y, incy);
119  return cudaGetLastError();
120 }
121 inline cudaError_t cublas_copy(cublasHandle_t handle, int n, const double* x,
122  int incx, float* y, int incy) {
123  int dimBlock(CU1DBLOCK);
124  int dimGrid(n_blocks(n, CU1DBLOCK));
125  cublas_copy_kaldi_df(dimGrid, dimBlock, n, x, incx, y, incy);
126  return cudaGetLastError();
127 }
128 inline cublasStatus_t cublas_copy(cublasHandle_t handle, int n, const float* x,
129  int incx, float* y, int incy) {
130  return cublasScopy_v2(handle,n,x,incx,y,incy);
131 }
132 inline cublasStatus_t cublas_copy(cublasHandle_t handle, int n, const double* x,
133  int incx, double* y, int incy) {
134  return cublasDcopy_v2(handle,n,x,incx,y,incy);
135 }
136 inline cublasStatus_t cublas_scal(cublasHandle_t handle, int n, float alpha,
137  float* mat, int incx) {
138  return cublasSscal_v2(handle, n, &alpha, mat, incx);
139 }
140 inline cublasStatus_t cublas_scal(cublasHandle_t handle, int n, double alpha,
141  double* mat, int incx) {
142  return cublasDscal_v2(handle, n, &alpha, mat, incx);
143 }
144 
145 inline cublasStatus_t cublas_axpy(cublasHandle_t handle, int n, float alpha,
146  const float* x, int incx, float* y, int incy) {
147  return cublasSaxpy_v2(handle, n, &alpha, x, incx, y, incy);
148 }
149 inline cublasStatus_t cublas_axpy(cublasHandle_t handle, int n, double alpha,
150  const double* x, int incx, double* y, int incy) {
151  return cublasDaxpy_v2(handle, n, &alpha, x, incx, y, incy);
152 }
153 inline cublasStatus_t cublas_gemv(
154  cublasHandle_t handle, cublasOperation_t trans,
155  int m, int n, float alpha, const float* A, int lda, const float* x,
156  int incx, float beta, float* y, int incy) {
157  return cublasSgemv_v2(handle,trans,m,n,&alpha,A,lda,x,incx,&beta,y,incy);
158 }
159 inline cublasStatus_t cublas_gemv(
160  cublasHandle_t handle, cublasOperation_t trans,
161  int m, int n, double alpha, const double* A, int lda, const double* x,
162  int incx, double beta, double* y, int incy) {
163  return cublasDgemv_v2(handle,trans,m,n,&alpha,A,lda,x,incx,&beta,y,incy);
164 }
165 
166 inline cublasStatus_t cublas_spmv(
167  cublasHandle_t handle, cublasFillMode_t uplo,
168  int n, float alpha, const float *AP, const float *x, int incx,
169  float beta, float *y, int incy) {
170  return cublasSspmv_v2(handle, uplo, n, &alpha, AP, x, incx, &beta, y, incy);
171 }
172 inline cublasStatus_t cublas_spmv(
173  cublasHandle_t handle, cublasFillMode_t uplo,
174  int n, double alpha, const double *AP, const double *x, int incx,
175  double beta, double *y, int incy) {
176  return cublasDspmv_v2(handle, uplo, n, &alpha, AP, x, incx, &beta, y, incy);
177 }
178 
179 // Use caution with these, the 'transpose' argument is the opposite of what it
180 // should really be, due to CUDA storing things in column major order. We also
181 // had to switch 'l' to 'u'; we view our packed matrices as lower-triangular,
182 // row-by-row, but CUDA views the same layout as upper-triangular,
183 // column-by-column.
184 inline cublasStatus_t cublas_tpmv(cublasHandle_t handle, cublasOperation_t trans,
185  int n, const float* Ap, float* x, int incx) {
186  return cublasStpmv_v2(handle, CUBLAS_FILL_MODE_UPPER, trans, CUBLAS_DIAG_NON_UNIT, n, Ap, x, incx);
187 }
188 inline cublasStatus_t cublas_tpmv(cublasHandle_t handle, cublasOperation_t trans,
189  int n, const double* Ap, double* x,int incx) {
190  return cublasDtpmv_v2(handle, CUBLAS_FILL_MODE_UPPER, trans, CUBLAS_DIAG_NON_UNIT, n, Ap, x, incx);
191 }
192 
193 inline cublasStatus_t cublas_spr(cublasHandle_t handle, cublasFillMode_t uplo,
194  int n, float alpha, const float *x, int incx,
195  float *AP) {
196  return cublasSspr_v2(handle, uplo, n, &alpha, x, incx, AP);
197 }
198 inline cublasStatus_t cublas_spr(cublasHandle_t handle, cublasFillMode_t uplo,
199  int n, double alpha, const double *x, int incx,
200  double *AP) {
201  return cublasDspr_v2(handle, uplo, n, &alpha, x, incx, AP);
202 }
203 
204 //
205 // cuSPARSE wrappers
206 //
207 
208 inline cusparseStatus_t cusparse_csr2csc(cusparseHandle_t handle, int m, int n,
209  int nnz, const float *csrVal,
210  const int *csrRowPtr,
211  const int *csrColInd, float *cscVal,
212  int *cscRowInd, int *cscColPtr,
213  cusparseAction_t copyValues,
214  cusparseIndexBase_t idxBase) {
215  return cusparseScsr2csc(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd,
216  cscVal, cscRowInd, cscColPtr, copyValues, idxBase);
217 }
218 inline cusparseStatus_t cusparse_csr2csc(cusparseHandle_t handle, int m, int n,
219  int nnz, const double *csrVal,
220  const int *csrRowPtr,
221  const int *csrColInd, double *cscVal,
222  int *cscRowInd, int *cscColPtr,
223  cusparseAction_t copyValues,
224  cusparseIndexBase_t idxBase) {
225  return cusparseDcsr2csc(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd,
226  cscVal, cscRowInd, cscColPtr, copyValues, idxBase);
227 }
228 
229 inline cusparseStatus_t cusparse_csrmm(cusparseHandle_t handle,
230  cusparseOperation_t transA, int m, int n,
231  int k, int nnz, const float *alpha,
232  const cusparseMatDescr_t descrA,
233  const float *csrValA,
234  const int *csrRowPtrA,
235  const int *csrColIndA, const float *B,
236  int ldb, const float *beta, float *C,
237  int ldc) {
238  return cusparseScsrmm(handle, transA, m, n, k, nnz, alpha, descrA, csrValA,
239  csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
240 }
241 inline cusparseStatus_t cusparse_csrmm(cusparseHandle_t handle,
242  cusparseOperation_t transA, int m, int n,
243  int k, int nnz, const double *alpha,
244  const cusparseMatDescr_t descrA,
245  const double *csrValA,
246  const int *csrRowPtrA,
247  const int *csrColIndA, const double *B,
248  int ldb, const double *beta, double *C,
249  int ldc) {
250  return cusparseDcsrmm(handle, transA, m, n, k, nnz, alpha, descrA, csrValA,
251  csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
252 }
253 
254 inline cusparseStatus_t cusparse_csrmm2(cusparseHandle_t handle,
255  cusparseOperation_t transA,
256  cusparseOperation_t transB, int m,
257  int n, int k, int nnz,
258  const float *alpha,
259  const cusparseMatDescr_t descrA,
260  const float *csrValA,
261  const int *csrRowPtrA,
262  const int *csrColIndA, const float *B,
263  int ldb, const float *beta, float *C,
264  int ldc) {
265  return cusparseScsrmm2(handle, transA, transB, m, n, k, nnz, alpha, descrA,
266  csrValA, csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
267 }
268 inline cusparseStatus_t cusparse_csrmm2(cusparseHandle_t handle,
269  cusparseOperation_t transA,
270  cusparseOperation_t transB, int m,
271  int n, int k, int nnz,
272  const double *alpha,
273  const cusparseMatDescr_t descrA,
274  const double *csrValA,
275  const int *csrRowPtrA,
276  const int *csrColIndA, const double *B,
277  int ldb, const double *beta, double *C,
278  int ldc) {
279  return cusparseDcsrmm2(handle, transA, transB, m, n, k, nnz, alpha, descrA,
280  csrValA, csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
281 }
282 
283 
284 #endif
285 }
286 // namespace kaldi
287 
288 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
struct rnnlm::@11::@12 n
#define CU1DBLOCK
Definition: cu-matrixdim.h:57