From e119ff73efa8aa4d48c651e2d762e5107631f22d Mon Sep 17 00:00:00 2001 From: amcamd Date: Thu, 5 Jun 2025 17:13:30 -0400 Subject: [PATCH] update for hipblasVersionMajor >=3 --- csrc/ops.hip | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/csrc/ops.hip b/csrc/ops.hip index eef616d48..a9c3e0202 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -269,6 +269,15 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in const void * beta = &fbeta; hipblasStatus_t status; +#if hipblasVersionMajor >= 3 + status = hipblasGemmEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta, + C, HIP_R_32I, ldc, + HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT); +#else status = hipblasGemmEx(context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, @@ -276,6 +285,7 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, C, HIPBLAS_R_32I, ldc, HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); +#endif if (status != HIPBLAS_STATUS_SUCCESS) { @@ -299,6 +309,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i //printf("%i %i %i\n", strideA, strideB, strideC); //printf("%i\n", batchCount); +#if hipblasVersionMajor >= 3 + status = hipblasGemmStridedBatchedEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta, + C, HIP_R_32I, ldc, (long long int)strideC, batchCount, + HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT); +#else status = hipblasGemmStridedBatchedEx(context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, @@ -306,6 +325,7 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); +#endif if (status != HIPBLAS_STATUS_SUCCESS) {