diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 279679a4eac..955f99c04cc 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4629,9 +4629,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CUMSUM: case GGML_OP_TRI: case GGML_OP_DIAG: - return true; case GGML_OP_SOLVE_TRI: - return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; + return true; + default: return false; } diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index e161d4dc436..177ffc268f1 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -3,6 +3,80 @@ #include "solve_tri.cuh" #define MAX_N_FAST 64 +#define MAX_K_FAST 32 + +static __global__ void get_batch_pointers(const float * A, + float * X, + const float ** A_ptrs, + float ** X_ptrs, + int64_t ne02, + int64_t total_batches, + size_t s02, + size_t s03, + size_t s2, + size_t s3) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_batches) { + return; + } + + const int64_t i3 = idx / ne02; + const int64_t i2 = idx % ne02; + + A_ptrs[idx] = A + i3 * s03 + i2 * s02; + X_ptrs[idx] = X + i3 * s3 + i2 * s2; +} + +static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx, + const float * A, + const float * B, + float * X, + int n, + int k, + int64_t ne02, + int64_t ne03, + size_t s02, + size_t s03, + size_t s12, + size_t s13, + size_t s2, + size_t s3, + cudaStream_t stream) { + const float alpha = 1.0f; + const int64_t total_batches = ne02 * ne03; + if (total_batches == 0) { + return; + } + + // Bulk copy B -> X (contiguous tensors) + if (X != B) { + const int64_t total_elements_BX = n * k * total_batches; + CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + } + + const int id = ggml_cuda_get_device(); + + ggml_cuda_pool_alloc A_ptrs_alloc(ctx.pool(id), total_batches); + ggml_cuda_pool_alloc X_ptrs_alloc(ctx.pool(id), total_batches); + + const float ** A_ptrs_dev = A_ptrs_alloc.get(); + float ** X_ptrs_dev = X_ptrs_alloc.get(); + + get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02, + total_batches, s02, s03, s2, s3); + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + + // Yes, this is necessary, without this we get RMSE errors + CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH)); + CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N, + CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches)); + + // revert to standard mode from common.cuh + CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH)); + + GGML_UNUSED_VARS(s12, s13); +} // ====================== // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction @@ -63,7 +137,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; - const int half = WARP_SIZE; + const int half = WARP_SIZE; const int nrows_low = (n < half) ? n : half; #pragma unroll @@ -81,8 +155,8 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, #pragma unroll for (int row = half; row < n; ++row) { - float sum = sA[row * n + lane] * x_low; - const int j = half + lane; + float sum = sA[row * n + lane] * x_low; + const int j = half + lane; if (j < row) { sum += sA[row * n + j] * x_high; } @@ -97,7 +171,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, for (int rr = 0; rr < 2; ++rr) { const int row = rr * WARP_SIZE + lane; if (row < n) { - const float val = (row < half) ? x_low : x_high; + const float val = (row < half) ? x_low : x_high; X_batch[row * k + col_idx] = val; } } @@ -176,20 +250,26 @@ static void solve_tri_f32_cuda(const float * A, } void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix) - const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns) + const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular) + const ggml_tensor * src1 = dst->src[1]; // B (n×k) ggml_is_contiguous(src0); ggml_is_contiguous(src1); - const int64_t n = src0->ne[0]; - const int64_t k = src1->ne[0]; + const int64_t n = src0->ne[0]; + const int64_t k = src1->ne[0]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; - GGML_ASSERT(n <= 64); - GGML_ASSERT(k <= 32); - - solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2], - src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), - src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), - dst->nb[3] / sizeof(float), ctx.stream()); + if (n <= MAX_N_FAST && k <= MAX_K_FAST) { + solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, + src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), ctx.stream()); + } else { + solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, + ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), ctx.stream()); + } } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index b7d6edf7fcb..951a88d5678 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -19,6 +19,9 @@ #define CUDA_R_16F HIPBLAS_R_16F #define CUDA_R_16BF HIPBLAS_R_16B #define CUDA_R_32F HIPBLAS_R_32F +#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT +#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER +#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended #define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned @@ -30,6 +33,7 @@ #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define __all_sync(mask, var) __all(var) #define __any_sync(mask, var) __any(var) +#define cublasStrsmBatched hipblasStrsmBatched #define cublasCreate hipblasCreate #define cublasDestroy hipblasDestroy #define cublasGemmEx hipblasGemmEx diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 8c55a2e4e56..221e67f96a7 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -12,11 +12,16 @@ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT #define CUBLAS_OP_N MUBLAS_OP_N #define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH +#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT +#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER +#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH #define CUDA_R_16F MUSA_R_16F #define CUDA_R_16BF MUSA_R_16BF #define CUDA_R_32F MUSA_R_32F +#define cublasStrsmBatched mublasStrsmBatched #define cublasComputeType_t cudaDataType_t #define cublasCreate mublasCreate #define cublasDestroy mublasDestroy diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 289e2e6d7fd..9d29a365027 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7861,9 +7861,24 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 64, 64, 2, 2 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 79, 79, 5, 3 }, { 417, 79, 5, 3 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 80, 80, 2, 8 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 79, 80, 2, 8 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 81, 80, 2, 8 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 80, 80, 8, 8 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 79, 80, 8, 8 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 81, 80, 8, 8 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 84, 84, 4, 4 }, { 32, 84, 4, 4 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 95, 95, 8, 8 }, { 40, 95, 8, 8 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 })); - test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 32, 128, 4, 4 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 3, 4 }, { 32, 128, 3, 4 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 32, 128, 4, 1 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 })); for (bool v : {false, true}) { for (bool circular : {false, true}) { @@ -8064,12 +8079,13 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416)); - test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 })); - test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 32, 64, 4, 4 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 })); // qwen3next with CHUNK_SIZE 64 test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 })); // qwen3next with CHUNK_SIZE 128 test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 256, 256, 4, 2 }, { 128, 256, 4, 2 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 })); diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 2db04e95227..dda053498a7 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