Skip to content

Commit 843e30c

Browse files
Add FP16 and Q8_0 activation initialization methods in State class
1 parent 68729ee commit 843e30c

File tree

1 file changed

+13
-4
lines changed
  • src/main/java/org/beehive/gpullama3/inference/state

1 file changed

+13
-4
lines changed

src/main/java/org/beehive/gpullama3/inference/state/State.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
import org.beehive.gpullama3.tensor.standard.FloatTensor;
44
import org.beehive.gpullama3.model.Configuration;
55
import uk.ac.manchester.tornado.api.types.HalfFloat;
6-
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
7-
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
8-
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
9-
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
6+
import uk.ac.manchester.tornado.api.types.arrays.*;
107

118
/**
129
* Represents the base state structure used during LLM inference.
@@ -127,6 +124,18 @@ protected static class StateFields {
127124
public IntArray positionHolder;
128125
public FloatArray temp, tempFFN, tempLogits;
129126
public TornadoNativeArray embeddingX;
127+
128+
public void createActivationFP16(int size) {
129+
this.embeddingX = new HalfFloatArray(size);
130+
}
131+
132+
public void createActivationQ8_0(int size) {
133+
int blockSize = 32;
134+
int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants
135+
int blocksNeeded = (size + blockSize - 1) / blockSize;
136+
int q8BytesNeeded = blocksNeeded * Q8_0_BLOCK_BYTES;
137+
this.embeddingX = new ByteArray(q8BytesNeeded);
138+
}
130139
}
131140

132141
@Override

0 commit comments

Comments
 (0)