Skip to content

Commit 60dcf1b

Browse files
committed
Merge both versions
1 parent b5e41a7 commit 60dcf1b

File tree

1 file changed

+223
-54
lines changed

1 file changed

+223
-54
lines changed

ggml/src/ggml-cuda/solve_tri.cu

Lines changed: 223 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
#include "common.cuh"
22
#include "ggml-cuda/vendors/cuda.h"
3-
#include <cublas_api.h>
43
#include "ggml.h"
54
#include "solve_tri.cuh"
5+
6+
#include <cublas_api.h>
67
#include <cublas_v2.h>
78
#include <cuda_runtime_api.h>
89
#include <driver_types.h>
910

10-
static __global__ void get_batch_pointers(const float * A, float * X, const float ** A_ptrs, float ** X_ptrs,
11-
int64_t ne02, int64_t total_batches,
12-
size_t s02, size_t s03, size_t s2, size_t s3) {
11+
#define MAX_N_FAST 64
12+
#define MAX_K_FAST 32
13+
14+
static __global__ void get_batch_pointers(const float * A,
15+
float * X,
16+
const float ** A_ptrs,
17+
float ** X_ptrs,
18+
int64_t ne02,
19+
int64_t total_batches,
20+
size_t s02,
21+
size_t s03,
22+
size_t s2,
23+
size_t s3) {
1324
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
1425
if (idx >= total_batches) {
1526
return;
@@ -22,22 +33,22 @@ static __global__ void get_batch_pointers(const float * A, float * X, const floa
2233
X_ptrs[idx] = X + i3 * s3 + i2 * s2;
2334
}
2435

25-
static void solve_tri_f32_cublas(ggml_backend_cuda_context &ctx,
26-
const float * A,
27-
const float * B,
28-
float * X,
29-
int n,
30-
int k,
31-
int64_t ne02,
32-
int64_t ne03,
33-
size_t s02,
34-
size_t s03,
35-
size_t s12,
36-
size_t s13,
37-
size_t s2,
38-
size_t s3,
39-
cudaStream_t stream) {
40-
const float alpha = 1.0f;
36+
static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
37+
const float * A,
38+
const float * B,
39+
float * X,
40+
int n,
41+
int k,
42+
int64_t ne02,
43+
int64_t ne03,
44+
size_t s02,
45+
size_t s03,
46+
size_t s12,
47+
size_t s13,
48+
size_t s2,
49+
size_t s3,
50+
cudaStream_t stream) {
51+
const float alpha = 1.0f;
4152
const int64_t total_batches = ne02 * ne03;
4253
if (total_batches == 0) {
4354
return;
@@ -46,67 +57,225 @@ static void solve_tri_f32_cublas(ggml_backend_cuda_context &ctx,
4657
// Bulk copy B -> X (contiguous tensors)
4758
if (X != B) {
4859
const int64_t total_elements_BX = n * k * total_batches;
49-
CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float),
50-
cudaMemcpyDeviceToDevice, stream));
60+
CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
5161
}
5262

5363
int id = ggml_cuda_get_device();
5464

5565
ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
56-
ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
66+
ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
5767

5868
const float ** A_ptrs_dev = A_ptrs_alloc.get();
59-
float ** X_ptrs_dev = X_ptrs_alloc.get();
69+
float ** X_ptrs_dev = X_ptrs_alloc.get();
6070

61-
get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(
62-
A, X, A_ptrs_dev, X_ptrs_dev, ne02, total_batches, s02, s03, s2, s3);
71+
get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
72+
total_batches, s02, s03, s2, s3);
6373

6474
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
6575

6676
// Yes, this is necessary, without this we get RMSE errors
6777
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
68-
CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id),
69-
CUBLAS_SIDE_RIGHT,
70-
CUBLAS_FILL_MODE_UPPER,
71-
CUBLAS_OP_N,
72-
CUBLAS_DIAG_NON_UNIT,
73-
k,
74-
n,
75-
&alpha,
76-
A_ptrs_dev, n,
77-
X_ptrs_dev, k,
78-
total_batches));
78+
CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
79+
CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
7980

8081
// revert to standard mode from common.cuh
8182
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
8283

8384
GGML_UNUSED_VARS(s12, s13);
8485
}
8586

