Skip to content

Commit d0966eb

Browse files
committed
Refactor tensor loading and introduce support for Half-Float precision in TornadoVM acceleration.
1 parent 0e5d5e5 commit d0966eb

File tree

11 files changed

+47
-12
lines changed

11 files changed

+47
-12
lines changed

set_paths

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
# Resolve root of this project (LLaMA3) and TornadoVM
88
export LLAMA_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
9-
export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm"
9+
#export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm"
1010

1111
# Set the path to TornadoVM SDK binaries
12-
export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk"
12+
#export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk"
1313

1414
# Add TornadoVM and LLaMA bin directories to PATH
1515
export PATH="${PATH}:${TORNADO_SDK}:${LLAMA_ROOT}"

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
583583
final Configuration configuration = model.configuration();
584584
final TornadoWeights weights = (TornadoWeights) model.weights();
585585

586-
MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);
586+
MemorySegment.copy(weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(), (long) token * configuration.dim() * Short.BYTES, state.embeddingX.getSegment(), 0, configuration.dim() * Short.BYTES);
587587

588588
return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
589589
}

src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import org.beehive.gpullama3.tensor.standard.FloatTensor;
55
import org.beehive.gpullama3.model.Configuration;
66
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
7+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
78
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
89

910
import java.util.stream.Stream;
@@ -52,6 +53,8 @@ protected StateFields createStateFields(Configuration config) {
5253
fields.wrapHb = new FloatArray(config.hiddenDim());
5354
fields.wrapHb2 = new FloatArray(config.hiddenDim());
5455

56+
fields.embeddingX = new HalfFloatArray(config.dim());
57+
5558
fields.wrapLogits = new FloatArray(config.vocabularySize());
5659
fields.wrapQ = new FloatArray(config.dim());
5760
fields.wrapK = new FloatArray(config.dim());

src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
77
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
89
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
910

1011
import java.util.stream.Stream;
@@ -79,6 +80,7 @@ protected StateFields createStateFields(Configuration config) {
7980
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(contextLength, kvDim)).limit(nLayers).toArray(FloatTensor[]::new);
8081

8182
// TornadoVM wrapper arrays for GPU acceleration
83+
fields.embeddingX = new HalfFloatArray(config.dim());
8284
fields.wrapX = new FloatArray(dim);
8385
fields.wrapXb = new FloatArray(dim);
8486
fields.wrapXb2 = new FloatArray(dim);

src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
77
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
89
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
910

1011
import java.util.stream.Stream;
@@ -40,6 +41,7 @@ protected StateFields createStateFields(Configuration configuration) {
4041
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
4142

4243
// TornadoVM wrappers with Qwen2 dimensions
44+
fields.embeddingX = new HalfFloatArray(config.dim());
4345
fields.wrapX = new FloatArray(config.dim());
4446
fields.wrapXb = new FloatArray(config.dim());
4547
fields.wrapXb2 = new FloatArray(config.dim());

src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
77
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
89
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
910

1011
import java.util.stream.Stream;
@@ -65,6 +66,8 @@ protected StateFields createStateFields(Configuration configuration) {
6566
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
6667

6768
// TornadoVM wrappers with Qwen3-specific sizes
69+
70+
fields.embeddingX = new HalfFloatArray(config.dim());
6871
fields.wrapX = new FloatArray(config.dim());
6972
fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads());
7073
fields.wrapXb2 = new FloatArray(config.dim());
@@ -74,7 +77,7 @@ protected StateFields createStateFields(Configuration configuration) {
7477
fields.wrapQ = new FloatArray(nEmbdHeadK * config.numberOfHeads());
7578
fields.wrapK = new FloatArray(nEmbdKGqa);
7679
fields.wrapV = new FloatArray(nEmbdKGqa);
77-
80+
fields.embeddingX = new HalfFloatArray(config.dim());
7881
fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
7982
fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
8083
fields.wrapValueCache.init(0.f);

src/main/java/org/beehive/gpullama3/inference/state/State.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import org.beehive.gpullama3.tensor.standard.FloatTensor;
44
import org.beehive.gpullama3.model.Configuration;
5+
import uk.ac.manchester.tornado.api.types.HalfFloat;
56
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
7+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
68
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
79

810
/**
@@ -57,6 +59,7 @@ public abstract class State {
5759
public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM.
5860
public final IntArray positionHolder;
5961

62+
public HalfFloatArray embeddingX;
6063
// store inter
6164
public int localSize;
6265
public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size.
@@ -88,6 +91,7 @@ protected State(Configuration config, int batchsize) {
8891
this.keyCache = fields.keyCache;
8992
this.valueCache = fields.valueCache;
9093

94+
this.embeddingX = fields.embeddingX;
9195
this.wrapX = fields.wrapX;
9296
this.wrapXb = fields.wrapXb;
9397
this.wrapXb2 = fields.wrapXb2;
@@ -121,6 +125,7 @@ protected static class StateFields {
121125
public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache;
122126
public IntArray positionHolder;
123127
public FloatArray temp, tempFFN, tempLogits;
128+
public HalfFloatArray embeddingX;
124129
}
125130

126131
@Override

src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
120120

121121
// Load all tensors uniformly as TornadoTensor hierarchy
122122
return new LlamaTornadoWeights(
123-
loadTornadoTensorAsFP32(tokenEmbeddings),
123+
loadTornadoTensor(tokenEmbeddings),
124124
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
125125
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
126126
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),

src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
130130

131131
// Load all tensors uniformly as TornadoTensor hierarchy
132132
return new LlamaTornadoWeights(
133-
loadTornadoTensorAsFP32(tokenEmbeddings),
133+
loadTornadoTensor(tokenEmbeddings),
134134
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
135135
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
136136
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),

src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import uk.ac.manchester.tornado.api.KernelContext;
44
import uk.ac.manchester.tornado.api.math.TornadoMath;
5+
import uk.ac.manchester.tornado.api.types.HalfFloat;
56
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
7+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
68

79
public class TransformerComputeKernels {
810

@@ -19,6 +21,18 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) {
1921
}
2022
}
2123

24+
public static void convertFP16toFP32(KernelContext context, HalfFloatArray x, FloatArray wrapX) {
25+
int i = context.globalIdx;
26+
wrapX.set(i, x.get(i).getFloat32());
27+
}
28+
29+
public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) {
30+
int i = context.globalIdx;
31+
float valInput = wrapX.get(i);
32+
HalfFloat val = new HalfFloat(valInput);
33+
x.set(i,val);
34+
}
35+
2236
/**
2337
* Performs RMS (Root Mean Square) normalization using parallel reduction.
2438
* This is a two-phase reduction: first within work groups, then across work groups.

0 commit comments

Comments
 (0)