11package org .beehive .gpullama3 .tensor .tornado ;
22
3- import org .beehive .gpullama3 .tensor .GGMLTensorEntry ;
43import org .beehive .gpullama3 .tensor .GGMLType ;
5- import org .beehive .gpullama3 .tensor .standard .FloatTensor ;
6- import uk .ac .manchester .tornado .api .types .HalfFloat ;
74import uk .ac .manchester .tornado .api .types .arrays .ByteArray ;
8- import uk .ac .manchester .tornado .api .types .arrays .HalfFloatArray ;
9- import uk .ac .manchester .tornado .api .types .arrays .Int8Array ;
10- import uk .ac .manchester .tornado .api .types .arrays .TornadoNativeArray ;
115
126import java .lang .foreign .MemorySegment ;
13- import java .lang .foreign .ValueLayout ;
14- import java .nio .ByteOrder ;
15- import java .util .concurrent .*;
16- import java .util .stream .IntStream ;
177
8+ /**
9+ * This class represents a quantized tensor in the {@link GGMLType#Q8_0} format.
10+ * It is backed by a {@link ByteArray} containing both the quantized values and the scale factors.
11+ * The underlying {@link ByteArray} contains N Q8_0 blocks, where N is the tensor size divided by 32.:
12+ * Each Q8_0 Block has the following layout:
13+ * [Scale Factor (fp16) - 2 bytes] [Quantized Value 0 (int8) - 1 byte] ... [Quantized Value 31 (int8) - 1 byte]
14+ */
1815public class Q8_0TornadoTensor extends TornadoTensor {
1916
20- private final int size ;
21- private final HalfFloatArray scales ; // One per 32-element block
22- private final Int8Array quants ; // Quantized int8 values
23- private MemorySegment segment ;
24-
25- private final ByteArray tornadoNativeArray ;
26-
27- public Q8_0TornadoTensor (int size , HalfFloatArray scales , Int8Array quants , MemorySegment segment ) {
28- this .size = size ;
29- this .scales = scales ;
30- this .quants = quants ;
31- this .segment = segment ;
32- this .tornadoNativeArray = null ;
33- }
17+ private final ByteArray tornadoNativeArray ; // Unified Q8_0 tensor in the memorySegment of the ByteArray
3418
3519 public Q8_0TornadoTensor (ByteArray byteArray ) {
36- this .size = -1 ;
37- this .scales = null ;
38- this .quants = null ;
39- this .segment = null ;
4020 this .tornadoNativeArray = byteArray ;
4121 }
4222
4323 public static Q8_0TornadoTensor fromTornadoMemorySegment (MemorySegment segment ) {
4424 return new Q8_0TornadoTensor (ByteArray .fromSegmentShallow (segment ));
4525 }
4626
47- public int getSize () {
48- return size ;
49- }
50-
51- /**
52- * Returns the scale factors for GPU kernels.
53- *
54- * @return HalfFloatArray containing fp16 scale factors
55- */
56- public HalfFloatArray getScales () {
57- return scales ;
58- }
59-
60- /**
61- * Returns the quantized values for GPU kernels.
62- *
63- * @return Int8Array containing quantized int8 values
64- */
65- public Int8Array getQuants () {
66- return quants ;
67- }
68-
6927 @ Override
7028 public ByteArray asByteArray () {
7129 return tornadoNativeArray ;
@@ -76,22 +34,4 @@ public GGMLType type() {
7634 return GGMLType .Q8_0 ;
7735 }
7836
79- public MemorySegment asMemorySegment () {
80- return segment ;
81- }
82-
83- /**
84- * Dequantizes and returns a single float value.
85- *
86- * @param index Element index
87- * @return Dequantized float value
88- */
89- public float getFloat (int index ) {
90- assert 0 <= index ;
91- int blockIdx = index / GGMLType .Q8_0 .getBlockSize ();
92- float scale = scales .get (blockIdx ).getFloat32 ();
93- byte quant = quants .get (index );
94- return quant * scale ;
95- }
96-
9737}
0 commit comments