Skip to content

Commit 288e2f1

Browse files
committed
Merge branch 'main' of github.com:beehive-lab/GPULlama3.java
2 parents 951911e + 6fb2e83 commit 288e2f1

File tree

2 files changed

+97
-29
lines changed

2 files changed

+97
-29
lines changed

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) {
127127
return switch (ggmlType) {
128128
case F32 -> FP32TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
129129
case F16 -> FP16TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
130-
case Q8_0 -> Q8_0TornadoTensor.create(entry);
130+
case Q8_0 -> Q8_0TornadoTensor.createAsQ8_0(entry);
131131
case Q4_0 -> throw new UnsupportedOperationException("Q4 format not supported yet");
132132
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
133133
};
@@ -163,16 +163,7 @@ public static TornadoTensor loadTornadoTensorAsFP32(GGMLTensorEntry entry) {
163163
}
164164
yield new FP32TornadoTensor(tensorFA);
165165
}
166-
case Q8_0 -> {
167-
Q8_0TornadoTensor tensorQ8_0 = Q8_0TornadoTensor.create(entry);
168-
int numOfElements = tensorQ8_0.getSize();
169-
FloatArray tensorFA = new FloatArray(numOfElements);
170-
for (int i = 0; i < numOfElements; i++) {
171-
tensorFA.set(i, tensorQ8_0.getFloat(i));
172-
}
173-
yield new FP32TornadoTensor(tensorFA);
174-
175-
}
166+
case Q8_0 -> Q8_0TornadoTensor.createAsFP32(entry);
176167
default -> {
177168
throw new UnsupportedOperationException("Unsupported tensor type: " + tensor.type());
178169
}
@@ -200,7 +191,7 @@ public static HalfFloatArray[] loadArrayAsHalfFloatArray(int size, IntFunction<G
200191
public static Q8_0TornadoTensor[] loadArrayAsQ8_0TornadoTensor(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
201192
Q8_0TornadoTensor[] array = new Q8_0TornadoTensor[size];
202193
for (int i = 0; i < size; i++) {
203-
array[i] = Q8_0TornadoTensor.create(getTensorEntry.apply(i));
194+
array[i] = Q8_0TornadoTensor.createAsQ8_0(getTensorEntry.apply(i));
204195
}
205196
return array;
206197
}

src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java

Lines changed: 94 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import org.beehive.gpullama3.tensor.GGMLType;
55
import org.beehive.gpullama3.tensor.standard.FloatTensor;
66
import uk.ac.manchester.tornado.api.types.HalfFloat;
7-
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
8-
import uk.ac.manchester.tornado.api.types.arrays.Int8Array;
9-
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
7+
import uk.ac.manchester.tornado.api.types.arrays.*;
108

119
import java.lang.foreign.MemorySegment;
1210
import java.lang.foreign.ValueLayout;
1311
import java.nio.ByteOrder;
12+
import java.util.concurrent.*;
13+
import java.util.stream.IntStream;
1414

1515
public class Q8_0TornadoTensor extends TornadoTensor {
1616

@@ -71,7 +71,10 @@ public float getFloat(int index) {
7171
return quant * scale;
7272
}
7373

74-
public static Q8_0TornadoTensor create(GGMLTensorEntry entry) {
74+
/**
75+
* Creates a Q8_0TornadoTensor from a GGMLTensorEntry (original implementation).
76+
*/
77+
public static Q8_0TornadoTensor createAsQ8_0(GGMLTensorEntry entry) {
7578
if (entry.ggmlType() != GGMLType.Q8_0) {
7679
throw new IllegalArgumentException("Expected Q8_0 tensor, got: " + entry.ggmlType() + " for tensor: " + entry.name());
7780
}
@@ -97,22 +100,96 @@ public static Q8_0TornadoTensor create(GGMLTensorEntry entry) {
97100
ValueLayout.OfShort shortLayout = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
98101
ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE;
99102

100-
for (int block = 0; block < numBlocks; block++) {
101-
// TODO: use GGML type method for the 34L size
102-
long blockOffset = block * 34L; // 34 bytes per block
103+
// element-wise copy and unpack from MemorySegment to HalfFloatArray scales and Int8Array quants
104+
// use parallel streams and unroll inner loop for better performance
105+
IntStream.range(0, numBlocks)
106+
.parallel()
107+
.forEach(block -> {
108+
// TODO: use GGML type method for the 34L size
109+
long blockOffset = block * 34L; // 34 bytes per block
110+
111+
// read fp16 scale (first 2 bytes of block)
112+
short scaleRaw = q8Segment.get(shortLayout, blockOffset);
113+
scales.set(block, new HalfFloat(scaleRaw));
114+
int blockStart = block * 32;
115+
116+
// read 32 int8 quantized values (remaining bytes of block)
117+
// TODO: use GGML type method for the 32 size
118+
for (int i = 0; i < 32; i += 4) {
119+
// unroll inner loop for better performance
120+
byte q0 = q8Segment.get(byteLayout, blockOffset + 2 + i);
121+
byte q1 = q8Segment.get(byteLayout, blockOffset + 2 + i + 1);
122+
byte q2 = q8Segment.get(byteLayout, blockOffset + 2 + i + 2);
123+
byte q3 = q8Segment.get(byteLayout, blockOffset + 2 + i + 3);
124+
125+
quants.set(blockStart + i, q0);
126+
quants.set(blockStart + i + 1, q1);
127+
quants.set(blockStart + i + 2, q2);
128+
quants.set(blockStart + i + 3, q3);
129+
}
130+
});
103131

104-
// read fp16 scale (first 2 bytes of block)
105-
short scaleRaw = q8Segment.get(shortLayout, blockOffset);
106-
scales.set(block, new HalfFloat(scaleRaw));
132+
return new Q8_0TornadoTensor(size, scales, quants, q8Segment);
133+
}
107134

108-
// read 32 int8 quantized values (remaining bytes of block)
109-
// TODO: use GGML type method for the 32 size
110-
for (int i = 0; i < 32; i++) {
111-
byte quantValue = q8Segment.get(byteLayout, blockOffset + 2 + i);
112-
quants.set(block * 32 + i, quantValue);
113-
}
135+
/**
136+
* Creates a Q8_0TornadoTensor formulated as FP32TornadoTensor object from a GGMLTensorEntry.
137+
* NOTE: Hack implementation to comply with FP32 inference.
138+
*/
139+
public static FP32TornadoTensor createAsFP32(GGMLTensorEntry entry) {
140+
if (entry.ggmlType() != GGMLType.Q8_0) {
141+
throw new IllegalArgumentException("Expected Q8_0 tensor, got: " + entry.ggmlType() + " for tensor: " + entry.name());
114142
}
115143

116-
return new Q8_0TornadoTensor(size, scales, quants, q8Segment);
144+
int[] shape = entry.shape();
145+
int size = FloatTensor.numberOfElements(shape);
146+
int numBlocks = size / GGMLType.Q8_0.getBlockSize();
147+
148+
if (size % GGMLType.Q8_0.getBlockSize() != 0) {
149+
throw new IllegalArgumentException("Q8_0 tensor size must be multiple of " + GGMLType.Q8_0.getBlockSize() + ", got: " + size + " for tensor: " + entry.name());
150+
}
151+
152+
// TODO: fix Q8_0 loading in tornado layoyt
153+
// currently we end up to hack it by removing
154+
// tornado header from memory segment
155+
MemorySegment q8Segment = entry.memorySegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
156+
157+
// allocate the FloatArray to store the result
158+
FloatArray floatArray = new FloatArray(size);
159+
160+
// unpack Q8_0 blocks: [2 bytes fp16 scale][32 bytes int8 quants]
161+
ValueLayout.OfShort shortLayout = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
162+
ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE;
163+
164+
// element-wise dequantization and copy from MemorySegment to FloatArray
165+
// use parallel streams and unroll inner loop for better performance
166+
IntStream.range(0, numBlocks)
167+
.parallel()
168+
.forEach(block -> {
169+
// TODO: use GGML type method for the 34L size
170+
long blockOffset = block * 34L; // 34 bytes per block
171+
172+
// read fp16 scale (first 2 bytes of block) and convert to float
173+
short scaleRaw = q8Segment.get(shortLayout, blockOffset);
174+
float scale = Float.float16ToFloat(scaleRaw);
175+
int blockStart = block * 32;
176+
177+
// read 32 int8 quantized values (remaining bytes of block)
178+
// TODO: use GGML type method for the 32 size
179+
for (int i = 0; i < 32; i += 4) {
180+
// unroll inner loop for better performance
181+
byte q0 = q8Segment.get(byteLayout, blockOffset + 2 + i);
182+
byte q1 = q8Segment.get(byteLayout, blockOffset + 2 + i + 1);
183+
byte q2 = q8Segment.get(byteLayout, blockOffset + 2 + i + 2);
184+
byte q3 = q8Segment.get(byteLayout, blockOffset + 2 + i + 3);
185+
186+
floatArray.set(blockStart + i, q0 * scale);
187+
floatArray.set(blockStart + i + 1, q1 * scale);
188+
floatArray.set(blockStart + i + 2, q2 * scale);
189+
floatArray.set(blockStart + i + 3, q3 * scale);
190+
}
191+
});
192+
193+
return new FP32TornadoTensor(floatArray);
117194
}
118195
}

0 commit comments

Comments
 (0)