Skip to content

Commit c52bcaa

Browse files
Update Qwen2 and Deepseek FFN layers to use byte-based Q8_0 kernels
1 parent 4e30022 commit c52bcaa

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -166,37 +166,30 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
166166
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
167167
//Copy-in weights per layer for batched-layered layout
168168
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
169-
weights.wqLayered[layerIndex].getScales(),
170-
weights.wqLayered[layerIndex].getQuants(),
171-
weights.wkLayered[layerIndex].getScales(),
172-
weights.wkLayered[layerIndex].getQuants(),
173-
weights.wvLayered[layerIndex].getScales(),
174-
weights.wvLayered[layerIndex].getQuants(),
175-
weights.woLayered[layerIndex].getScales(),
176-
weights.woLayered[layerIndex].getQuants(),
169+
weights.wqLayered[layerIndex].asByteArray(),
170+
weights.wkLayered[layerIndex].asByteArray(),
171+
weights.wvLayered[layerIndex].asByteArray(),
172+
weights.woLayered[layerIndex].asByteArray(),
177173
weights.q_biasLayered[layerIndex].asFloatArray(),
178174
weights.k_biasLayered[layerIndex].asFloatArray(),
179175
weights.v_biasLayered[layerIndex].asFloatArray(),
180176
weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
181-
weights.w1Layered[layerIndex].getScales(),
182-
weights.w1Layered[layerIndex].getQuants(),
183-
weights.w2Layered[layerIndex].getScales(),
184-
weights.w2Layered[layerIndex].getQuants(),
185-
weights.w3Layered[layerIndex].getScales(),
186-
weights.w3Layered[layerIndex].getQuants()
177+
weights.w1Layered[layerIndex].asByteArray(),
178+
weights.w2Layered[layerIndex].asByteArray(),
179+
weights.w3Layered[layerIndex].asByteArray()
187180
);
188181
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
189182

190183
unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp,
191184
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
192185
.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
193186
state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp)
194-
.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
195-
state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].getQuants(), weights.wqLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
196-
.task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
197-
state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].getQuants(), weights.wkLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
198-
.task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
199-
state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].getQuants(), weights.wvLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
187+
.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context,
188+
state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
189+
.task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context,
190+
state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].asByteArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
191+
.task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context,
192+
state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].asByteArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
200193
.task("qbias", TransformerComputeKernelsLayered::addInPlace, state.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim())
201194
.task("kbias", TransformerComputeKernelsLayered::addInPlace, state.wrapK, weights.k_biasLayered[layerIndex].asFloatArray(), config.kvDim())
202195
.task("vbias", TransformerComputeKernelsLayered::addInPlace, state.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim())
@@ -208,16 +201,16 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
208201
state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
209202
config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(),
210203
state.positionHolder, layerIndex, config.contextLength())
211-
.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
212-
state.wrapXb, state.wrapX, weights.woLayered[layerIndex].getQuants(), weights.woLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
204+
.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context,
205+
state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
213206
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN,
214207
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
215208
.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
216209
state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN)
217-
.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context,
218-
state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
219-
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
220-
state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].getQuants(), weights.w2Layered[layerIndex].getScales(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
210+
.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte, context,
211+
state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
212+
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context,
213+
state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asByteArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
221214
.persistOnDevice(
222215
state.wrapX
223216
);

0 commit comments

Comments
 (0)