11#include " common.cuh"
22#include " ggml-cuda/vendors/cuda.h"
3- #include < cublas_api.h>
43#include " ggml.h"
54#include " solve_tri.cuh"
5+
6+ #include < cublas_api.h>
67#include < cublas_v2.h>
78#include < cuda_runtime_api.h>
89#include < driver_types.h>
910
10- static __global__ void get_batch_pointers (const float * A, float * X, const float ** A_ptrs, float ** X_ptrs,
11- int64_t ne02, int64_t total_batches,
12- size_t s02, size_t s03, size_t s2, size_t s3) {
11+ #define MAX_N_FAST 64
12+ #define MAX_K_FAST 32
13+
14+ static __global__ void get_batch_pointers (const float * A,
15+ float * X,
16+ const float ** A_ptrs,
17+ float ** X_ptrs,
18+ int64_t ne02,
19+ int64_t total_batches,
20+ size_t s02,
21+ size_t s03,
22+ size_t s2,
23+ size_t s3) {
1324 const int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1425 if (idx >= total_batches) {
1526 return ;
@@ -22,22 +33,22 @@ static __global__ void get_batch_pointers(const float * A, float * X, const floa
2233 X_ptrs[idx] = X + i3 * s3 + i2 * s2;
2334}
2435
25- static void solve_tri_f32_cublas (ggml_backend_cuda_context &ctx,
26- const float * A,
27- const float * B,
28- float * X,
29- int n,
30- int k,
31- int64_t ne02,
32- int64_t ne03,
33- size_t s02,
34- size_t s03,
35- size_t s12,
36- size_t s13,
37- size_t s2,
38- size_t s3,
39- cudaStream_t stream) {
40- const float alpha = 1 .0f ;
36+ static void solve_tri_f32_cublas (ggml_backend_cuda_context & ctx,
37+ const float * A,
38+ const float * B,
39+ float * X,
40+ int n,
41+ int k,
42+ int64_t ne02,
43+ int64_t ne03,
44+ size_t s02,
45+ size_t s03,
46+ size_t s12,
47+ size_t s13,
48+ size_t s2,
49+ size_t s3,
50+ cudaStream_t stream) {
51+ const float alpha = 1 .0f ;
4152 const int64_t total_batches = ne02 * ne03;
4253 if (total_batches == 0 ) {
4354 return ;
@@ -46,67 +57,225 @@ static void solve_tri_f32_cublas(ggml_backend_cuda_context &ctx,
4657 // Bulk copy B -> X (contiguous tensors)
4758 if (X != B) {
4859 const int64_t total_elements_BX = n * k * total_batches;
49- CUDA_CHECK (cudaMemcpyAsync (X, B, total_elements_BX * sizeof (float ),
50- cudaMemcpyDeviceToDevice, stream));
60+ CUDA_CHECK (cudaMemcpyAsync (X, B, total_elements_BX * sizeof (float ), cudaMemcpyDeviceToDevice, stream));
5161 }
5262
5363 int id = ggml_cuda_get_device ();
5464
5565 ggml_cuda_pool_alloc<const float *> A_ptrs_alloc (ctx.pool (id), total_batches);
56- ggml_cuda_pool_alloc<float *> X_ptrs_alloc (ctx.pool (id), total_batches);
66+ ggml_cuda_pool_alloc<float *> X_ptrs_alloc (ctx.pool (id), total_batches);
5767
5868 const float ** A_ptrs_dev = A_ptrs_alloc.get ();
59- float ** X_ptrs_dev = X_ptrs_alloc.get ();
69+ float ** X_ptrs_dev = X_ptrs_alloc.get ();
6070
61- get_batch_pointers<<<(total_batches + 255 ) / 256 , 256 , 0 , stream>>> (
62- A, X, A_ptrs_dev, X_ptrs_dev, ne02, total_batches, s02, s03, s2, s3);
71+ get_batch_pointers<<<(total_batches + 255 ) / 256 , 256 , 0 , stream>>> (A, X, A_ptrs_dev, X_ptrs_dev, ne02,
72+ total_batches, s02, s03, s2, s3);
6373
6474 CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (id), stream));
6575
6676 // Yes, this is necessary, without this we get RMSE errors
6777 CUBLAS_CHECK (cublasSetMathMode (ctx.cublas_handle (id), CUBLAS_DEFAULT_MATH));
68- CUBLAS_CHECK (cublasStrsmBatched (ctx.cublas_handle (id),
69- CUBLAS_SIDE_RIGHT,
70- CUBLAS_FILL_MODE_UPPER,
71- CUBLAS_OP_N,
72- CUBLAS_DIAG_NON_UNIT,
73- k,
74- n,
75- &alpha,
76- A_ptrs_dev, n,
77- X_ptrs_dev, k,
78- total_batches));
78+ CUBLAS_CHECK (cublasStrsmBatched (ctx.cublas_handle (id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
79+ CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
7980
8081 // revert to standard mode from common.cuh
8182 CUBLAS_CHECK (cublasSetMathMode (ctx.cublas_handle (id), CUBLAS_TF32_TENSOR_OP_MATH));
8283
8384 GGML_UNUSED_VARS (s12, s13);
8485}
8586
87+ // ======================
88+ // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
89+ // ======================
90+ // When ncols_template == 0 the bounds for the loops in this function are not
91+ // known and can't be unrolled. As we want to keep pragma unroll for all other
92+ // cases we supress the clang transformation warning here.
93+ #ifdef __clang__
94+ # pragma clang diagnostic push
95+ # pragma clang diagnostic ignored "-Wpass-failed"
96+ #endif // __clang__
97+ template <int n_template, int k_template>
98+ static __global__ void solve_tri_f32_fast (const float * __restrict__ A,
99+ const float * __restrict__ B,
100+ float * __restrict__ X,
101+ const uint3 ne02,
102+ const size_t nb02,
103+ const size_t nb03,
104+ const size_t nb12,
105+ const size_t nb13,
106+ const size_t nb2,
107+ const size_t nb3,
108+ const int n_arg,
109+ const int k_arg) {
110+ const int n = n_template == 0 ? n_arg : n_template;
111+ const int k = k_template == 0 ? k_arg : k_template;
112+
113+ const int batch_idx = blockIdx .x ;
114+ const int lane = threadIdx .x ;
115+ const int col_idx = threadIdx .y ;
116+
117+ if (col_idx >= k) {
118+ return ;
119+ }
120+
121+ const uint2 i02_i03 = fast_div_modulo (batch_idx, ne02);
122+ const int64_t i02 = i02_i03.y ;
123+ const int64_t i03 = i02_i03.x ;
124+
125+ const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
126+ const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
127+ float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
128+
129+ __shared__ float sA [MAX_N_FAST * MAX_N_FAST];
130+
131+ const int offset = threadIdx .x + threadIdx .y * blockDim .x ;
132+
133+ #pragma unroll
134+ for (int i = 0 ; i < n * n; i += k * WARP_SIZE) {
135+ const int i0 = i + offset;
136+ if (i0 < n * n) {
137+ sA [i0] = A_batch[i0];
138+ }
139+ }
140+
141+ __syncthreads ();
142+
143+ float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0 .0f ;
144+ float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0 .0f ;
145+
146+ const int half = WARP_SIZE;
147+ const int nrows_low = (n < half) ? n : half;
148+
149+ #pragma unroll
150+ for (int row = 0 ; row < nrows_low; ++row) {
151+ float sum = 0 .0f ;
152+ if (lane < row) {
153+ sum += sA [row * n + lane] * x_low;
154+ }
155+ sum = warp_reduce_sum (sum);
156+
157+ if (lane == row) {
158+ x_low = (x_low - sum) / sA [row * n + row];
159+ }
160+ }
161+
162+ #pragma unroll
163+ for (int row = half; row < n; ++row) {
164+ float sum = sA [row * n + lane] * x_low;
165+ const int j = half + lane;
166+ if (j < row) {
167+ sum += sA [row * n + j] * x_high;
168+ }
169+ sum = warp_reduce_sum (sum);
170+
171+ if (lane == row - half) {
172+ x_high = (x_high - sum) / sA [row * n + row];
173+ }
174+ }
175+
176+ #pragma unroll
177+ for (int rr = 0 ; rr < 2 ; ++rr) {
178+ const int row = rr * WARP_SIZE + lane;
179+ if (row < n) {
180+ const float val = (row < half) ? x_low : x_high;
181+ X_batch[row * k + col_idx] = val;
182+ }
183+ }
184+ }
185+ #ifdef __clang__
186+ # pragma clang diagnostic pop
187+ #endif // __clang__
188+
189+ static void solve_tri_f32_cuda (const float * A,
190+ const float * B,
191+ float * X,
192+ int n,
193+ int k,
194+ int64_t ne02,
195+ int64_t ne03,
196+ size_t nb02,
197+ size_t nb03,
198+ size_t nb12,
199+ size_t nb13,
200+ size_t nb2,
201+ size_t nb3,
202+ cudaStream_t stream) {
203+ const uint3 ne02_fd = init_fastdiv_values ((uint32_t ) ne02);
204+ dim3 threads (WARP_SIZE, k);
205+ dim3 grid (ne02 * ne03);
206+ if (n == 64 ) {
207+ switch (k) {
208+ case 32 :
209+ solve_tri_f32_fast<64 , 32 >
210+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
211+ break ;
212+ case 16 :
213+ solve_tri_f32_fast<64 , 16 >
214+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
215+ break ;
216+ case 14 :
217+ solve_tri_f32_fast<64 , 14 >
218+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
219+ break ;
220+ case 12 :
221+ solve_tri_f32_fast<64 , 12 >
222+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
223+ break ;
224+ case 10 :
225+ solve_tri_f32_fast<64 , 10 >
226+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
227+ break ;
228+ case 8 :
229+ solve_tri_f32_fast<64 , 8 >
230+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
231+ break ;
232+ case 6 :
233+ solve_tri_f32_fast<64 , 6 >
234+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
235+ break ;
236+ case 4 :
237+ solve_tri_f32_fast<64 , 4 >
238+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
239+ break ;
240+ case 2 :
241+ solve_tri_f32_fast<64 , 2 >
242+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
243+ break ;
244+ case 1 :
245+ solve_tri_f32_fast<64 , 1 >
246+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
247+ break ;
248+ default :
249+ solve_tri_f32_fast<0 , 0 >
250+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
251+ }
252+ } else { // run general case
253+ solve_tri_f32_fast<0 , 0 >
254+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
255+ }
256+ }
86257
87- // ----------------------------------------------------------------------------
88- // Public entry point
89- // ----------------------------------------------------------------------------
90258void ggml_cuda_op_solve_tri (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
91- const ggml_tensor * src0 = dst->src [0 ]; // A (n×n, lower triangular)
92- const ggml_tensor * src1 = dst->src [1 ]; // B (n×k)
259+ const ggml_tensor * src0 = dst->src [0 ]; // A (n×n, lower triangular)
260+ const ggml_tensor * src1 = dst->src [1 ]; // B (n×k)
93261
94262 ggml_is_contiguous (src0);
95263 ggml_is_contiguous (src1);
96264
97- const int64_t n = src0->ne [0 ];
98- const int64_t k = src1->ne [0 ];
265+ const int64_t n = src0->ne [0 ];
266+ const int64_t k = src1->ne [0 ];
99267 const int64_t ne02 = src0->ne [2 ];
100268 const int64_t ne03 = src0->ne [3 ];
101269
102- solve_tri_f32_cublas (ctx,
103- (const float *) src0->data ,
104- (const float *) src1->data ,
105- (float *) dst->data ,
106- n, k,
107- ne02, ne03,
108- src0->nb [2 ] / sizeof (float ), src0->nb [3 ] / sizeof (float ),
109- src1->nb [2 ] / sizeof (float ), src1->nb [3 ] / sizeof (float ),
110- dst->nb [2 ] / sizeof (float ), dst->nb [3 ] / sizeof (float ),
111- ctx.stream ());
270+ if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
271+ solve_tri_f32_cuda ((const float *) src0->data , (const float *) src1->data , (float *) dst->data , n, k,
272+ src0->ne [2 ], src0->ne [3 ], src0->nb [2 ] / sizeof (float ), src0->nb [3 ] / sizeof (float ),
273+ src1->nb [2 ] / sizeof (float ), src1->nb [3 ] / sizeof (float ), dst->nb [2 ] / sizeof (float ),
274+ dst->nb [3 ] / sizeof (float ), ctx.stream ());
275+ } else {
276+ solve_tri_f32_cublas (ctx, (const float *) src0->data , (const float *) src1->data , (float *) dst->data , n, k,
277+ ne02, ne03, src0->nb [2 ] / sizeof (float ), src0->nb [3 ] / sizeof (float ),
278+ src1->nb [2 ] / sizeof (float ), src1->nb [3 ] / sizeof (float ), dst->nb [2 ] / sizeof (float ),
279+ dst->nb [3 ] / sizeof (float ), ctx.stream ());
280+ }
112281}
0 commit comments