@@ -166,37 +166,30 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
166166 unifiedLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
167167 //Copy-in weights per layer for batched-layered layout
168168 weights .rms_att_weightLayered [layerIndex ].asFloatArray (),
169- weights .wqLayered [layerIndex ].getScales (),
170- weights .wqLayered [layerIndex ].getQuants (),
171- weights .wkLayered [layerIndex ].getScales (),
172- weights .wkLayered [layerIndex ].getQuants (),
173- weights .wvLayered [layerIndex ].getScales (),
174- weights .wvLayered [layerIndex ].getQuants (),
175- weights .woLayered [layerIndex ].getScales (),
176- weights .woLayered [layerIndex ].getQuants (),
169+ weights .wqLayered [layerIndex ].asByteArray (),
170+ weights .wkLayered [layerIndex ].asByteArray (),
171+ weights .wvLayered [layerIndex ].asByteArray (),
172+ weights .woLayered [layerIndex ].asByteArray (),
177173 weights .q_biasLayered [layerIndex ].asFloatArray (),
178174 weights .k_biasLayered [layerIndex ].asFloatArray (),
179175 weights .v_biasLayered [layerIndex ].asFloatArray (),
180176 weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (),
181- weights .w1Layered [layerIndex ].getScales (),
182- weights .w1Layered [layerIndex ].getQuants (),
183- weights .w2Layered [layerIndex ].getScales (),
184- weights .w2Layered [layerIndex ].getQuants (),
185- weights .w3Layered [layerIndex ].getScales (),
186- weights .w3Layered [layerIndex ].getQuants ()
177+ weights .w1Layered [layerIndex ].asByteArray (),
178+ weights .w2Layered [layerIndex ].asByteArray (),
179+ weights .w3Layered [layerIndex ].asByteArray ()
187180 );
188181 unifiedLayer = configureLayerDataTransfers (unifiedLayer , layerIndex );
189182
190183 unifiedLayer .task ("reductionsOneBlock" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .temp ,
191184 state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
192185 .task ("mapContext" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , state .wrapXb ,
193186 state .wrapX , weights .rms_att_weightLayered [layerIndex ].asFloatArray (), state .temp )
194- .task ("qmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context ,
195- state .wrapXb , state .wrapQ , weights .wqLayered [layerIndex ].getQuants (), weights . wqLayered [ layerIndex ]. getScales (), config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
196- .task ("kmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context ,
197- state .wrapXb , state .wrapK , weights .wkLayered [layerIndex ].getQuants (), weights . wkLayered [ layerIndex ]. getScales (), config .dim (), config .kvDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
198- .task ("vmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context ,
199- state .wrapXb , state .wrapV , weights .wvLayered [layerIndex ].getQuants (), weights . wvLayered [ layerIndex ]. getScales (), config .dim (), config .kvDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
187+ .task ("qmatmul" , TransformerComputeKernelsLayered ::matrixVectorGenericQ8Byte , context ,
188+ state .wrapXb , state .wrapQ , weights .wqLayered [layerIndex ].asByteArray (), config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
189+ .task ("kmatmul" , TransformerComputeKernelsLayered ::matrixVectorGenericQ8Byte , context ,
190+ state .wrapXb , state .wrapK , weights .wkLayered [layerIndex ].asByteArray (), config .dim (), config .kvDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
191+ .task ("vmatmul" , TransformerComputeKernelsLayered ::matrixVectorGenericQ8Byte , context ,
192+ state .wrapXb , state .wrapV , weights .wvLayered [layerIndex ].asByteArray (), config .dim (), config .kvDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
200193 .task ("qbias" , TransformerComputeKernelsLayered ::addInPlace , state .wrapQ , weights .q_biasLayered [layerIndex ].asFloatArray (), config .dim ())
201194 .task ("kbias" , TransformerComputeKernelsLayered ::addInPlace , state .wrapK , weights .k_biasLayered [layerIndex ].asFloatArray (), config .kvDim ())
202195 .task ("vbias" , TransformerComputeKernelsLayered ::addInPlace , state .wrapV , weights .v_biasLayered [layerIndex ].asFloatArray (), config .kvDim ())
@@ -208,16 +201,16 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
208201 state .wrapQ , state .wrapKeyCache , state .wrapValueCache , state .wrapXb ,
209202 config .numberOfHeads (), config .headSize (), config .kvDim (), config .kvMul (),
210203 state .positionHolder , layerIndex , config .contextLength ())
211- .task ("matmul1" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context ,
212- state .wrapXb , state .wrapX , weights .woLayered [layerIndex ].getQuants (), weights . woLayered [ layerIndex ]. getScales (), config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
204+ .task ("matmul1" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte , context ,
205+ state .wrapXb , state .wrapX , weights .woLayered [layerIndex ].asByteArray (), config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
213206 .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .tempFFN ,
214207 state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
215208 .task ("mapContextFFN" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , state .wrapXb ,
216209 state .wrapX , weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (), state .tempFFN )
217- .task ("fused_ffn_w1_w3" , TransformerComputeKernelsLayered ::fusedFeedForwardWithSiLUAndGLUActivation , context ,
218- state .wrapXb , state .wrapHb , weights .w1Layered [layerIndex ].getQuants (), weights .w1Layered [ layerIndex ]. getScales (), weights . w3Layered [layerIndex ].getQuants (), weights . w3Layered [ layerIndex ]. getScales (), config .dim (), config .hiddenDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
219- .task ("projectionTwo" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context ,
220- state .wrapHb , state .wrapX , weights .w2Layered [layerIndex ].getQuants (), weights . w2Layered [ layerIndex ]. getScales (), config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
210+ .task ("fused_ffn_w1_w3" , TransformerComputeKernelsLayered ::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte , context ,
211+ state .wrapXb , state .wrapHb , weights .w1Layered [layerIndex ].asByteArray (), weights .w3Layered [layerIndex ].asByteArray (), config .dim (), config .hiddenDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
212+ .task ("projectionTwo" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte , context ,
213+ state .wrapHb , state .wrapX , weights .w2Layered [layerIndex ].asByteArray (), config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
221214 .persistOnDevice (
222215 state .wrapX
223216 );
0 commit comments