Skip to content

Commit 9562505

Browse files
Use quantization-specific activation init in Phi3 models
1 parent c52bcaa commit 9562505

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ protected StateFields createStateFields(Configuration config) {
8080
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(contextLength, kvDim)).limit(nLayers).toArray(FloatTensor[]::new);
8181

8282
// TornadoVM wrapper arrays for GPU acceleration
83-
fields.embeddingX = new HalfFloatArray(config.dim());
83+
switch (config.modelType()) {
84+
case "FP16" -> fields.createActivationFP16(config.dim());
85+
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
86+
default -> throw new UnsupportedOperationException("Quantization format " + config.modelType());
87+
}
8488
fields.wrapX = new FloatArray(dim);
8589
fields.wrapXb = new FloatArray(dim);
8690
fields.wrapXb2 = new FloatArray(dim);

0 commit comments

Comments
 (0)