Skip to content

Commit d74991f

Browse files
Optimize Q8_0 tensor loading with parallel streams and loop unrolling.
1 parent 6702382 commit d74991f

File tree

1 file changed

+58
-31
lines changed

1 file changed

+58
-31
lines changed

src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import java.lang.foreign.MemorySegment;
1010
import java.lang.foreign.ValueLayout;
1111
import java.nio.ByteOrder;
12+
import java.util.concurrent.*;
13+
import java.util.stream.IntStream;
1214

1315
public class Q8_0TornadoTensor extends TornadoTensor {
1416

@@ -98,21 +100,34 @@ public static Q8_0TornadoTensor createAsQ8_0(GGMLTensorEntry entry) {
98100
ValueLayout.OfShort shortLayout = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
99101
ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE;
100102

101-
for (int block = 0; block < numBlocks; block++) {
102-
// TODO: use GGML type method for the 34L size
103-
long blockOffset = block * 34L; // 34 bytes per block
104-
105-
// read fp16 scale (first 2 bytes of block)
106-
short scaleRaw = q8Segment.get(shortLayout, blockOffset);
107-
scales.set(block, new HalfFloat(scaleRaw));
108-
109-
// read 32 int8 quantized values (remaining bytes of block)
110-
// TODO: use GGML type method for the 32 size
111-
for (int i = 0; i < 32; i++) {
112-
byte quantValue = q8Segment.get(byteLayout, blockOffset + 2 + i);
113-
quants.set(block * 32 + i, quantValue);
114-
}
115-
}
103+
// element-wise copy and unpack from MemorySegment to HalfFloatArray scales and Int8Array quants
104+
// use parallel streams and unroll inner loop for better performance
105+
IntStream.range(0, numBlocks)
106+
.parallel()
107+
.forEach(block -> {
108+
// TODO: use GGML type method for the 34L size
109+
long blockOffset = block * 34L; // 34 bytes per block
110+
111+
// read fp16 scale (first 2 bytes of block)
112+
short scaleRaw = q8Segment.get(shortLayout, blockOffset);
113+
scales.set(block, new HalfFloat(scaleRaw));
114+
int blockStart = block * 32;
115+
116+
// read 32 int8 quantized values (remaining bytes of block)
117+
// TODO: use GGML type method for the 32 size
118+
for (int i = 0; i < 32; i += 4) {
119+
// unroll inner loop for better performance
120+
byte q0 = q8Segment.get(byteLayout, blockOffset + 2 + i);
121+
byte q1 = q8Segment.get(byteLayout, blockOffset + 2 + i + 1);
122+
byte q2 = q8Segment.get(byteLayout, blockOffset + 2 + i + 2);
123+
byte q3 = q8Segment.get(byteLayout, blockOffset + 2 + i + 3);
124+
125+
quants.set(blockStart + i, q0);
126+
quants.set(blockStart + i + 1, q1);
127+
quants.set(blockStart + i + 2, q2);
128+
quants.set(blockStart + i + 3, q3);
129+
}
130+
});
116131

117132
return new Q8_0TornadoTensor(size, scales, quants, q8Segment);
118133
}
@@ -146,22 +161,34 @@ public static FP32TornadoTensor createAsFP32(GGMLTensorEntry entry) {
146161
ValueLayout.OfShort shortLayout = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
147162
ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE;
148163

149-
for (int block = 0; block < numBlocks; block++) {
150-
// TODO: use GGML type method for the 34L size
151-
long blockOffset = block * 34L; // 34 bytes per block
152-
153-
// read fp16 scale (first 2 bytes of block) and convert to float
154-
short scaleRaw = q8Segment.get(shortLayout, blockOffset);
155-
float scale = Float.float16ToFloat(scaleRaw);
156-
157-
// read 32 int8 quantized values (remaining bytes of block)
158-
// TODO: use GGML type method for the 32 size
159-
for (int i = 0; i < 32; i++) {
160-
byte quantValue = q8Segment.get(byteLayout, blockOffset + 2 + i);
161-
float floatValue = quantValue * scale;
162-
floatArray.set(block * 32 + i, floatValue);
163-
}
164-
}
164+
// element-wise dequantization and copy from MemorySegment to FloatArray
165+
// use parallel streams and unroll inner loop for better performance
166+
IntStream.range(0, numBlocks)
167+
.parallel()
168+
.forEach(block -> {
169+
// TODO: use GGML type method for the 34L size
170+
long blockOffset = block * 34L; // 34 bytes per block
171+
172+
// read fp16 scale (first 2 bytes of block) and convert to float
173+
short scaleRaw = q8Segment.get(shortLayout, blockOffset);
174+
float scale = Float.float16ToFloat(scaleRaw);
175+
int blockStart = block * 32;
176+
177+
// read 32 int8 quantized values (remaining bytes of block)
178+
// TODO: use GGML type method for the 32 size
179+
for (int i = 0; i < 32; i += 4) {
180+
// unroll inner loop for better performance
181+
byte q0 = q8Segment.get(byteLayout, blockOffset + 2 + i);
182+
byte q1 = q8Segment.get(byteLayout, blockOffset + 2 + i + 1);
183+
byte q2 = q8Segment.get(byteLayout, blockOffset + 2 + i + 2);
184+
byte q3 = q8Segment.get(byteLayout, blockOffset + 2 + i + 3);
185+
186+
floatArray.set(blockStart + i, q0 * scale);
187+
floatArray.set(blockStart + i + 1, q1 * scale);
188+
floatArray.set(blockStart + i + 2, q2 * scale);
189+
floatArray.set(blockStart + i + 3, q3 * scale);
190+
}
191+
});
165192

166193
return new FP32TornadoTensor(floatArray);
167194
}

0 commit comments

Comments
 (0)