Skip to content

Commit b5e41a7

Browse files
committed
Just use cuBLAS for everything...
1 parent 2a1bde9 commit b5e41a7

File tree

3 files changed

+103
-270
lines changed

3 files changed

+103
-270
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4624,9 +4624,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
46244624
case GGML_OP_FILL:
46254625
case GGML_OP_CUMSUM:
46264626
case GGML_OP_TRI:
4627-
return true;
46284627
case GGML_OP_SOLVE_TRI:
46294628
return true;
4629+
46304630
default:
46314631
return false;
46324632
}

ggml/src/ggml-cuda/solve_tri.cu

Lines changed: 83 additions & 266 deletions
Original file line numberDiff line numberDiff line change
@@ -1,295 +1,112 @@
11
#include "common.cuh"
2+
#include "ggml-cuda/vendors/cuda.h"
3+
#include <cublas_api.h>
24
#include "ggml.h"
35
#include "solve_tri.cuh"
4-
5-
#define MAX_N_FAST 64
6-
7-
// ======================
8-
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
9-
// ======================
10-
// When ncols_template == 0 the bounds for the loops in this function are not
11-
// known and can't be unrolled. As we want to keep pragma unroll for all other
12-
// cases we supress the clang transformation warning here.
13-
#ifdef __clang__
14-
# pragma clang diagnostic push
15-
# pragma clang diagnostic ignored "-Wpass-failed"
16-
#endif // __clang__
17-
template <int n_template, int k_template>
18-
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
19-
const float * __restrict__ B,
20-
float * __restrict__ X,
21-
const uint3 ne02,
22-
const size_t nb02,
23-
const size_t nb03,
24-
const size_t nb12,
25-
const size_t nb13,
26-
const size_t nb2,
27-
const size_t nb3,
28-
const int n_arg,
29-
const int k_arg) {
30-
const int n = n_template == 0 ? n_arg : n_template;
31-
const int k = k_template == 0 ? k_arg : k_template;
32-
33-
const int batch_idx = blockIdx.x;
34-
const int lane = threadIdx.x;
35-
const int col_idx = threadIdx.y;
36-
37-
if (col_idx >= k) {
6+
#include <cublas_v2.h>
7+
#include <cuda_runtime_api.h>
8+
#include <driver_types.h>
9+
10+
static __global__ void get_batch_pointers(const float * A, float * X, const float ** A_ptrs, float ** X_ptrs,
11+
int64_t ne02, int64_t total_batches,
12+
size_t s02, size_t s03, size_t s2, size_t s3) {
13+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
14+
if (idx >= total_batches) {
3815
return;
3916
}
4017

41-
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
42-
const int64_t i02 = i02_i03.y;
43-
const int64_t i03 = i02_i03.x;
44-
45-
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
46-
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
47-
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
48-
49-
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
50-
51-
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
52-
53-
#pragma unroll
54-
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
55-
const int i0 = i + offset;
56-
if (i0 < n * n) {
57-
sA[i0] = A_batch[i0];
58-
}
59-
}
60-
61-
__syncthreads();
62-
63-
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
64-
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
65-
66-
const int half = WARP_SIZE;
67-
const int nrows_low = (n < half) ? n : half;
18+
const int64_t i3 = idx / ne02;
19+
const int64_t i2 = idx % ne02;
6820

69-
#pragma unroll
70-
for (int row = 0; row < nrows_low; ++row) {
71-
float sum = 0.0f;
72-
if (lane < row) {
73-
sum += sA[row * n + lane] * x_low;
74-
}
75-
sum = warp_reduce_sum(sum);
21+
A_ptrs[idx] = A + i3 * s03 + i2 * s02;
22+
X_ptrs[idx] = X + i3 * s3 + i2 * s2;
23+
}
7624

77-
if (lane == row) {
78-
x_low = (x_low - sum) / sA[row * n + row];
79-
}
25+
static void solve_tri_f32_cublas(ggml_backend_cuda_context &ctx,
26+
const float * A,
27+
const float * B,
28+
float * X,
29+
int n,
30+
int k,
31+
int64_t ne02,
32+
int64_t ne03,
33+
size_t s02,
34+
size_t s03,
35+
size_t s12,
36+
size_t s13,
37+
size_t s2,
38+
size_t s3,
39+
cudaStream_t stream) {
40+
const float alpha = 1.0f;
41+
const int64_t total_batches = ne02 * ne03;
42+
if (total_batches == 0) {
43+
return;
8044
}
8145

82-
#pragma unroll
83-
for (int row = half; row < n; ++row) {
84-
float sum = sA[row * n + lane] * x_low;
85-
const int j = half + lane;
86-
if (j < row) {
87-
sum += sA[row * n + j] * x_high;
88-
}
89-
sum = warp_reduce_sum(sum);
90-
91-
if (lane == row - half) {
92-
x_high = (x_high - sum) / sA[row * n + row];
93-
}
46+
// Bulk copy B -> X (contiguous tensors)
47+
if (X != B) {
48+
const int64_t total_elements_BX = n * k * total_batches;
49+
CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float),
50+
cudaMemcpyDeviceToDevice, stream));
9451
}
9552

96-
#pragma unroll
97-
for (int rr = 0; rr < 2; ++rr) {
98-
const int row = rr * WARP_SIZE + lane;
99-
if (row < n) {
100-
const float val = (row < half) ? x_low : x_high;
101-
X_batch[row * k + col_idx] = val;
102-
}
103-
}
104-
}
105-
#ifdef __clang__
106-
# pragma clang diagnostic pop
107-
#endif // __clang__
53+
int id = ggml_cuda_get_device();
10854

