Skip to content

Commit 2c8cf24

Browse files
Update Phi3 FFN layers to use byte-based Q8_0 kernels
1 parent 9562505 commit 2c8cf24

File tree

1 file changed

+12
-20
lines changed

1 file changed

+12
-20
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,11 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex
137137
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
138138
// Copy-in quantized weights per layer
139139
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
140-
weights.wqkvLayered[layerIndex].getQuants(),
141-
weights.wqkvLayered[layerIndex].getScales(),
142-
weights.woLayered[layerIndex].getQuants(),
143-
weights.woLayered[layerIndex].getScales(),
140+
weights.wqkvLayered[layerIndex].asByteArray(),
141+
weights.woLayered[layerIndex].asByteArray(),
144142
weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
145-
weights.wUpLayered[layerIndex].getQuants(),
146-
weights.wUpLayered[layerIndex].getScales(),
147-
weights.wDownLayered[layerIndex].getQuants(),
148-
weights.wDownLayered[layerIndex].getScales()
143+
weights.wUpLayered[layerIndex].asByteArray(),
144+
weights.wDownLayered[layerIndex].asByteArray()
149145
);
150146
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
151147

@@ -168,12 +164,11 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex
168164

169165
// Combined QKV projection (quantized)
170166
unifiedLayer.task("qkvmatmul",
171-
TransformerComputeKernelsLayered::matrixVectorGeneric,
167+
TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte,
172168
context,
173169
phi3State.wrapXb,
174170
phi3State.wrapQkv,
175-
weights.wqkvLayered[layerIndex].getQuants(),
176-
weights.wqkvLayered[layerIndex].getScales(),
171+
weights.wqkvLayered[layerIndex].asByteArray(),
177172
phi3Config.dim(),
178173
opSize,
179174
LOCAL_WORK_GROUP_SIZE_ALLOC)
@@ -226,12 +221,11 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex
226221

227222
// Output projection (quantized)
228223
unifiedLayer.task("matmul1",
229-
TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
224+
TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte,
230225
context,
231226
phi3State.wrapXb,
232227
phi3State.wrapX,
233-
weights.woLayered[layerIndex].getQuants(),
234-
weights.woLayered[layerIndex].getScales(),
228+
weights.woLayered[layerIndex].asByteArray(),
235229
phi3Config.dim(),
236230
phi3Config.dim(),
237231
LOCAL_WORK_GROUP_SIZE_ALLOC);
@@ -255,12 +249,11 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex
255249

256250
// FFN: combined Up and Gate projection (outputs 2 * hiddenDim, quantized)
257251
unifiedLayer.task("wGateUp",
258-
TransformerComputeKernelsLayered::matrixVectorGeneric,
252+
TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte,
259253
context,
260254
phi3State.wrapXb,
261255
phi3State.wrapHb,
262-
weights.wUpLayered[layerIndex].getQuants(),
263-
weights.wUpLayered[layerIndex].getScales(),
256+
weights.wUpLayered[layerIndex].asByteArray(),
264257
phi3Config.dim(),
265258
2 * phi3Config.hiddenDim(),
266259
LOCAL_WORK_GROUP_SIZE_ALLOC)
@@ -273,12 +266,11 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex
273266

274267
// FFN: Down projection with residual (quantized)
275268
unifiedLayer.task("wDown",
276-
TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
269+
TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte,
277270
context,
278271
phi3State.wrapHbU,
279272
phi3State.wrapX,
280-
weights.wDownLayered[layerIndex].getQuants(),
281-
weights.wDownLayered[layerIndex].getScales(),
273+
weights.wDownLayered[layerIndex].asByteArray(),
282274
phi3Config.hiddenDim(),
283275
phi3Config.dim(),
284276
LOCAL_WORK_GROUP_SIZE_ALLOC)

0 commit comments

Comments
 (0)