Skip to content
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4629,9 +4629,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
case GGML_OP_DIAG:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
return true;

default:
return false;
}
Expand Down
110 changes: 95 additions & 15 deletions ggml/src/ggml-cuda/solve_tri.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,80 @@
#include "solve_tri.cuh"

#define MAX_N_FAST 64
#define MAX_K_FAST 32

static __global__ void get_batch_pointers(const float * A,
float * X,
const float ** A_ptrs,
float ** X_ptrs,
int64_t ne02,
int64_t total_batches,
size_t s02,
size_t s03,
size_t s2,
size_t s3) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total_batches) {
return;
}

const int64_t i3 = idx / ne02;
const int64_t i2 = idx % ne02;

A_ptrs[idx] = A + i3 * s03 + i2 * s02;
X_ptrs[idx] = X + i3 * s3 + i2 * s2;
}

static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
const float * A,
const float * B,
float * X,
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t s02,
size_t s03,
size_t s12,
size_t s13,
size_t s2,
size_t s3,
cudaStream_t stream) {
const float alpha = 1.0f;
const int64_t total_batches = ne02 * ne03;
if (total_batches == 0) {
return;
}

// Bulk copy B -> X (contiguous tensors)
if (X != B) {
const int64_t total_elements_BX = n * k * total_batches;
CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
}

const int id = ggml_cuda_get_device();

ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);

const float ** A_ptrs_dev = A_ptrs_alloc.get();
float ** X_ptrs_dev = X_ptrs_alloc.get();

get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
total_batches, s02, s03, s2, s3);

CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));

// Yes, this is necessary, without this we get RMSE errors
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));

// revert to standard mode from common.cuh
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));

GGML_UNUSED_VARS(s12, s13);
}

// ======================
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
Expand Down Expand Up @@ -63,7 +137,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;

const int half = WARP_SIZE;
const int half = WARP_SIZE;
const int nrows_low = (n < half) ? n : half;

#pragma unroll
Expand All @@ -81,8 +155,8 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,

#pragma unroll
for (int row = half; row < n; ++row) {
float sum = sA[row * n + lane] * x_low;
const int j = half + lane;
float sum = sA[row * n + lane] * x_low;
const int j = half + lane;
if (j < row) {
sum += sA[row * n + j] * x_high;
}
Expand All @@ -97,7 +171,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
for (int rr = 0; rr < 2; ++rr) {
const int row = rr * WARP_SIZE + lane;
if (row < n) {
const float val = (row < half) ? x_low : x_high;
const float val = (row < half) ? x_low : x_high;
X_batch[row * k + col_idx] = val;
}
}
Expand Down Expand Up @@ -176,20 +250,26 @@ static void solve_tri_f32_cuda(const float * A,
}

void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
const ggml_tensor * src1 = dst->src[1]; // B (n×k)

ggml_is_contiguous(src0);
ggml_is_contiguous(src1);

const int64_t n = src0->ne[0];
const int64_t k = src1->ne[0];
const int64_t n = src0->ne[0];
const int64_t k = src1->ne[0];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];

GGML_ASSERT(n <= 64);
GGML_ASSERT(k <= 32);

solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float), ctx.stream());
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float), ctx.stream());
} else {
solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float), ctx.stream());
}
}
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/vendors/hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_16BF HIPBLAS_R_16B
#define CUDA_R_32F HIPBLAS_R_32F
#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT
#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER
#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
Expand All @@ -30,6 +33,7 @@
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define __all_sync(mask, var) __all(var)
#define __any_sync(mask, var) __any(var)
#define cublasStrsmBatched hipblasStrsmBatched
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/vendors/musa.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH
#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT
#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER
#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_16BF MUSA_R_16BF
#define CUDA_R_32F MUSA_R_32F
#define cublasStrsmBatched mublasStrsmBatched
#define cublasComputeType_t cudaDataType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
Expand Down
22 changes: 19 additions & 3 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7861,9 +7861,24 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 64, 64, 2, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 79, 79, 5, 3 }, { 417, 79, 5, 3 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 80, 80, 2, 8 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 79, 80, 2, 8 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 81, 80, 2, 8 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 80, 80, 8, 8 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 79, 80, 8, 8 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 81, 80, 8, 8 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 84, 84, 4, 4 }, { 32, 84, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 95, 95, 8, 8 }, { 40, 95, 8, 8 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 32, 128, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 3, 4 }, { 32, 128, 3, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 32, 128, 4, 1 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 }));

for (bool v : {false, true}) {
for (bool circular : {false, true}) {
Expand Down Expand Up @@ -8064,12 +8079,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));

test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 32, 64, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
// qwen3next with CHUNK_SIZE 64
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));
// qwen3next with CHUNK_SIZE 128
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 256, 256, 4, 2 }, { 128, 256, 4, 2 }));

test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
Expand Down
Binary file modified tools/server/public/index.html.gz
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you added this unintentionally, please remove it from this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It adds itself automatically due to a GitHub hook installed by the server development code 😐

Binary file not shown.
Loading