@@ -163,23 +163,16 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
163163 // Transfer Q8_0 weights for this layer (quants and scales)
164164 unifiedLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
165165 weights .rms_att_weightLayered [layerIndex ].asFloatArray (), //
166- weights .wqLayered [layerIndex ].getQuants (), //
167- weights .wqLayered [layerIndex ].getScales (), //
168- weights .wkLayered [layerIndex ].getQuants (), //
169- weights .wkLayered [layerIndex ].getScales (), //
170- weights .wvLayered [layerIndex ].getQuants (), //
171- weights .wvLayered [layerIndex ].getScales (),//
172- weights .woLayered [layerIndex ].getQuants (),//
173- weights .woLayered [layerIndex ].getScales (),//
166+ weights .wqLayered [layerIndex ].asByteArray (),
167+ weights .wkLayered [layerIndex ].asByteArray (),
168+ weights .wvLayered [layerIndex ].asByteArray (),
169+ weights .woLayered [layerIndex ].asByteArray (),
174170 weights .rms_att_KNormLayered [layerIndex ].asFloatArray (), //
175171 weights .rms_att_QNormLayered [layerIndex ].asFloatArray (),//
176172 weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (), //
177- weights .w1Layered [layerIndex ].getQuants (), //
178- weights .w1Layered [layerIndex ].getScales (), //
179- weights .w2Layered [layerIndex ].getQuants (), //
180- weights .w2Layered [layerIndex ].getScales (), //
181- weights .w3Layered [layerIndex ].getQuants (), //
182- weights .w3Layered [layerIndex ].getScales ()); //
173+ weights .w1Layered [layerIndex ].asByteArray (),
174+ weights .w2Layered [layerIndex ].asByteArray (),
175+ weights .w3Layered [layerIndex ].asByteArray ());
183176
184177 // Configure layer data transfers (EVERY_EXECUTION and device persistence)
185178 unifiedLayer = configureLayerDataTransfers (unifiedLayer , layerIndex );
@@ -200,19 +193,19 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
200193 int qkvDim1 = config .dim (); // Input dimension
201194
202195 unifiedLayer .task ("qmatmul" ,
203- TransformerComputeKernelsLayered ::matrixVectorGeneric ,
196+ TransformerComputeKernelsLayered ::matrixVectorGenericQ8Byte ,
204197 context , qwen3State .wrapXb , qwen3State .wrapQ ,
205- weights .wqLayered [layerIndex ].getQuants (), weights . wqLayered [ layerIndex ]. getScales (),
198+ weights .wqLayered [layerIndex ].asByteArray (),
206199 qkvDim1 , qDim0 , LOCAL_WORK_GROUP_SIZE_ALLOC )
207200 .task ("kmatmul" ,
208- TransformerComputeKernelsLayered ::matrixVectorGeneric ,
201+ TransformerComputeKernelsLayered ::matrixVectorGenericQ8Byte ,
209202 context , qwen3State .wrapXb , qwen3State .wrapK ,
210- weights .wkLayered [layerIndex ].getQuants (), weights . wkLayered [ layerIndex ]. getScales (),
203+ weights .wkLayered [layerIndex ].asByteArray (),
211204 qkvDim1 , kvDim0 , LOCAL_WORK_GROUP_SIZE_ALLOC )
212205 .task ("vmatmul" ,
213- TransformerComputeKernelsLayered ::matrixVectorGeneric ,
206+ TransformerComputeKernelsLayered ::matrixVectorGenericQ8Byte ,
214207 context , qwen3State .wrapXb , qwen3State .wrapV ,
215- weights .wvLayered [layerIndex ].getQuants (), weights . wvLayered [ layerIndex ]. getScales (),
208+ weights .wvLayered [layerIndex ].asByteArray (),
216209 qkvDim1 , kvDim0 , LOCAL_WORK_GROUP_SIZE_ALLOC );
217210
218211 // Qcur: RMS norm with parallel offset for Query
@@ -252,9 +245,9 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
252245
253246 // Output projection (Q8_0 weights)
254247 unifiedLayer .task ("matmul1" ,
255- TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual ,
248+ TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte ,
256249 context , qwen3State .wrapXb , qwen3State .wrapX ,
257- weights .woLayered [layerIndex ].getQuants (), weights . woLayered [ layerIndex ]. getScales (),
250+ weights .woLayered [layerIndex ].asByteArray (),
258251 qDim0 , config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC );
259252
260253 // ========== FEED-FORWARD BLOCK ==========
@@ -269,15 +262,15 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
269262
270263 // Fused FFN: w1(x) ⊗ w3(x) with SiLU activation (Q8_0 weights)
271264 unifiedLayer .task ("fused_ffn_w1_w3" ,
272- TransformerComputeKernelsLayered ::fusedFeedForwardWithSiLUAndGLUActivation ,
265+ TransformerComputeKernelsLayered ::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte ,
273266 context , qwen3State .wrapXb , qwen3State .wrapHb ,
274- weights .w1Layered [layerIndex ].getQuants (), weights . w1Layered [ layerIndex ]. getScales (),
275- weights .w3Layered [layerIndex ].getQuants (), weights . w3Layered [ layerIndex ]. getScales (),
267+ weights .w1Layered [layerIndex ].asByteArray (),
268+ weights .w3Layered [layerIndex ].asByteArray (),
276269 config .dim (), config .hiddenDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
277270 .task ("projectionTwo" ,
278- TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual ,
271+ TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte ,
279272 context , qwen3State .wrapHb , qwen3State .wrapX ,
280- weights .w2Layered [layerIndex ].getQuants (), weights . w2Layered [ layerIndex ]. getScales (),
273+ weights .w2Layered [layerIndex ].asByteArray (),
281274 config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
282275 .persistOnDevice (state .wrapX );
283276
0 commit comments