Skip to content

Commit 00a4faa

Browse files
Rename modelType to quantization across configurations and update associated usages.
1 parent 2c8cf24 commit 00a4faa

File tree

10 files changed

+25
-27
lines changed

10 files changed

+25
-27
lines changed

src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,11 @@ protected StateFields createStateFields(Configuration config) {
5454
fields.wrapHb = new FloatArray(config.hiddenDim());
5555
fields.wrapHb2 = new FloatArray(config.hiddenDim());
5656

57-
switch (config.modelType()) {
57+
switch (config.quantization()) {
5858
case "FP16" -> fields.createActivationFP16(config.dim());
5959
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
60-
default -> throw new UnsupportedOperationException("Quantization format " + config.modelType());
60+
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
6161
}
62-
63-
6462
fields.wrapLogits = new FloatArray(config.vocabularySize());
6563
fields.wrapQ = new FloatArray(config.dim());
6664
fields.wrapK = new FloatArray(config.dim());

src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ protected StateFields createStateFields(Configuration config) {
8080
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(contextLength, kvDim)).limit(nLayers).toArray(FloatTensor[]::new);
8181

8282
// TornadoVM wrapper arrays for GPU acceleration
83-
switch (config.modelType()) {
83+
switch (config.quantization()) {
8484
case "FP16" -> fields.createActivationFP16(config.dim());
8585
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
86-
default -> throw new UnsupportedOperationException("Quantization format " + config.modelType());
86+
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
8787
}
8888
fields.wrapX = new FloatArray(dim);
8989
fields.wrapXb = new FloatArray(dim);

src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ protected StateFields createStateFields(Configuration configuration) {
4141
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
4242

4343
// TornadoVM wrappers with Qwen2 dimensions
44-
switch (config.modelType()) {
44+
switch (config.quantization()) {
4545
case "FP16" -> fields.createActivationFP16(config.dim());
4646
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
47-
default -> throw new UnsupportedOperationException("Quantization format " + config.modelType());
47+
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
4848
}
4949
fields.wrapX = new FloatArray(config.dim());
5050
fields.wrapXb = new FloatArray(config.dim());

src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ protected StateFields createStateFields(Configuration configuration) {
6767

6868
// TornadoVM wrappers with Qwen3-specific sizes
6969

70-
switch (config.modelType()) {
70+
switch (config.quantization()) {
7171
case "FP16" -> fields.createActivationFP16(config.dim());
7272
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
73-
default -> throw new UnsupportedOperationException("Quantization format " + config.modelType());
73+
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
7474
}
7575

7676
fields.wrapX = new FloatArray(config.dim());

src/main/java/org/beehive/gpullama3/model/Configuration.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
public interface Configuration {
44

5-
String modelType();
5+
String quantization();
66

77
/** Transformer embedding dimension */
88
int dim();

src/main/java/org/beehive/gpullama3/model/llama/LlamaConfiguration.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import org.beehive.gpullama3.model.Configuration;
44

55
// @formatter:off
6-
public record LlamaConfiguration(String type,
6+
public record LlamaConfiguration(String quantization,
77
int dim,
88
int hiddenDim,
99
int numberOfLayers,
@@ -15,8 +15,8 @@ public record LlamaConfiguration(String type,
1515
float ropeTheta) implements Configuration {
1616

1717
@Override
18-
public String modelType() {
19-
return type;
18+
public String quantization() {
19+
return quantization;
2020
}
2121

2222
@Override
@@ -57,7 +57,7 @@ public LlamaConfiguration withContextLength(int newContextLength) {
5757
return this; // no change
5858
}
5959
return new LlamaConfiguration(
60-
this.type,
60+
this.quantization,
6161
this.dim,
6262
this.hiddenDim,
6363
this.numberOfLayers,

src/main/java/org/beehive/gpullama3/model/mistral/MistralConfiguration.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import org.beehive.gpullama3.model.Configuration;
44

55
// @formatter:off
6-
public record MistralConfiguration(String type,
6+
public record MistralConfiguration(String quantization,
77
int dim,
88
int hiddenDim,
99
int numberOfLayers,
@@ -15,8 +15,8 @@ public record MistralConfiguration(String type,
1515
float rmsNormEps,
1616
float ropeTheta) implements Configuration {
1717

18-
@Override public String modelType() {
19-
return type;
18+
@Override public String quantization() {
19+
return quantization;
2020
}
2121

2222
public int kvDim() {

src/main/java/org/beehive/gpullama3/model/phi3/Phi3Configuration.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import org.beehive.gpullama3.model.Configuration;
44

55
// @formatter:off
6-
public record Phi3Configuration(String type,
6+
public record Phi3Configuration(String quantization,
77
int dim,
88
int hiddenDim,
99
int numberOfLayers,
@@ -14,8 +14,8 @@ public record Phi3Configuration(String type,
1414
float rmsNormEps,
1515
float ropeTheta) implements Configuration {
1616

17-
@Override public String modelType() {
18-
return type;
17+
@Override public String quantization() {
18+
return quantization;
1919
}
2020

2121
@Override

src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2Configuration.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import org.beehive.gpullama3.model.Configuration;
44

5-
public record Qwen2Configuration(String type,
5+
public record Qwen2Configuration(String quantization,
66
int dim,
77
int hiddenDim,
88
int numberOfLayers,
@@ -17,8 +17,8 @@ public record Qwen2Configuration(String type,
1717
float rmsNormEps,
1818
float ropeTheta) implements Configuration {
1919
@Override
20-
public String modelType() {
21-
return type;
20+
public String quantization() {
21+
return quantization;
2222
}
2323

2424
@Override

src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3Configuration.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import org.beehive.gpullama3.model.Configuration;
44

55
// @formatter:off
6-
public record Qwen3Configuration(String type,
6+
public record Qwen3Configuration(String quantization,
77
int dim,
88
int hiddenDim,
99
int numberOfLayers,
@@ -18,8 +18,8 @@ public record Qwen3Configuration(String type,
1818
float rmsNormEps,
1919
float ropeTheta) implements Configuration {
2020

21-
@Override public String modelType() {
22-
return type;
21+
@Override public String quantization() {
22+
return quantization;
2323
}
2424

2525
@Override

0 commit comments

Comments
 (0)