@@ -135,7 +135,6 @@ static __global__ void solve_tri_f32_general(const float * __restrict__ A,
135135 if (col_idx >= k) {
136136 return ;
137137 }
138-
139138 const uint2 i02_i03 = fast_div_modulo (batch_idx, ne02);
140139 const int64_t i02 = i02_i03.y ;
141140 const int64_t i03 = i02_i03.x ;
@@ -153,65 +152,53 @@ static __global__ void solve_tri_f32_general(const float * __restrict__ A,
153152 for (int tile_start = 0 ; tile_start < n; tile_start += GENERAL_TILE_SIZE) {
154153 int tile_end = min (tile_start + GENERAL_TILE_SIZE, n);
155154 int tile_n = tile_end - tile_start;
156-
157155 // Load tile of A matrix
158156 for (int i = tid; i < tile_n * tile_n; i += blockDim .x ) {
159157 int local_row = i / tile_n;
160158 int local_col = i % tile_n;
161159 int global_row = tile_start + local_row;
162160 int global_col = tile_start + local_col;
163-
164161 if (global_col <= global_row) {
165162 sA [local_row * GENERAL_TILE_SIZE + local_col] = A_batch[global_row * n + global_col];
166163 } else {
167164 sA [local_row * GENERAL_TILE_SIZE + local_col] = 0 .0f ;
168165 }
169166 }
170-
171167 __syncthreads ();
172-
173168 // Load corresponding part of B and initialize X
174169 if (tid < tile_n) {
175170 sB [tid] = B_batch[(tile_start + tid) * k + col_idx];
176171 sX [tid] = sB [tid];
177172 }
178-
179173 __syncthreads ();
180-
181174 // Forward substitution for this tile
182175 for (int row = 0 ; row < tile_n; ++row) {
183176 if (tid == row) {
184177 float sum = 0 .0f ;
185-
186178 // Sum contributions from previous rows in this tile
187179 for (int j = 0 ; j < row; ++j) {
188180 sum += sA [row * GENERAL_TILE_SIZE + j] * sX [j];
189181 }
190-
191182 // Sum contributions from previous tiles
192183 if (tile_start > 0 ) {
193184 int global_row = tile_start + row;
194185 for (int j = 0 ; j < tile_start; ++j) {
195186 sum += A_batch[global_row * n + j] * X_batch[j * k + col_idx];
196187 }
197188 }
198-
199189 const float a_diag = sA [row * GENERAL_TILE_SIZE + row];
200190 sX [row] = (sB [row] - sum) / a_diag;
201191 }
202192 __syncthreads ();
203193 }
204-
205194 // Store results back to global memory
206195 if (tid < tile_n) {
207196 int global_row = tile_start + tid;
208197 X_batch[global_row * k + col_idx] = sX [tid];
209198 }
210-
211199 __syncthreads ();
212200 }
213201}
214-
215202static void solve_tri_f32_cuda (const float * A,
216203 const float * B,
217204 float * X,
@@ -227,7 +214,6 @@ static void solve_tri_f32_cuda(const float * A,
227214 size_t nb3,
228215 cudaStream_t stream) {
229216 const uint3 ne02_fd = init_fastdiv_values ((uint32_t ) ne02);
230-
231217 // Choose kernel based on matrix size
232218 if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
233219 // Use fast kernel for small matrices
@@ -287,7 +273,6 @@ static void solve_tri_f32_cuda(const float * A,
287273 // Use general kernel for larger matrices
288274 dim3 threads (256 , 1 ); // 256 threads per block
289275 dim3 grid (ne02 * ne03, k); // One block per column
290-
291276 solve_tri_f32_general<0 , 0 >
292277 <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
293278 }
@@ -307,4 +292,4 @@ void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
307292 src0->ne [3 ], src0->nb [2 ] / sizeof (float ), src0->nb [3 ] / sizeof (float ),
308293 src1->nb [2 ] / sizeof (float ), src1->nb [3 ] / sizeof (float ), dst->nb [2 ] / sizeof (float ),
309294 dst->nb [3 ] / sizeof (float ), ctx.stream ());
310- }
295+ }
0 commit comments