@@ -106,6 +106,112 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
106106# pragma clang diagnostic pop
107107#endif // __clang__
108108
109+ // ======================
110+ // General Kernel for larger matrices
111+ // Uses a simpler approach with fixed tile size
112+ // ======================
113+ #define GENERAL_TILE_SIZE 32
114+
115+ template <int n_template, int k_template>
116+ static __global__ void solve_tri_f32_general (const float * __restrict__ A,
117+ const float * __restrict__ B,
118+ float * __restrict__ X,
119+ const uint3 ne02,
120+ const size_t nb02,
121+ const size_t nb03,
122+ const size_t nb12,
123+ const size_t nb13,
124+ const size_t nb2,
125+ const size_t nb3,
126+ const int n_arg,
127+ const int k_arg) {
128+ const int n = n_template == 0 ? n_arg : n_template;
129+ const int k = k_template == 0 ? k_arg : k_template;
130+
131+ const int batch_idx = blockIdx .x ;
132+ const int col_idx = blockIdx .y ;
133+ const int tid = threadIdx .x ;
134+
135+ if (col_idx >= k) {
136+ return ;
137+ }
138+
139+ const uint2 i02_i03 = fast_div_modulo (batch_idx, ne02);
140+ const int64_t i02 = i02_i03.y ;
141+ const int64_t i03 = i02_i03.x ;
142+
143+ const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
144+ const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
145+ float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
146+
147+ // Shared memory for current tile
148+ __shared__ float sA [GENERAL_TILE_SIZE * GENERAL_TILE_SIZE];
149+ __shared__ float sB [GENERAL_TILE_SIZE];
150+ __shared__ float sX [GENERAL_TILE_SIZE];
151+
152+ // Process in tiles
153+ for (int tile_start = 0 ; tile_start < n; tile_start += GENERAL_TILE_SIZE) {
154+ int tile_end = min (tile_start + GENERAL_TILE_SIZE, n);
155+ int tile_n = tile_end - tile_start;
156+
157+ // Load tile of A matrix
158+ for (int i = tid; i < tile_n * tile_n; i += blockDim .x ) {
159+ int local_row = i / tile_n;
160+ int local_col = i % tile_n;
161+ int global_row = tile_start + local_row;
162+ int global_col = tile_start + local_col;
163+
164+ if (global_col <= global_row) {
165+ sA [local_row * GENERAL_TILE_SIZE + local_col] = A_batch[global_row * n + global_col];
166+ } else {
167+ sA [local_row * GENERAL_TILE_SIZE + local_col] = 0 .0f ;
168+ }
169+ }
170+
171+ __syncthreads ();
172+
173+ // Load corresponding part of B and initialize X
174+ if (tid < tile_n) {
175+ sB [tid] = B_batch[(tile_start + tid) * k + col_idx];
176+ sX [tid] = sB [tid];
177+ }
178+
179+ __syncthreads ();
180+
181+ // Forward substitution for this tile
182+ for (int row = 0 ; row < tile_n; ++row) {
183+ if (tid == row) {
184+ float sum = 0 .0f ;
185+
186+ // Sum contributions from previous rows in this tile
187+ for (int j = 0 ; j < row; ++j) {
188+ sum += sA [row * GENERAL_TILE_SIZE + j] * sX [j];
189+ }
190+
191+ // Sum contributions from previous tiles
192+ if (tile_start > 0 ) {
193+ int global_row = tile_start + row;
194+ for (int j = 0 ; j < tile_start; ++j) {
195+ sum += A_batch[global_row * n + j] * X_batch[j * k + col_idx];
196+ }
197+ }
198+
199+ const float a_diag = sA [row * GENERAL_TILE_SIZE + row];
200+ sX [row] = (sB [row] - sum) / a_diag;
201+ }
202+ __syncthreads ();
203+ }
204+
205+ // Store results back to global memory
206+ if (tid < tile_n) {
207+ int global_row = tile_start + tid;
208+ X_batch[global_row * k + col_idx] = sX [tid];
209+ }
210+
211+ __syncthreads ();
212+ }
213+ }
214+
109215static void solve_tri_f32_cuda (const float * A,
110216 const float * B,
111217 float * X,
@@ -121,56 +227,68 @@ static void solve_tri_f32_cuda(const float * A,
121227 size_t nb3,
122228 cudaStream_t stream) {
123229 const uint3 ne02_fd = init_fastdiv_values ((uint32_t ) ne02);
124- dim3 threads (WARP_SIZE, k);
125- dim3 grid (ne02 * ne03);
126- if (n == 64 ) {
127- switch (k) {
128- case 32 :
129- solve_tri_f32_fast<64 , 32 >
130- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
131- break ;
132- case 16 :
133- solve_tri_f32_fast<64 , 16 >
134- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
135- break ;
136- case 14 :
137- solve_tri_f32_fast<64 , 14 >
138- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
139- break ;
140- case 12 :
141- solve_tri_f32_fast<64 , 12 >
142- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
143- break ;
144- case 10 :
145- solve_tri_f32_fast<64 , 10 >
146- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
147- break ;
148- case 8 :
149- solve_tri_f32_fast<64 , 8 >
150- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
151- break ;
152- case 6 :
153- solve_tri_f32_fast<64 , 6 >
154- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
155- break ;
156- case 4 :
157- solve_tri_f32_fast<64 , 4 >
158- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
159- break ;
160- case 2 :
161- solve_tri_f32_fast<64 , 2 >
162- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
163- break ;
164- case 1 :
165- solve_tri_f32_fast<64 , 1 >
166- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
167- break ;
168- default :
169- solve_tri_f32_fast<0 , 0 >
170- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
230+
231+ // Choose kernel based on matrix size
232+ if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
233+ // Use fast kernel for small matrices
234+ dim3 threads (WARP_SIZE, k);
235+ dim3 grid (ne02 * ne03);
236+ if (n == 64 ) {
237+ switch (k) {
238+ case 32 :
239+ solve_tri_f32_fast<64 , 32 >
240+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
241+ break ;
242+ case 16 :
243+ solve_tri_f32_fast<64 , 16 >
244+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
245+ break ;
246+ case 14 :
247+ solve_tri_f32_fast<64 , 14 >
248+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
249+ break ;
250+ case 12 :
251+ solve_tri_f32_fast<64 , 12 >
252+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
253+ break ;
254+ case 10 :
255+ solve_tri_f32_fast<64 , 10 >
256+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
257+ break ;
258+ case 8 :
259+ solve_tri_f32_fast<64 , 8 >
260+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
261+ break ;
262+ case 6 :
263+ solve_tri_f32_fast<64 , 6 >
264+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
265+ break ;
266+ case 4 :
267+ solve_tri_f32_fast<64 , 4 >
268+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
269+ break ;
270+ case 2 :
271+ solve_tri_f32_fast<64 , 2 >
272+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
273+ break ;
274+ case 1 :
275+ solve_tri_f32_fast<64 , 1 >
276+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
277+ break ;
278+ default :
279+ solve_tri_f32_fast<0 , 0 >
280+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
281+ }
282+ } else { // run general case
283+ solve_tri_f32_fast<0 , 0 >
284+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
171285 }
172- } else { // run general case
173- solve_tri_f32_fast<0 , 0 >
286+ } else {
287+ // Use general kernel for larger matrices
288+ dim3 threads (256 , 1 ); // 256 threads per block
289+ dim3 grid (ne02 * ne03, k); // One block per column
290+
291+ solve_tri_f32_general<0 , 0 >
174292 <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
175293 }
176294}
@@ -185,11 +303,8 @@ void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
185303 const int64_t n = src0->ne [0 ];
186304 const int64_t k = src1->ne [0 ];
187305
188- GGML_ASSERT (n <= 64 );
189- GGML_ASSERT (k <= 32 );
190-
191306 solve_tri_f32_cuda ((const float *) src0->data , (const float *) src1->data , (float *) dst->data , n, k, src0->ne [2 ],
192307 src0->ne [3 ], src0->nb [2 ] / sizeof (float ), src0->nb [3 ] / sizeof (float ),
193308 src1->nb [2 ] / sizeof (float ), src1->nb [3 ] / sizeof (float ), dst->nb [2 ] / sizeof (float ),
194309 dst->nb [3 ] / sizeof (float ), ctx.stream ());
195- }
310+ }
0 commit comments