Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ca2b28a
Implement FP16 support in TornadoVM by introducing HalfFloat arrays, …
mikepapadim Dec 3, 2025
f0411ae
Introduce matrix-vector kernel with residual addition and enhance FP1…
mikepapadim Dec 3, 2025
6334ac3
Fused Q/K/V matrix-vector multiplication into a single kernel to redu…
mikepapadim Dec 3, 2025
46218a7
Fuse RoPE rotation and KV cache copy into a single kernel, update tas…
mikepapadim Dec 3, 2025
b48ec62
Add `mapContextWithQuantize` kernel, integrate into task graph, and d…
mikepapadim Dec 3, 2025
943da78
Refactor logits task graph to optimize kernel setup, update worker gr…
mikepapadim Dec 4, 2025
386dddc
Refactor FP16 FFN layers to streamline task graph setup, update worke…
mikepapadim Dec 4, 2025
b202bb4
Refactor FP16 FFN layers to streamline task graph setup, update worke…
mikepapadim Dec 4, 2025
3eba3b3
Refactor `LogitsFP16Layer` task graph to improve readability, optimiz…
mikepapadim Dec 4, 2025
2e010b1
Add `fusedFeedForwardWithSiLUAndGLUActivation` kernel for HalfFloat a…
mikepapadim Dec 4, 2025
4aef300
Document Transformer Layer Task Flow for `LlamaFP16FFNLayers` with de…
mikepapadim Dec 4, 2025
177ec9d
Set default profiler dump directory relative to `LLAMA_ROOT` when not…
mikepapadim Dec 4, 2025
a1c94fb
Add `fusedRmsNormFFNGateUp` kernel and update FP16 FFN task graph to …
mikepapadim Dec 4, 2025
577b6b1
Increase `BLOCK_SIZE_C` to 16 for Transformer kernel and update FP16 …
mikepapadim Dec 4, 2025
d5c1206
Increase `ropeWithCacheWorker` local work group size to 512 in FP16 F…
mikepapadim Dec 4, 2025
f91108c
Add fused kernels for Qwen3: `ropeRotationWithCacheCopy`, `fusedQKVMa…
mikepapadim Dec 4, 2025
67050bb
Merge branch 'feat/deq-n-compute' of github.com:beehive-lab/GPULlama3…
mikepapadim Dec 4, 2025
cfa3ba0
Add fused Q and K RMSNorm kernel and refactor task graph to consolida…
mikepapadim Dec 4, 2025
abf12d4
Refactor Qwen3 FP16 FFN layers to streamline worker grid setup, updat…
mikepapadim Dec 4, 2025
042b0b5
Add `processHeadsFlashAttentionOptV2` kernel with static memory size …
mikepapadim Dec 4, 2025
1cbe03a
Refactor Qwen3 FP16 FFN layers: remove unused imports, replace explic…
mikepapadim Dec 4, 2025
a4bc159
Refactor Qwen2 FP16 task graph: consolidate attention and FFN tasks w…
mikepapadim Dec 4, 2025
e15c229
Add `fusedQKvBiasAddition` kernel, refactor Qwen2 FP16 task graph to …
mikepapadim Dec 4, 2025
e7d79c9
Add support for HalfFloatArray in Phi3State and initialize FP16 wrapp…
mikepapadim Dec 4, 2025
02b1a2c
Add `splitQKV` and `splitGateUpSiLU` worker grids to Phi3 FP16 FFN la…
mikepapadim Dec 4, 2025
428e5cc
Refactor Phi3 FP16 FFN layers: replace `createRoPEWorker` with generi…
mikepapadim Dec 4, 2025
6c1ac6f
Add Phi3-specific fused kernels for RMSNorm+QKV and RMSNorm+Gate/Up, …
mikepapadim Dec 4, 2025
ed74652
Replace `splitQKV` kernel with `fusedRmsNormQKVMatmulDirect`, refacto…
mikepapadim Dec 4, 2025
8b52fbe
Remove unused `splitQKV` and RMS Apply+QKV Projection kernels, update…
mikepapadim Dec 4, 2025
977f0ba
Add `fusedRmsNormFFNGateUpSiLU` kernel to optimize Phi3 FFN flow, rep…
mikepapadim Dec 4, 2025
7e19032
Remove unused `splitQKV` and `splitGateUpSiLU` workers, clean up comm…
mikepapadim Dec 4, 2025
1e46405
Refactor Phi3 FP16 FFN layer task graph: improve readability by adjus…
mikepapadim Dec 4, 2025
7c63dc4
Refactor LogitsFP16Layer: streamline task graph setup, consolidate gr…
mikepapadim Dec 4, 2025
1a98725
Refactor `TransformerComputeKernelsLayered`: replace `matrixVectorRow…
mikepapadim Dec 7, 2025
d1ec408
Refactor `TransformerComputeKernelsLayered`: rename `matrixVectorRowM…
mikepapadim Dec 7, 2025
d1ff213
Refactor `TransformerComputeKernelsLayered`: rename `matrixVectorRowM…
mikepapadim Dec 7, 2025
36d0584
Merge branch 'main' of github.com:beehive-lab/GPULlama3.java into fea…
mikepapadim Dec 8, 2025
2c0c55c
Fix import
mikepapadim Dec 8, 2025
bbdffd3
Refactor LogitsQ8_0Layer: simplify grid scheduler setup, consolidate …
mikepapadim Dec 8, 2025
e2d8820
Refactor logits handling: replace `init` calls with `clear` for tenso…
mikepapadim Dec 8, 2025
fee6ea4
Refactor FFN layer task graphs: update data transfer logic by removin…
mikepapadim Dec 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion llama-tornado
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def create_parser() -> argparse.ArgumentParser:
)
debug_group.add_argument(
"--profiler-dump-dir",
default="/home/mikepapadim/repos/gpu-llama3.java/prof.json",
default=None,
help="Directory for profiler output",
)