87+
// ======================
88+
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
89+
// ======================
90+
// When ncols_template == 0 the bounds for the loops in this function are not
91+
// known and can't be unrolled. As we want to keep pragma unroll for all other
92+
// cases we supress the clang transformation warning here.
93+
#ifdef __clang__
94+
# pragma clang diagnostic push
95+
# pragma clang diagnostic ignored "-Wpass-failed"
96+
#endif // __clang__
97+
template <int n_template, int k_template>
98+
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
99+
const float * __restrict__ B,
100+
float * __restrict__ X,
101+
const uint3 ne02,
102+
const size_t nb02,
103+
const size_t nb03,
104+
const size_t nb12,
105+
const size_t nb13,
106+
const size_t nb2,
107+
const size_t nb3,
108+
const int n_arg,
109+
const int k_arg) {
110+
const int n = n_template == 0 ? n_arg : n_template;
111+
const int k = k_template == 0 ? k_arg : k_template;
112+
113+
const int batch_idx = blockIdx.x;
114+
const int lane = threadIdx.x;
115+
const int col_idx = threadIdx.y;
116+
117+
if (col_idx >= k) {
118+
return;
119+
}
120+
121+
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
122+
const int64_t i02 = i02_i03.y;
123+
const int64_t i03 = i02_i03.x;
124+
125+
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
126+
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
127+
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
128+
129+
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
130+
131+
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
132+
133+
#pragma unroll
134+
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
135+
const int i0 = i + offset;
136+
if (i0 < n * n) {
137+
sA[i0] = A_batch[i0];
138+
}
139+
}
140+
141+
__syncthreads();
142+
143+
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
144+
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
145+
146+
const int half = WARP_SIZE;
147+
const int nrows_low = (n < half) ? n : half;
148+
149+
#pragma unroll
150+
for (int row = 0; row < nrows_low; ++row) {
151+
float sum = 0.0f;
152+
if (lane < row) {
153+
sum += sA[row * n + lane] * x_low;
154+
}
155+
sum = warp_reduce_sum(sum);
156+
157+
if (lane == row) {
158+
x_low = (x_low - sum) / sA[row * n + row];
159+
}
160+
}
161+
162+
#pragma unroll
163+
for (int row = half; row < n; ++row) {
164+
float sum = sA[row * n + lane] * x_low;
165+
const int j = half + lane;
166+
if (j < row) {
167+
sum += sA[row * n + j] * x_high;
168+
}
169+
sum = warp_reduce_sum(sum);
170+
171+
if (lane == row - half) {
172+
x_high = (x_high - sum) / sA[row * n + row];
173+
}
174+
}
175+
176+
#pragma unroll
177+
for (int rr = 0; rr < 2; ++rr) {
178+
const int row = rr * WARP_SIZE + lane;
179+
if (row < n) {
180+
const float val = (row < half) ? x_low : x_high;
181+
X_batch[row * k + col_idx] = val;
182+
}
183+
}
184+
}
185+
#ifdef __clang__
186+
# pragma clang diagnostic pop
187+
#endif // __clang__
188+
189+
static void solve_tri_f32_cuda(const float * A,
190+
const float * B,
191+
float * X,
192+
int n,
193+
int k,
194+
int64_t ne02,
195+
int64_t ne03,
196+
size_t nb02,
197+
size_t nb03,
198+
size_t nb12,
199+
size_t nb13,
200+
size_t nb2,
201+
size_t nb3,
202+
cudaStream_t stream) {
203+
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
204+
dim3 threads(WARP_SIZE, k);
205+
dim3 grid(ne02 * ne03);
206+
if (n == 64) {
207+
switch (k) {
208+
case 32:
209+
solve_tri_f32_fast<64, 32>
210+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
211+
break;
212+
case 16:
213+
solve_tri_f32_fast<64, 16>
214+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
215+
break;
216+
case 14:
217+
solve_tri_f32_fast<64, 14>
218+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
219+
break;
220+
case 12:
221+
solve_tri_f32_fast<64, 12>
222+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
223+
break;
224+
case 10:
225+
solve_tri_f32_fast<64, 10>
226+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
227+
break;
228+
case 8:
229+
solve_tri_f32_fast<64, 8>
230+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
231+
break;
232+
case 6:
233+
solve_tri_f32_fast<64, 6>
234+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
235+
break;
236+
case 4:
237+
solve_tri_f32_fast<64, 4>
238+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
239+
break;
240+
case 2:
241+
solve_tri_f32_fast<64, 2>
242+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
243+
break;
244+
case 1:
245+
solve_tri_f32_fast<64, 1>
246+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
247+
break;
248+
default:
249+
solve_tri_f32_fast<0, 0>
250+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
251+
}
252+
} else { // run general case
253+
solve_tri_f32_fast<0, 0>
254+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
255+
}
256+
}
86257

87-
// ----------------------------------------------------------------------------
88-
// Public entry point
89-
// ----------------------------------------------------------------------------
90258
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
91-
const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
92-
const ggml_tensor * src1 = dst->src[1]; // B (n×k)
259+
const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
260+
const ggml_tensor * src1 = dst->src[1]; // B (n×k)
93261

94262
ggml_is_contiguous(src0);
95263
ggml_is_contiguous(src1);
96264

97-
const int64_t n = src0->ne[0];
98-
const int64_t k = src1->ne[0];
265+
const int64_t n = src0->ne[0];
266+
const int64_t k = src1->ne[0];
99267
const int64_t ne02 = src0->ne[2];
100268
const int64_t ne03 = src0->ne[3];
101269

102-
solve_tri_f32_cublas(ctx,
103-
(const float *) src0->data,
104-
(const float *) src1->data,
105-
(float *) dst->data,
106-
n, k,
107-
ne02, ne03,
108-
src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
109-
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float),
110-
dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float),
111-
ctx.stream());
270+
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
271+
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
272+
src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
273+
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
274+
dst->nb[3] / sizeof(float), ctx.stream());
275+
} else {
276+
solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
277+
ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
278+
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
279+
dst->nb[3] / sizeof(float), ctx.stream());
280+
}
112281
}

0 commit comments

Comments
 (0)