Skip to content

Commit e63d919

Browse files
Cleanup and document Q8_0TornadoTensor
1 parent aef06cc commit e63d919

File tree

1 file changed

+8
-68
lines changed

1 file changed

+8
-68
lines changed
Lines changed: 8 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,29 @@
11
package org.beehive.gpullama3.tensor.tornado;
22

3-
import org.beehive.gpullama3.tensor.GGMLTensorEntry;
43
import org.beehive.gpullama3.tensor.GGMLType;
5-
import org.beehive.gpullama3.tensor.standard.FloatTensor;
6-
import uk.ac.manchester.tornado.api.types.HalfFloat;
74
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
8-
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
9-
import uk.ac.manchester.tornado.api.types.arrays.Int8Array;
10-
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
115

126
import java.lang.foreign.MemorySegment;
13-
import java.lang.foreign.ValueLayout;
14-
import java.nio.ByteOrder;
15-
import java.util.concurrent.*;
16-
import java.util.stream.IntStream;
177

8+
/**
9+
* This class represents a quantized tensor in the {@link GGMLType#Q8_0} format.
10+
* It is backed by a {@link ByteArray} containing both the quantized values and the scale factors.
11+
* The underlying {@link ByteArray} contains N Q8_0 blocks, where N is the tensor size divided by 32.:
12+
* Each Q8_0 Block has the following layout:
13+
* [Scale Factor (fp16) - 2 bytes] [Quantized Value 0 (int8) - 1 byte] ... [Quantized Value 31 (int8) - 1 byte]
14+
*/
1815
public class Q8_0TornadoTensor extends TornadoTensor {
1916

20-
private final int size;
21-
private final HalfFloatArray scales; // One per 32-element block
22-
private final Int8Array quants; // Quantized int8 values
23-
private MemorySegment segment;
24-
25-
private final ByteArray tornadoNativeArray;
26-
27-
public Q8_0TornadoTensor(int size, HalfFloatArray scales, Int8Array quants, MemorySegment segment) {
28-
this.size = size;
29-
this.scales = scales;
30-
this.quants = quants;
31-
this.segment = segment;
32-
this.tornadoNativeArray = null;
33-
}
17+
private final ByteArray tornadoNativeArray; // Unified Q8_0 tensor in the memorySegment of the ByteArray
3418

3519
public Q8_0TornadoTensor(ByteArray byteArray) {
36-
this.size = -1;
37-
this.scales = null;
38-
this.quants = null;
39-
this.segment = null;
4020
this.tornadoNativeArray = byteArray;
4121
}
4222

4323
public static Q8_0TornadoTensor fromTornadoMemorySegment(MemorySegment segment) {
4424
return new Q8_0TornadoTensor(ByteArray.fromSegmentShallow(segment));
4525
}
4626

47-
public int getSize() {
48-
return size;
49-
}
50-
51-
/**
52-
* Returns the scale factors for GPU kernels.
53-
*
54-
* @return HalfFloatArray containing fp16 scale factors
55-
*/
56-
public HalfFloatArray getScales() {
57-
return scales;
58-
}
59-
60-
/**
61-
* Returns the quantized values for GPU kernels.
62-
*
63-
* @return Int8Array containing quantized int8 values
64-
*/
65-
public Int8Array getQuants() {
66-
return quants;
67-
}
68-
6927
@Override
7028
public ByteArray asByteArray() {
7129
return tornadoNativeArray;
@@ -76,22 +34,4 @@ public GGMLType type() {
7634
return GGMLType.Q8_0;
7735
}
7836

79-
public MemorySegment asMemorySegment() {
80-
return segment;
81-
}
82-
83-
/**
84-
* Dequantizes and returns a single float value.
85-
*
86-
* @param index Element index
87-
* @return Dequantized float value
88-
*/
89-
public float getFloat(int index) {
90-
assert 0 <= index;
91-
int blockIdx = index / GGMLType.Q8_0.getBlockSize();
92-
float scale = scales.get(blockIdx).getFloat32();
93-
byte quant = quants.get(index);
94-
return quant * scale;
95-
}
96-
9737
}

0 commit comments

Comments
 (0)