Expand Down Expand Up @@ -498,6 +498,11 @@ def main():
parser = create_parser()
args = parser.parse_args()

# Set default profiler log path relative to LLAMA_ROOT
if args.profiler_dump_dir is None:
llama_root = os.environ.get("LLAMA_ROOT")
args.profiler_dump_dir = os.path.join(llama_root, "profiler-log.json")

# Set default seed if not provided
if args.seed is None:
args.seed = int(time.time())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ protected StateFields createStateFields(Configuration config) {
fields.wrapK = new FloatArray(config.dim());
fields.wrapV = new FloatArray(config.dim());

fields.wrapXFP16 = new HalfFloatArray(config.dim());
fields.wrapXbFP16 = new HalfFloatArray(config.dim());
// dim vs kvdim
fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ protected StateFields createStateFields(Configuration config) {
}
fields.wrapX = new FloatArray(dim);
fields.wrapXb = new FloatArray(dim);
fields.wrapXFP16 = new HalfFloatArray(dim);
fields.wrapXbFP16 = new HalfFloatArray(dim);
fields.wrapXb2 = new FloatArray(dim);
fields.wrapHb = new FloatArray(2 * hiddenDim);
fields.wrapHb2 = new FloatArray(hiddenDim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ protected StateFields createStateFields(Configuration configuration) {
}
fields.wrapX = new FloatArray(config.dim());
fields.wrapXb = new FloatArray(config.dim());
fields.wrapXbFP16 = new HalfFloatArray(config.dim());
fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
fields.wrapHb2 = new FloatArray(config.hiddenDim());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ protected StateFields createStateFields(Configuration configuration) {

fields.wrapX = new FloatArray(config.dim());
fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads());
fields.wrapXbFP16 = new HalfFloatArray(nEmbdHeadK * config.numberOfHeads());

fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
fields.wrapHb2 = new FloatArray(config.hiddenDim());
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/beehive/gpullama3/inference/state/State.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import org.beehive.gpullama3.model.Configuration;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.*;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

/**
* Represents the base state structure used during LLM inference.
Expand Down Expand Up @@ -58,13 +61,17 @@ public abstract class State {
public final IntArray positionHolder;

public TornadoNativeArray embeddingX;

public final HalfFloatArray wrapXbFP16; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage.

// store inter
public int localSize;
public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size.
public FloatArray tempFFN; // Temporary buffer for feed-forward network calculations, size adjusted for local workgroup size.
public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size.
public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models.

public HalfFloatArray wrapXFP16;
/** last index in previous block */

protected State(Configuration config, int batchsize) {
Expand Down Expand Up @@ -100,6 +107,9 @@ protected State(Configuration config, int batchsize) {
this.wrapK = fields.wrapK;
this.wrapV = fields.wrapV;

this.wrapXFP16 = fields.wrapXFP16;
this.wrapXbFP16 = fields.wrapXbFP16;

// dim vs kvdim
this.wrapKeyCache = fields.wrapKeyCache;
this.wrapValueCache = fields.wrapValueCache;
Expand Down Expand Up @@ -136,6 +146,7 @@ public void createActivationQ8_0(int size) {
int q8BytesNeeded = blocksNeeded * Q8_0_BLOCK_BYTES;
this.embeddingX = new ByteArray(q8BytesNeeded);
}
public HalfFloatArray wrapXFP16, wrapXbFP16;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package org.beehive.gpullama3.tornadovm;

import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
Expand Down Expand Up @@ -133,6 +133,8 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {

// Set the position in the state object (used by attention layers)
state.positionHolder.set(0, position);
state.temp.clear();
state.tempFFN.clear();

// 2. Execute each transformer layer graph sequentially
// Each graph computes attention and feed-forward transformations for one layer
Expand All @@ -141,7 +143,8 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
.execute();
}

state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f
state.wrapLogits.clear(); // Clear the output logits tensor -> set to 0f
// 3. Execute the final graph that projects the last hidden state to output logits
executionPlan.withGraph(getFinalLogitsGraphIndex())
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
Expand Down Expand Up @@ -179,7 +182,7 @@ private int getFinalLogitsGraphIndex() {
/// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer.
public void forceCopyInReadOnlyDataLayered() {
// Execute all TornadoVM graphs
state.wrapX.init(0.0f);
state.wrapX.clear();
state.positionHolder.init(0);

// Execute activation update graph
Expand Down
Loading