@@ -25,14 +25,14 @@ namespace cuda
2525using cublas::getHandle;
2626
2727cublasOperation_t
28- toCblasTranspose (af_blas_transpose opt)
28+ toCblasTranspose (af_transpose_t opt)
2929{
3030 cublasOperation_t out = CUBLAS_OP_N;
3131 switch (opt) {
32- case AF_NO_TRANSPOSE : out = CUBLAS_OP_N; break ;
33- case AF_TRANSPOSE : out = CUBLAS_OP_T; break ;
34- case AF_CONJUGATE_TRANSPOSE : out = CUBLAS_OP_C; break ;
35- default : AF_ERROR (" INVALID af_blas_transpose " , AF_ERR_INVALID_ARG);
32+ case AF_NO_TRANS : out = CUBLAS_OP_N; break ;
33+ case AF_TRANS : out = CUBLAS_OP_T; break ;
34+ case AF_CONJ_TRANS : out = CUBLAS_OP_C; break ;
35+ default : AF_ERROR (" INVALID af_transpose_t " , AF_ERR_INVALID_ARG);
3636 }
3737 return out;
3838}
@@ -117,7 +117,7 @@ using namespace std;
117117
118118template <typename T>
119119Array<T> matmul (const Array<T> &lhs, const Array<T> &rhs,
120- af_blas_transpose optLhs, af_blas_transpose optRhs)
120+ af_transpose_t optLhs, af_transpose_t optRhs)
121121{
122122 cublasOperation_t lOpts = toCblasTranspose (optLhs);
123123 cublasOperation_t rOpts = toCblasTranspose (optRhs);
@@ -170,7 +170,7 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
170170
171171template <typename T>
172172Array<T> dot (const Array<T> &lhs, const Array<T> &rhs,
173- af_blas_transpose optLhs, af_blas_transpose optRhs)
173+ af_transpose_t optLhs, af_transpose_t optRhs)
174174{
175175 int N = lhs.dims ()[0 ];
176176
@@ -186,7 +186,7 @@ Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
186186}
187187
188188template <typename T>
189- void trsm (const Array<T> &lhs, Array<T> &rhs, af_blas_transpose trans,
189+ void trsm (const Array<T> &lhs, Array<T> &rhs, af_transpose_t trans,
190190 bool is_upper, bool is_left, bool is_unit)
191191{
192192 // dim4 lDims = lhs.dims();
@@ -214,7 +214,7 @@ void trsm(const Array<T> &lhs, Array<T> &rhs, af_blas_transpose trans,
214214
215215#define INSTANTIATE_BLAS (TYPE ) \
216216 template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
217- af_blas_transpose optLhs, af_blas_transpose optRhs);
217+ af_transpose_t optLhs, af_transpose_t optRhs);
218218
219219INSTANTIATE_BLAS (float )
220220INSTANTIATE_BLAS (cfloat)
@@ -223,14 +223,14 @@ INSTANTIATE_BLAS(cdouble)
223223
224224#define INSTANTIATE_DOT (TYPE ) \
225225 template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
226- af_blas_transpose optLhs, af_blas_transpose optRhs);
226+ af_transpose_t optLhs, af_transpose_t optRhs);
227227
228228INSTANTIATE_DOT (float )
229229INSTANTIATE_DOT (double )
230230
231231#define INSTANTIATE_TRSM (TYPE ) \
232232 template void trsm<TYPE>(const Array<TYPE> &lhs, Array<TYPE> &rhs, \
233- af_blas_transpose trans, bool is_upper, bool is_left, bool is_unit);
233+ af_transpose_t trans, bool is_upper, bool is_left, bool is_unit);
234234
235235INSTANTIATE_TRSM (float )
236236INSTANTIATE_TRSM (cfloat)
0 commit comments