@@ -143,7 +143,6 @@ static __global__ void solve_tri_f32_general(const float * __restrict__ A,
143143 if (col_idx >= k) {
144144 return ;
145145 }
146-
147146 const uint2 i02_i03 = fast_div_modulo (batch_idx, ne02);
148147 const int64_t i02 = i02_i03.y ;
149148 const int64_t i03 = i02_i03.x ;
@@ -161,65 +160,53 @@ static __global__ void solve_tri_f32_general(const float * __restrict__ A,
161160 for (int tile_start = 0 ; tile_start < n; tile_start += GENERAL_TILE_SIZE) {
162161 int tile_end = min (tile_start + GENERAL_TILE_SIZE, n);
163162 int tile_n = tile_end - tile_start;
164-
165163 // Load tile of A matrix
166164 for (int i = tid; i < tile_n * tile_n; i += blockDim .x ) {
167165 int local_row = i / tile_n;
168166 int local_col = i % tile_n;
169167 int global_row = tile_start + local_row;
170168 int global_col = tile_start + local_col;
171-
172169 if (global_col <= global_row) {
173170 sA [local_row * GENERAL_TILE_SIZE + local_col] = A_batch[global_row * n + global_col];
174171 } else {
175172 sA [local_row * GENERAL_TILE_SIZE + local_col] = 0 .0f ;
176173 }
177174 }
178-
179175 __syncthreads ();
180-
181176 // Load corresponding part of B and initialize X
182177 if (tid < tile_n) {
183178 sB [tid] = B_batch[(tile_start + tid) * k + col_idx];
184179 sX [tid] = sB [tid];
185180 }
186-
187181 __syncthreads ();
188-
189182 // Forward substitution for this tile
190183 for (int row = 0 ; row < tile_n; ++row) {
191184 if (tid == row) {
192185 float sum = 0 .0f ;
193-
194186 // Sum contributions from previous rows in this tile
195187 for (int j = 0 ; j < row; ++j) {
196188 sum += sA [row * GENERAL_TILE_SIZE + j] * sX [j];
197189 }
198-
199190 // Sum contributions from previous tiles
200191 if (tile_start > 0 ) {
201192 int global_row = tile_start + row;
202193 for (int j = 0 ; j < tile_start; ++j) {
203194 sum += A_batch[global_row * n + j] * X_batch[j * k + col_idx];
204195 }
205196 }
206-
207197 const float a_diag = sA [row * GENERAL_TILE_SIZE + row];
208198 sX [row] = (sB [row] - sum) / a_diag;
209199 }
210200 __syncthreads ();
211201 }
212-
213202 // Store results back to global memory
214203 if (tid < tile_n) {
215204 int global_row = tile_start + tid;
216205 X_batch[global_row * k + col_idx] = sX [tid];
217206 }
218-
219207 __syncthreads ();
220208 }
221209}
222-
223210static void solve_tri_f32_cuda (const float * A,
224211 const float * B,
225212 float * X,
@@ -235,7 +222,6 @@ static void solve_tri_f32_cuda(const float * A,
235222 size_t nb3,
236223 cudaStream_t stream) {
237224 const uint3 ne02_fd = init_fastdiv_values ((uint32_t ) ne02);
238-
239225 // Choose kernel based on matrix size
240226 if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
241227 // Use fast kernel for small matrices
@@ -295,7 +281,6 @@ static void solve_tri_f32_cuda(const float * A,
295281 // Use general kernel for larger matrices
296282 dim3 threads (256 , 1 ); // 256 threads per block
297283 dim3 grid (ne02 * ne03, k); // One block per column
298-
299284 solve_tri_f32_general<0 , 0 >
300285 <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
301286 }
@@ -315,4 +300,4 @@ void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
315300 src0->ne [3 ], src0->nb [2 ] / sizeof (float ), src0->nb [3 ] / sizeof (float ),
316301 src1->nb [2 ] / sizeof (float ), src1->nb [3 ] / sizeof (float ), dst->nb [2 ] / sizeof (float ),
317302 dst->nb [3 ] / sizeof (float ), ctx.stream ());
318- }
303+ }
0 commit comments