44import org .beehive .gpullama3 .tensor .GGMLType ;
55import org .beehive .gpullama3 .tensor .standard .FloatTensor ;
66import uk .ac .manchester .tornado .api .types .HalfFloat ;
7- import uk .ac .manchester .tornado .api .types .arrays .HalfFloatArray ;
8- import uk .ac .manchester .tornado .api .types .arrays .Int8Array ;
9- import uk .ac .manchester .tornado .api .types .arrays .TornadoNativeArray ;
7+ import uk .ac .manchester .tornado .api .types .arrays .*;
108
119import java .lang .foreign .MemorySegment ;
1210import java .lang .foreign .ValueLayout ;
1311import java .nio .ByteOrder ;
12+ import java .util .concurrent .*;
13+ import java .util .stream .IntStream ;
1414
1515public class Q8_0TornadoTensor extends TornadoTensor {
1616
@@ -71,7 +71,10 @@ public float getFloat(int index) {
7171 return quant * scale ;
7272 }
7373
74- public static Q8_0TornadoTensor create (GGMLTensorEntry entry ) {
74+ /**
75+ * Creates a Q8_0TornadoTensor from a GGMLTensorEntry (original implementation).
76+ */
77+ public static Q8_0TornadoTensor createAsQ8_0 (GGMLTensorEntry entry ) {
7578 if (entry .ggmlType () != GGMLType .Q8_0 ) {
7679 throw new IllegalArgumentException ("Expected Q8_0 tensor, got: " + entry .ggmlType () + " for tensor: " + entry .name ());
7780 }
@@ -97,22 +100,96 @@ public static Q8_0TornadoTensor create(GGMLTensorEntry entry) {
97100 ValueLayout .OfShort shortLayout = ValueLayout .JAVA_SHORT_UNALIGNED .withOrder (ByteOrder .LITTLE_ENDIAN );
98101 ValueLayout .OfByte byteLayout = ValueLayout .JAVA_BYTE ;
99102
100- for (int block = 0 ; block < numBlocks ; block ++) {
101- // TODO: use GGML type method for the 34L size
102- long blockOffset = block * 34L ; // 34 bytes per block
103+ // element-wise copy and unpack from MemorySegment to HalfFloatArray scales and Int8Array quants
104+ // use parallel streams and unroll inner loop for better performance
105+ IntStream .range (0 , numBlocks )
106+ .parallel ()
107+ .forEach (block -> {
108+ // TODO: use GGML type method for the 34L size
109+ long blockOffset = block * 34L ; // 34 bytes per block
110+
111+ // read fp16 scale (first 2 bytes of block)
112+ short scaleRaw = q8Segment .get (shortLayout , blockOffset );
113+ scales .set (block , new HalfFloat (scaleRaw ));
114+ int blockStart = block * 32 ;
115+
116+ // read 32 int8 quantized values (remaining bytes of block)
117+ // TODO: use GGML type method for the 32 size
118+ for (int i = 0 ; i < 32 ; i += 4 ) {
119+ // unroll inner loop for better performance
120+ byte q0 = q8Segment .get (byteLayout , blockOffset + 2 + i );
121+ byte q1 = q8Segment .get (byteLayout , blockOffset + 2 + i + 1 );
122+ byte q2 = q8Segment .get (byteLayout , blockOffset + 2 + i + 2 );
123+ byte q3 = q8Segment .get (byteLayout , blockOffset + 2 + i + 3 );
124+
125+ quants .set (blockStart + i , q0 );
126+ quants .set (blockStart + i + 1 , q1 );
127+ quants .set (blockStart + i + 2 , q2 );
128+ quants .set (blockStart + i + 3 , q3 );
129+ }
130+ });
103131
104- // read fp16 scale (first 2 bytes of block)
105- short scaleRaw = q8Segment .get (shortLayout , blockOffset );
106- scales .set (block , new HalfFloat (scaleRaw ));
132+ return new Q8_0TornadoTensor (size , scales , quants , q8Segment );
133+ }
107134
108- // read 32 int8 quantized values (remaining bytes of block)
109- // TODO: use GGML type method for the 32 size
110- for (int i = 0 ; i < 32 ; i ++) {
111- byte quantValue = q8Segment .get (byteLayout , blockOffset + 2 + i );
112- quants .set (block * 32 + i , quantValue );
113- }
135+ /**
136+ * Creates a Q8_0TornadoTensor formulated as FP32TornadoTensor object from a GGMLTensorEntry.
137+ * NOTE: Hack implementation to comply with FP32 inference.
138+ */
139+ public static FP32TornadoTensor createAsFP32 (GGMLTensorEntry entry ) {
140+ if (entry .ggmlType () != GGMLType .Q8_0 ) {
141+ throw new IllegalArgumentException ("Expected Q8_0 tensor, got: " + entry .ggmlType () + " for tensor: " + entry .name ());
114142 }
115143
116- return new Q8_0TornadoTensor (size , scales , quants , q8Segment );
144+ int [] shape = entry .shape ();
145+ int size = FloatTensor .numberOfElements (shape );
146+ int numBlocks = size / GGMLType .Q8_0 .getBlockSize ();
147+
148+ if (size % GGMLType .Q8_0 .getBlockSize () != 0 ) {
149+ throw new IllegalArgumentException ("Q8_0 tensor size must be multiple of " + GGMLType .Q8_0 .getBlockSize () + ", got: " + size + " for tensor: " + entry .name ());
150+ }
151+
152+ // TODO: fix Q8_0 loading in tornado layoyt
153+ // currently we end up to hack it by removing
154+ // tornado header from memory segment
155+ MemorySegment q8Segment = entry .memorySegment ().asSlice (TornadoNativeArray .ARRAY_HEADER );
156+
157+ // allocate the FloatArray to store the result
158+ FloatArray floatArray = new FloatArray (size );
159+
160+ // unpack Q8_0 blocks: [2 bytes fp16 scale][32 bytes int8 quants]
161+ ValueLayout .OfShort shortLayout = ValueLayout .JAVA_SHORT_UNALIGNED .withOrder (ByteOrder .LITTLE_ENDIAN );
162+ ValueLayout .OfByte byteLayout = ValueLayout .JAVA_BYTE ;
163+
164+ // element-wise dequantization and copy from MemorySegment to FloatArray
165+ // use parallel streams and unroll inner loop for better performance
166+ IntStream .range (0 , numBlocks )
167+ .parallel ()
168+ .forEach (block -> {
169+ // TODO: use GGML type method for the 34L size
170+ long blockOffset = block * 34L ; // 34 bytes per block
171+
172+ // read fp16 scale (first 2 bytes of block) and convert to float
173+ short scaleRaw = q8Segment .get (shortLayout , blockOffset );
174+ float scale = Float .float16ToFloat (scaleRaw );
175+ int blockStart = block * 32 ;
176+
177+ // read 32 int8 quantized values (remaining bytes of block)
178+ // TODO: use GGML type method for the 32 size
179+ for (int i = 0 ; i < 32 ; i += 4 ) {
180+ // unroll inner loop for better performance
181+ byte q0 = q8Segment .get (byteLayout , blockOffset + 2 + i );
182+ byte q1 = q8Segment .get (byteLayout , blockOffset + 2 + i + 1 );
183+ byte q2 = q8Segment .get (byteLayout , blockOffset + 2 + i + 2 );
184+ byte q3 = q8Segment .get (byteLayout , blockOffset + 2 + i + 3 );
185+
186+ floatArray .set (blockStart + i , q0 * scale );
187+ floatArray .set (blockStart + i + 1 , q1 * scale );
188+ floatArray .set (blockStart + i + 2 , q2 * scale );
189+ floatArray .set (blockStart + i + 3 , q3 * scale );
190+ }
191+ });
192+
193+ return new FP32TornadoTensor (floatArray );
117194 }
118195}
0 commit comments