109-
// ======================
110-
// General Kernel for larger matrices
111-
// Uses a simpler approach with fixed tile size
112-
// ======================
113-
#define GENERAL_TILE_SIZE 32
55+
ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
56+
ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
11457

115-
template <int n_template, int k_template>
116-
static __global__ void solve_tri_f32_general(const float * __restrict__ A,
117-
const float * __restrict__ B,
118-
float * __restrict__ X,
119-
const uint3 ne02,
120-
const size_t nb02,
121-
const size_t nb03,
122-
const size_t nb12,
123-
const size_t nb13,
124-
const size_t nb2,
125-
const size_t nb3,
126-
const int n_arg,
127-
const int k_arg) {
128-
const int n = n_template == 0 ? n_arg : n_template;
129-
const int k = k_template == 0 ? k_arg : k_template;
58+
const float ** A_ptrs_dev = A_ptrs_alloc.get();
59+
float ** X_ptrs_dev = X_ptrs_alloc.get();
13060

131-
const int batch_idx = blockIdx.x;
132-
const int col_idx = blockIdx.y;
133-
const int tid = threadIdx.x;
61+
get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(
62+
A, X, A_ptrs_dev, X_ptrs_dev, ne02, total_batches, s02, s03, s2, s3);
13463

135-
if (col_idx >= k) {
136-
return;
137-
}
138-
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
139-
const int64_t i02 = i02_i03.y;
140-
const int64_t i03 = i02_i03.x;
64+
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
14165

142-
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
143-
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
144-
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
66+
// Yes, this is necessary, without this we get RMSE errors
67+
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
68+
CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id),
69+
CUBLAS_SIDE_RIGHT,
70+
CUBLAS_FILL_MODE_UPPER,
71+
CUBLAS_OP_N,
72+
CUBLAS_DIAG_NON_UNIT,
73+
k,
74+
n,
75+
&alpha,
76+
A_ptrs_dev, n,
77+
X_ptrs_dev, k,
78+
total_batches));
14579

146-
// Shared memory for current tile
147-
__shared__ float sA[GENERAL_TILE_SIZE * GENERAL_TILE_SIZE];
148-
__shared__ float sB[GENERAL_TILE_SIZE];
149-
__shared__ float sX[GENERAL_TILE_SIZE];
80+
// revert to standard mode from common.cuh
81+
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
15082

