Skip to content

Commit 4e984fa

Browse files
Use quantization-specific activation init in Qwen3 models
1 parent 111dbdd commit 4e984fa

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,12 @@ protected StateFields createStateFields(Configuration configuration) {
6767

6868
// TornadoVM wrappers with Qwen3-specific sizes
6969

70-
fields.embeddingX = new HalfFloatArray(config.dim());
70+
switch (config.modelType()) {
71+
case "FP16" -> fields.createActivationFP16(config.dim());
72+
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
73+
default -> throw new UnsupportedOperationException("Quantization format " + config.modelType());
74+
}
75+
7176
fields.wrapX = new FloatArray(config.dim());
7277
fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads());
7378
fields.wrapXb2 = new FloatArray(config.dim());
@@ -77,7 +82,6 @@ protected StateFields createStateFields(Configuration configuration) {
7782
fields.wrapQ = new FloatArray(nEmbdHeadK * config.numberOfHeads());
7883
fields.wrapK = new FloatArray(nEmbdKGqa);
7984
fields.wrapV = new FloatArray(nEmbdKGqa);
80-
fields.embeddingX = new HalfFloatArray(config.dim());
8185
fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
8286
fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
8387
fields.wrapValueCache.init(0.f);

0 commit comments

Comments
 (0)