Skip to content

Commit ed5f882

Browse files
Update Qwen3 FFN layers to use byte-based Q8_0 kernels
1 parent 4e984fa commit ed5f882

File tree

1 file changed

+20
-27
lines changed

1 file changed

+20
-27
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -163,23 +163,16 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
163163
// Transfer Q8_0 weights for this layer (quants and scales)
164164
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
165165
weights.rms_att_weightLayered[layerIndex].asFloatArray(), //
166-
weights.wqLayered[layerIndex].getQuants(), //
167-
weights.wqLayered[layerIndex].getScales(), //
168-
weights.wkLayered[layerIndex].getQuants(), //
169-
weights.wkLayered[layerIndex].getScales(), //
170-
weights.wvLayered[layerIndex].getQuants(), //
171-
weights.wvLayered[layerIndex].getScales(),//
172-
weights.woLayered[layerIndex].getQuants(),//
173-
weights.woLayered[layerIndex].getScales(),//
166+
weights.wqLayered[layerIndex].asByteArray(),
167+
weights.wkLayered[layerIndex].asByteArray(),
168+
weights.wvLayered[layerIndex].asByteArray(),
169+
weights.woLayered[layerIndex].asByteArray(),
174170
weights.rms_att_KNormLayered[layerIndex].asFloatArray(), //
175171
weights.rms_att_QNormLayered[layerIndex].asFloatArray(),//
176172
weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), //
177-
weights.w1Layered[layerIndex].getQuants(), //
178-
weights.w1Layered[layerIndex].getScales(), //
179-
weights.w2Layered[layerIndex].getQuants(), //
180-
weights.w2Layered[layerIndex].getScales(), //
181-
weights.w3Layered[layerIndex].getQuants(), //
182-
weights.w3Layered[layerIndex].getScales()); //
173+
weights.w1Layered[layerIndex].asByteArray(),
174+
weights.w2Layered[layerIndex].asByteArray(),
175+
weights.w3Layered[layerIndex].asByteArray());
183176

184177
// Configure layer data transfers (EVERY_EXECUTION and device persistence)
185178
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
@@ -200,19 +193,19 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
200193
int qkvDim1 = config.dim(); // Input dimension
201194

202195
unifiedLayer.task("qmatmul",
203-
TransformerComputeKernelsLayered::matrixVectorGeneric,
196+
TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte,
204197
context, qwen3State.wrapXb, qwen3State.wrapQ,
205-
weights.wqLayered[layerIndex].getQuants(), weights.wqLayered[layerIndex].getScales(),
198+
weights.wqLayered[layerIndex].asByteArray(),
206199
qkvDim1, qDim0, LOCAL_WORK_GROUP_SIZE_ALLOC)
207200
.task("kmatmul",
208-
TransformerComputeKernelsLayered::matrixVectorGeneric,
201+
TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte,
209202
context, qwen3State.wrapXb, qwen3State.wrapK,
210-
weights.wkLayered[layerIndex].getQuants(), weights.wkLayered[layerIndex].getScales(),
203+
weights.wkLayered[layerIndex].asByteArray(),
211204
qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC)
212205
.task("vmatmul",
213-
TransformerComputeKernelsLayered::matrixVectorGeneric,
206+
TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte,
214207
context, qwen3State.wrapXb, qwen3State.wrapV,
215-
weights.wvLayered[layerIndex].getQuants(), weights.wvLayered[layerIndex].getScales(),
208+
weights.wvLayered[layerIndex].asByteArray(),
216209
qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC);
217210

218211
// Qcur: RMS norm with parallel offset for Query
@@ -252,9 +245,9 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
252245

253246
// Output projection (Q8_0 weights)
254247
unifiedLayer.task("matmul1",
255-
TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
248+
TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte,
256249
context, qwen3State.wrapXb, qwen3State.wrapX,
257-
weights.woLayered[layerIndex].getQuants(), weights.woLayered[layerIndex].getScales(),
250+
weights.woLayered[layerIndex].asByteArray(),
258251
qDim0, config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC);
259252

260253
// ========== FEED-FORWARD BLOCK ==========
@@ -269,15 +262,15 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
269262

270263
// Fused FFN: w1(x) ⊗ w3(x) with SiLU activation (Q8_0 weights)
271264
unifiedLayer.task("fused_ffn_w1_w3",
272-
TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation,
265+
TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte,
273266
context, qwen3State.wrapXb, qwen3State.wrapHb,
274-
weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(),
275-
weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales(),
267+
weights.w1Layered[layerIndex].asByteArray(),
268+
weights.w3Layered[layerIndex].asByteArray(),
276269
config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
277270
.task("projectionTwo",
278-
TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
271+
TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte,
279272
context, qwen3State.wrapHb, qwen3State.wrapX,
280-
weights.w2Layered[layerIndex].getQuants(), weights.w2Layered[layerIndex].getScales(),
273+
weights.w2Layered[layerIndex].asByteArray(),
281274
config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
282275
.persistOnDevice(state.wrapX);
283276

0 commit comments

Comments
 (0)