151-
// Process in tiles
152-
for (int tile_start = 0; tile_start < n; tile_start += GENERAL_TILE_SIZE) {
153-
int tile_end = min(tile_start + GENERAL_TILE_SIZE, n);
154-
int tile_n = tile_end - tile_start;
155-
// Load tile of A matrix
156-
for (int i = tid; i < tile_n * tile_n; i += blockDim.x) {
157-
int local_row = i / tile_n;
158-
int local_col = i % tile_n;
159-
int global_row = tile_start + local_row;
160-
int global_col = tile_start + local_col;
161-
if (global_col <= global_row) {
162-
sA[local_row * GENERAL_TILE_SIZE + local_col] = A_batch[global_row * n + global_col];
163-
} else {
164-
sA[local_row * GENERAL_TILE_SIZE + local_col] = 0.0f;
165-
}
166-
}
167-
__syncthreads();
168-
// Load corresponding part of B and initialize X
169-
if (tid < tile_n) {
170-
sB[tid] = B_batch[(tile_start + tid) * k + col_idx];
171-
sX[tid] = sB[tid];
172-
}
173-
__syncthreads();
174-
// Forward substitution for this tile
175-
for (int row = 0; row < tile_n; ++row) {
176-
if (tid == row) {
177-
float sum = 0.0f;
178-
// Sum contributions from previous rows in this tile
179-
for (int j = 0; j < row; ++j) {
180-
sum += sA[row * GENERAL_TILE_SIZE + j] * sX[j];
181-
}
182-
// Sum contributions from previous tiles
183-
if (tile_start > 0) {
184-
int global_row = tile_start + row;
185-
for (int j = 0; j < tile_start; ++j) {
186-
sum += A_batch[global_row * n + j] * X_batch[j * k + col_idx];
187-
}
188-
}
189-
const float a_diag = sA[row * GENERAL_TILE_SIZE + row];
190-
sX[row] = (sB[row] - sum) / a_diag;
191-
}
192-
__syncthreads();
193-
}
194-
// Store results back to global memory
195-
if (tid < tile_n) {
196-
int global_row = tile_start + tid;
197-
X_batch[global_row * k + col_idx] = sX[tid];
198-
}
199-
__syncthreads();
200-
}
201-
}
202-
static void solve_tri_f32_cuda(const float * A,
203-
const float * B,
204-
float * X,
205-
int n,
206-
int k,
207-
int64_t ne02,
208-
int64_t ne03,
209-
size_t nb02,
210-
size_t nb03,
211-
size_t nb12,
212-
size_t nb13,
213-
size_t nb2,
214-
size_t nb3,
215-
cudaStream_t stream) {
216-
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
217-
// Choose kernel based on matrix size
218-
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
219-
// Use fast kernel for small matrices
220-
dim3 threads(WARP_SIZE, k);
221-
dim3 grid(ne02 * ne03);
222-
if (n == 64) {
223-
switch (k) {
224-
case 32:
225-
solve_tri_f32_fast<64, 32>
226-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
227-
break;
228-
case 16:
229-
solve_tri_f32_fast<64, 16>
230-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
231-
break;
232-
case 14:
233-
solve_tri_f32_fast<64, 14>
234-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
235-
break;
236-
case 12:
237-
solve_tri_f32_fast<64, 12>
238-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
239-
break;
240-
case 10:
241-
solve_tri_f32_fast<64, 10>
242-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
243-
break;
244-
case 8:
245-
solve_tri_f32_fast<64, 8>
246-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
247-
break;
248-
case 6:
249-
solve_tri_f32_fast<64, 6>
250-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
251-
break;
252-
case 4:
253-
solve_tri_f32_fast<64, 4>
254-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
255-
break;
256-
case 2:
257-
solve_tri_f32_fast<64, 2>
258-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
259-
break;
260-
case 1:
261-
solve_tri_f32_fast<64, 1>
262-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
263-
break;
264-
default:
265-
solve_tri_f32_fast<0, 0>
266-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
267-
}
268-
} else { // run general case
269-
solve_tri_f32_fast<0, 0>
270-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
271-
}
272-
} else {
273-
// Use general kernel for larger matrices
274-
dim3 threads(256, 1); // 256 threads per block
275-
dim3 grid(ne02 * ne03, k); // One block per column
276-
solve_tri_f32_general<0, 0>
277-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
278-
}
83+
GGML_UNUSED_VARS(s12, s13);
27984
}
28085

86+
87+
// ----------------------------------------------------------------------------
88+
// Public entry point
89+
// ----------------------------------------------------------------------------
28190
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
282-
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
283-
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
91+
const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
92+
const ggml_tensor * src1 = dst->src[1]; // B (n×k)
28493

28594
ggml_is_contiguous(src0);
28695
ggml_is_contiguous(src1);
28796

28897
const int64_t n = src0->ne[0];
28998
const int64_t k = src1->ne[0];
290-
291-
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
292-
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
293-
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
294-
dst->nb[3] / sizeof(float), ctx.stream());
99+
const int64_t ne02 = src0->ne[2];
100+
const int64_t ne03 = src0->ne[3];
101+
102+
solve_tri_f32_cublas(ctx,
103+
(const float *) src0->data,
104+
(const float *) src1->data,
105+
(float *) dst->data,
106+
n, k,
107+
ne02, ne03,
108+
src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
109+
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float),
110+
dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float),
111+
ctx.stream());
295112
}

0 commit comments

Comments
 (0)