Skip to content

Commit a4d38e4

Browse files
committed
Extended TRI
1 parent c8554b6 commit a4d38e4

File tree

2 files changed

+169
-54
lines changed

2 files changed

+169
-54
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4626,7 +4626,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
46264626
case GGML_OP_TRI:
46274627
return true;
46284628
case GGML_OP_SOLVE_TRI:
4629-
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
4629+
return true;
46304630
default:
46314631
return false;
46324632
}

ggml/src/ggml-cuda/solve_tri.cu

Lines changed: 168 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,112 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
106106
# pragma clang diagnostic pop
107107
#endif // __clang__
108108

109+
// ======================
110+
// General Kernel for larger matrices
111+
// Uses a simpler approach with fixed tile size
112+
// ======================
113+
#define GENERAL_TILE_SIZE 32
114+
115+
template <int n_template, int k_template>
116+
static __global__ void solve_tri_f32_general(const float * __restrict__ A,
117+
const float * __restrict__ B,
118+
float * __restrict__ X,
119+
const uint3 ne02,
120+
const size_t nb02,
121+
const size_t nb03,
122+
const size_t nb12,
123+
const size_t nb13,
124+
const size_t nb2,
125+
const size_t nb3,
126+
const int n_arg,
127+
const int k_arg) {
128+
const int n = n_template == 0 ? n_arg : n_template;
129+
const int k = k_template == 0 ? k_arg : k_template;
130+
131+
const int batch_idx = blockIdx.x;
132+
const int col_idx = blockIdx.y;
133+
const int tid = threadIdx.x;
134+
135+
if (col_idx >= k) {
136+
return;
137+
}
138+
139+
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
140+
const int64_t i02 = i02_i03.y;
141+
const int64_t i03 = i02_i03.x;
142+
143+
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
144+
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
145+
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
146+
147+
// Shared memory for current tile
148+
__shared__ float sA[GENERAL_TILE_SIZE * GENERAL_TILE_SIZE];
149+
__shared__ float sB[GENERAL_TILE_SIZE];
150+
__shared__ float sX[GENERAL_TILE_SIZE];
151+
152+
// Process in tiles
153+
for (int tile_start = 0; tile_start < n; tile_start += GENERAL_TILE_SIZE) {
154+
int tile_end = min(tile_start + GENERAL_TILE_SIZE, n);
155+
int tile_n = tile_end - tile_start;
156+
157+
// Load tile of A matrix
158+
for (int i = tid; i < tile_n * tile_n; i += blockDim.x) {
159+
int local_row = i / tile_n;
160+
int local_col = i % tile_n;
161+
int global_row = tile_start + local_row;
162+
int global_col = tile_start + local_col;
163+
164+
if (global_col <= global_row) {
165+
sA[local_row * GENERAL_TILE_SIZE + local_col] = A_batch[global_row * n + global_col];
166+
} else {
167+
sA[local_row * GENERAL_TILE_SIZE + local_col] = 0.0f;
168+
}
169+
}
170+
171+
__syncthreads();
172+
173+
// Load corresponding part of B and initialize X
174+
if (tid < tile_n) {
175+
sB[tid] = B_batch[(tile_start + tid) * k + col_idx];
176+
sX[tid] = sB[tid];
177+
}
178+
179+
__syncthreads();
180+
181+
// Forward substitution for this tile
182+
for (int row = 0; row < tile_n; ++row) {
183+
if (tid == row) {
184+
float sum = 0.0f;
185+
186+
// Sum contributions from previous rows in this tile
187+
for (int j = 0; j < row; ++j) {
188+
sum += sA[row * GENERAL_TILE_SIZE + j] * sX[j];
189+
}
190+
191+
// Sum contributions from previous tiles
192+
if (tile_start > 0) {
193+
int global_row = tile_start + row;
194+
for (int j = 0; j < tile_start; ++j) {
195+
sum += A_batch[global_row * n + j] * X_batch[j * k + col_idx];
196+
}
197+
}
198+
199+
const float a_diag = sA[row * GENERAL_TILE_SIZE + row];
200+
sX[row] = (sB[row] - sum) / a_diag;
201+
}
202+
__syncthreads();
203+
}
204+
205+
// Store results back to global memory
206+
if (tid < tile_n) {
207+
int global_row = tile_start + tid;
208+
X_batch[global_row * k + col_idx] = sX[tid];
209+
}
210+
211+
__syncthreads();
212+
}
213+
}
214+
109215
static void solve_tri_f32_cuda(const float * A,
110216
const float * B,
111217
float * X,
@@ -121,56 +227,68 @@ static void solve_tri_f32_cuda(const float * A,
121227
size_t nb3,
122228
cudaStream_t stream) {
123229
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
124-
dim3 threads(WARP_SIZE, k);
125-
dim3 grid(ne02 * ne03);
126-
if (n == 64) {
127-
switch (k) {
128-
case 32:
129-
solve_tri_f32_fast<64, 32>
130-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
131-
break;
132-
case 16:
133-
solve_tri_f32_fast<64, 16>
134-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
135-
break;
136-
case 14:
137-
solve_tri_f32_fast<64, 14>
138-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
139-
break;
140-
case 12:
141-
solve_tri_f32_fast<64, 12>
142-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
143-
break;
144-
case 10:
145-
solve_tri_f32_fast<64, 10>
146-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
147-
break;
148-
case 8:
149-
solve_tri_f32_fast<64, 8>
150-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
151-
break;
152-
case 6:
153-
solve_tri_f32_fast<64, 6>
154-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
155-
break;
156-
case 4:
157-
solve_tri_f32_fast<64, 4>
158-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
159-
break;
160-
case 2:
161-
solve_tri_f32_fast<64, 2>
162-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
163-
break;
164-
case 1:
165-
solve_tri_f32_fast<64, 1>
166-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
167-
break;
168-
default:
169-
solve_tri_f32_fast<0, 0>
170-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
230+
231+
// Choose kernel based on matrix size
232+
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
233+
// Use fast kernel for small matrices
234+
dim3 threads(WARP_SIZE, k);
235+
dim3 grid(ne02 * ne03);
236+
if (n == 64) {
237+
switch (k) {
238+
case 32:
239+
solve_tri_f32_fast<64, 32>
240+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
241+
break;
242+
case 16:
243+
solve_tri_f32_fast<64, 16>
244+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
245+
break;
246+
case 14:
247+
solve_tri_f32_fast<64, 14>
248+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
249+
break;
250+
case 12:
251+
solve_tri_f32_fast<64, 12>
252+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
253+
break;
254+
case 10:
255+
solve_tri_f32_fast<64, 10>
256+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
257+
break;
258+
case 8:
259+
solve_tri_f32_fast<64, 8>
260+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
261+
break;
262+
case 6:
263+
solve_tri_f32_fast<64, 6>
264+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
265+
break;
266+
case 4:
267+
solve_tri_f32_fast<64, 4>
268+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
269+
break;
270+
case 2:
271+
solve_tri_f32_fast<64, 2>
272+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
273+
break;
274+
case 1:
275+
solve_tri_f32_fast<64, 1>
276+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
277+
break;
278+
default:
279+
solve_tri_f32_fast<0, 0>
280+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
281+
}
282+
} else { // run general case
283+
solve_tri_f32_fast<0, 0>
284+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
171285
}
172-
} else { // run general case
173-
solve_tri_f32_fast<0, 0>
286+
} else {
287+
// Use general kernel for larger matrices
288+
dim3 threads(256, 1); // 256 threads per block
289+
dim3 grid(ne02 * ne03, k); // One block per column
290+
291+
solve_tri_f32_general<0, 0>
174292
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
175293
}
176294
}
@@ -185,11 +303,8 @@ void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
185303
const int64_t n = src0->ne[0];
186304
const int64_t k = src1->ne[0];
187305

188-
GGML_ASSERT(n <= 64);
189-
GGML_ASSERT(k <= 32);
190-
191306
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
192307
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
193308
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
194309
dst->nb[3] / sizeof(float), ctx.stream());
195-
}
310+
}

0 commit comments

Comments
 (0)