diff --git a/llama-tornado b/llama-tornado
index 9c0d6ba8..c98090f8 100755
--- a/llama-tornado
+++ b/llama-tornado
@@ -422,7 +422,7 @@ def create_parser() -> argparse.ArgumentParser:
)
debug_group.add_argument(
"--profiler-dump-dir",
- default="/home/mikepapadim/repos/gpu-llama3.java/prof.json",
+ default=None,
help="Directory for profiler output",
)
@@ -498,6 +498,11 @@ def main():
parser = create_parser()
args = parser.parse_args()
+ # Set default profiler log path relative to LLAMA_ROOT
+ if args.profiler_dump_dir is None:
+ llama_root = os.environ.get("LLAMA_ROOT")
+ args.profiler_dump_dir = os.path.join(llama_root, "profiler-log.json")
+
# Set default seed if not provided
if args.seed is None:
args.seed = int(time.time())
diff --git a/pom.xml b/pom.xml
index 814df7f5..a59830ef 100644
--- a/pom.xml
+++ b/pom.xml
@@ -54,12 +54,12 @@
io.github.beehive-lab
tornado-api
- 2.0.1-dev
+ 2.1.0
io.github.beehive-lab
tornado-runtime
- 2.0.1-dev
+ 2.1.0
diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java
index a24b9626..21344223 100644
--- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java
+++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java
@@ -64,6 +64,8 @@ protected StateFields createStateFields(Configuration config) {
fields.wrapK = new FloatArray(config.dim());
fields.wrapV = new FloatArray(config.dim());
+ fields.wrapXFP16 = new HalfFloatArray(config.dim());
+ fields.wrapXbFP16 = new HalfFloatArray(config.dim());
// dim vs kvdim
fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java
index 2ae4d269..79115922 100644
--- a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java
+++ b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java
@@ -87,6 +87,8 @@ protected StateFields createStateFields(Configuration config) {
}
fields.wrapX = new FloatArray(dim);
fields.wrapXb = new FloatArray(dim);
+ fields.wrapXFP16 = new HalfFloatArray(dim);
+ fields.wrapXbFP16 = new HalfFloatArray(dim);
fields.wrapXb2 = new FloatArray(dim);
fields.wrapHb = new FloatArray(2 * hiddenDim);
fields.wrapHb2 = new FloatArray(hiddenDim);
diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java
index 23939bba..266730ac 100644
--- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java
+++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java
@@ -48,6 +48,7 @@ protected StateFields createStateFields(Configuration configuration) {
}
fields.wrapX = new FloatArray(config.dim());
fields.wrapXb = new FloatArray(config.dim());
+ fields.wrapXbFP16 = new HalfFloatArray(config.dim());
fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
fields.wrapHb2 = new FloatArray(config.hiddenDim());
diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java
index 6f90398f..d70625b9 100644
--- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java
+++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java
@@ -75,6 +75,8 @@ protected StateFields createStateFields(Configuration configuration) {
fields.wrapX = new FloatArray(config.dim());
fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads());
+ fields.wrapXbFP16 = new HalfFloatArray(nEmbdHeadK * config.numberOfHeads());
+
fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
fields.wrapHb2 = new FloatArray(config.hiddenDim());
diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java
index 3f972e1b..f8e9906a 100644
--- a/src/main/java/org/beehive/gpullama3/inference/state/State.java
+++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java
@@ -4,6 +4,9 @@
import org.beehive.gpullama3.model.Configuration;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.*;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
/**
* Represents the base state structure used during LLM inference.
@@ -58,6 +61,9 @@ public abstract class State {
public final IntArray positionHolder;
public TornadoNativeArray embeddingX;
+
+ public final HalfFloatArray wrapXbFP16; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage.
+
// store inter
public int localSize;
public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size.
@@ -65,6 +71,7 @@ public abstract class State {
public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size.
public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models.
+ public HalfFloatArray wrapXFP16;
/** last index in previous block */
protected State(Configuration config, int batchsize) {
@@ -100,6 +107,9 @@ protected State(Configuration config, int batchsize) {
this.wrapK = fields.wrapK;
this.wrapV = fields.wrapV;
+ this.wrapXFP16 = fields.wrapXFP16;
+ this.wrapXbFP16 = fields.wrapXbFP16;
+
// dim vs kvdim
this.wrapKeyCache = fields.wrapKeyCache;
this.wrapValueCache = fields.wrapValueCache;
@@ -136,6 +146,7 @@ public void createActivationQ8_0(int size) {
int q8BytesNeeded = blocksNeeded * Q8_0_BLOCK_BYTES;
this.embeddingX = new ByteArray(q8BytesNeeded);
}
+ public HalfFloatArray wrapXFP16, wrapXbFP16;
}
@Override
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
index 293d2c0c..eadd2e68 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
@@ -1,9 +1,9 @@
package org.beehive.gpullama3.tornadovm;
-import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
@@ -133,6 +133,8 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {
// Set the position in the state object (used by attention layers)
state.positionHolder.set(0, position);
+ state.temp.clear();
+ state.tempFFN.clear();
// 2. Execute each transformer layer graph sequentially
// Each graph computes attention and feed-forward transformations for one layer
@@ -141,7 +143,8 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
.execute();
}
-
+ state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f
+ state.wrapLogits.clear(); // Clear the output logits tensor -> set to 0f
// 3. Execute the final graph that projects the last hidden state to output logits
executionPlan.withGraph(getFinalLogitsGraphIndex())
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
@@ -179,7 +182,7 @@ private int getFinalLogitsGraphIndex() {
/// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer.
public void forceCopyInReadOnlyDataLayered() {
// Execute all TornadoVM graphs
- state.wrapX.init(0.0f);
+ state.wrapX.clear();
state.positionHolder.init(0);
// Execute activation update graph
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java
new file mode 100644
index 00000000..d45a62c4
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java
@@ -0,0 +1,380 @@
+package org.beehive.gpullama3.tornadovm.kernels;
+
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.math.TornadoMath;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+
+/**
+ * Phi3Kernels: Optimized GPU kernels for Phi3 model family.
+ *
+ *
Key differences from Qwen/Llama kernels:
+ *
+ * - Generic fused RMS + matmul (single output matrix)
+ * - Phi3 RoPE with headSize/2 offset pattern
+ * - Combined gate/up structure support
+ *
+ */
+public class Phi3Kernels {
+
+ /**
+ * Fused RMSNorm apply + single matrix-vector multiplication.
+ *
+ * Combines RMS normalization application with a generic matmul in one kernel,
+ * reducing memory bandwidth by avoiding intermediate storage.
+ *
+ * Formula: output[row] = sum_j(W[row,j] * rmsWeight[j] * scale * x[j])
+ *
+ * Use cases:
+ *
+ * - Phi3 combined QKV projection (output = wqkv · RMSNorm(x))
+ * - Phi3 combined gate/up projection (output = wUp · RMSNorm(x))
+ * - Any single-matrix projection after RMSNorm
+ *
+ *
+ * @param context Kernel execution context
+ * @param x Input hidden state (FP32) [dim]
+ * @param output Output buffer (FP32) [outputDim]
+ * @param rmsWeights RMS normalization weights (FP32) [dim]
+ * @param rmsScale Precomputed RMS scale factor [1] (from reduction kernel)
+ * @param w Weight matrix (FP16) [outputDim × dim]
+ * @param inputDim Input dimension (dim)
+ * @param outputDim Output dimension
+ * @param localWorkGroupSize Local work group size for reduction
+ */
+ public static void fusedRmsNormMatmul(
+ KernelContext context,
+ FloatArray x, // input (FP32)
+ FloatArray output, // output (FP32)
+ FloatArray rmsWeights, // RMS norm weights
+ FloatArray rmsScale, // temp[0] = scale factor
+ HalfFloatArray w, // weight matrix
+ int inputDim, // input dimension
+ int outputDim, // output dimension
+ int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ if (rowId >= outputDim) {
+ return;
+ }
+
+ float scale = rmsScale.get(0);
+
+ // Allocate shared memory for reduction
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ int rowOffset = rowId * inputDim;
+
+ // Each thread computes partial dot product with inline normalization
+ float partialSum = 0.0f;
+ for (int j = localId; j < inputDim; j += localWorkGroupSize) {
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ partialSum += w.get(rowOffset + j).getFloat32() * normalized;
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ // Parallel reduction within workgroup
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ // Thread 0 writes final result
+ if (localId == 0) {
+ output.set(rowId, localSum[0]);
+ }
+ }
+
+ /**
+ * Phi3 RoPE rotation with fused KV cache copy.
+ *
+ * Phi3 uses a different RoPE pattern than Llama/Qwen:
+ *
+ * - Pairs elements with offset headSize/2 (not adjacent pairs)
+ * - Each thread processes one dimension pair across all heads
+ * - Iterates over heads internally
+ *
+ *
+ * This fused kernel combines:
+ *
+ * - Phi3-style RoPE rotation for Q and K
+ * - Direct cache write for rotated K
+ * - Direct cache copy for V (no rotation)
+ *
+ *
+ * @param context Kernel execution context
+ * @param positionHolder Current position in sequence [1]
+ * @param sq Query vectors (in/out, rotated) [dim]
+ * @param sk Key vectors (in/out, rotated) [kvDim]
+ * @param sv Value vectors (in only) [kvDim]
+ * @param keyCache Key cache (out) [layers × contextLength × kvDim]
+ * @param valueCache Value cache (out) [layers × contextLength × kvDim]
+ * @param nHeadKv Number of KV heads
+ * @param headSize Dimension per head
+ * @param kvDim Total KV dimension (nHeadKv × headSize)
+ * @param layer Current layer index
+ * @param contextLength Maximum sequence length
+ */
+ public static void ropeRotationWithCacheCopyPhi3(
+ KernelContext context,
+ IntArray positionHolder,
+ FloatArray sq, // Q vector (in/out)
+ FloatArray sk, // K vector (in/out)
+ FloatArray sv, // V vector (in only)
+ FloatArray keyCache, // Key cache (out)
+ FloatArray valueCache, // Value cache (out)
+ int nHeadKv,
+ int headSize,
+ int kvDim,
+ int layer,
+ int contextLength) {
+
+ int idx = context.globalIdx;
+ int dimHalf = headSize / 2;
+
+ // Each thread processes one dimension pair
+ if (idx >= dimHalf) {
+ return;
+ }
+
+ int pos = positionHolder.get(0);
+ int cacheOffset = layer * contextLength * kvDim + pos * kvDim;
+
+ // Calculate frequency for this dimension
+ float freq = 1.0f / TornadoMath.pow(10000.0f, (float) (idx * 2) / (float) headSize);
+ float val = pos * freq;
+ float fcr = TornadoMath.cos(val);
+ float fci = TornadoMath.sin(val);
+
+ // Process Q: all heads (dim = nHeads × headSize)
+ int totalDimQ = sq.getSize();
+ for (int base = 0; base < totalDimQ; base += headSize) {
+ if (base + idx >= totalDimQ || base + idx + dimHalf >= totalDimQ) {
+ break;
+ }
+
+ // Rotate Q with offset pattern
+ float v0 = sq.get(base + idx);
+ float v1 = sq.get(base + idx + dimHalf);
+ sq.set(base + idx, v0 * fcr - v1 * fci);
+ sq.set(base + idx + dimHalf, v0 * fci + v1 * fcr);
+ }
+
+ // Process K: only kvDim elements, with cache write
+ for (int base = 0; base < kvDim; base += headSize) {
+ if (base + idx >= kvDim || base + idx + dimHalf >= kvDim) {
+ break;
+ }
+
+ // Rotate K with offset pattern
+ float k0 = sk.get(base + idx);
+ float k1 = sk.get(base + idx + dimHalf);
+ float rotated0 = k0 * fcr - k1 * fci;
+ float rotated1 = k0 * fci + k1 * fcr;
+
+ // Write rotated K back
+ sk.set(base + idx, rotated0);
+ sk.set(base + idx + dimHalf, rotated1);
+
+ // Fused cache write for K
+ keyCache.set(cacheOffset + base + idx, rotated0);
+ keyCache.set(cacheOffset + base + idx + dimHalf, rotated1);
+
+ // Fused cache copy for V (no rotation needed)
+ valueCache.set(cacheOffset + base + idx, sv.get(base + idx));
+ valueCache.set(cacheOffset + base + idx + dimHalf, sv.get(base + idx + dimHalf));
+ }
+ }
+
+ /**
+ * Fused RMSNorm apply + QKV projection with direct output to separate Q, K, V buffers.
+ *
+ * Eliminates the need for a separate splitQKV kernel by routing outputs
+ * directly based on row index:
+ *
+ * - Rows [0, dim): Q projection
+ * - Rows [dim, dim+kvDim): K projection
+ * - Rows [dim+kvDim, dim+2*kvDim): V projection
+ *
+ *
+ * Formula: output[row] = sum_j(Wqkv[row,j] * rmsWeight[j] * scale * x[j])
+ *
+ * @param context Kernel execution context
+ * @param x Input hidden state (FP32) [dim]
+ * @param q Output Q buffer (FP32) [dim]
+ * @param k Output K buffer (FP32) [kvDim]
+ * @param v Output V buffer (FP32) [kvDim]
+ * @param rmsWeights RMS normalization weights (FP32) [dim]
+ * @param rmsScale Precomputed RMS scale factor [1]
+ * @param wqkv Combined QKV weight matrix (FP16) [opSize × dim]
+ * @param dim Model dimension (Q output size)
+ * @param kvDim KV dimension (K/V output size)
+ * @param localWorkGroupSize Local work group size for reduction
+ */
+ public static void fusedRmsNormQKVMatmulDirect(
+ KernelContext context,
+ FloatArray x, // input (FP32)
+ FloatArray q, // output Q (FP32)
+ FloatArray k, // output K (FP32)
+ FloatArray v, // output V (FP32)
+ FloatArray rmsWeights, // RMS norm weights
+ FloatArray rmsScale, // temp[0] = scale factor
+ HalfFloatArray wqkv, // combined QKV weight matrix
+ int dim, // input dim and Q output dim
+ int kvDim, // K/V output dim
+ int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ // Total rows = dim (Q) + kvDim (K) + kvDim (V)
+ int totalRows = dim + 2 * kvDim;
+ if (rowId >= totalRows) {
+ return;
+ }
+
+ float scale = rmsScale.get(0);
+
+ // Allocate shared memory for reduction
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ int rowOffset = rowId * dim;
+
+ // Each thread computes partial dot product with inline normalization
+ float partialSum = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ partialSum += wqkv.get(rowOffset + j).getFloat32() * normalized;
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ // Parallel reduction within workgroup
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ // Thread 0 writes to appropriate output buffer
+ if (localId == 0) {
+ float result = localSum[0];
+
+ if (rowId < dim) {
+ // Q projection: rows [0, dim)
+ q.set(rowId, result);
+ } else if (rowId < dim + kvDim) {
+ // K projection: rows [dim, dim+kvDim)
+ int kIdx = rowId - dim;
+ k.set(kIdx, result);
+ } else {
+ // V projection: rows [dim+kvDim, dim+2*kvDim)
+ int vIdx = rowId - dim - kvDim;
+ v.set(vIdx, result);
+ }
+ }
+ }
+ /**
+ * Fused RMSNorm apply + Gate/Up projection + SiLU + GLU in one kernel.
+ *
+ * Eliminates the need for separate gateUpSiLU kernel by computing both
+ * gate and up projections per workgroup and applying activation inline.
+ *
+ * For each output index i:
+ *
+ * - gate[i] = dot(wUp[i], RMSNorm(x))
+ * - up[i] = dot(wUp[hiddenDim + i], RMSNorm(x))
+ * - output[i] = SiLU(gate[i]) × up[i]
+ *
+ *
+ * @param context Kernel execution context
+ * @param x Input hidden state (FP32) [dim]
+ * @param output Output buffer (FP32) [hiddenDim] - final FFN result
+ * @param rmsWeights RMS normalization weights (FP32) [dim]
+ * @param rmsScale Precomputed RMS scale factor [1]
+ * @param wUp Combined gate+up weight matrix (FP16) [2×hiddenDim × dim]
+ * @param dim Input dimension
+ * @param hiddenDim Hidden dimension (output size)
+ * @param localWorkGroupSize Local work group size for reduction
+ */
+ public static void fusedRmsNormFFNGateUpSiLU(
+ KernelContext context,
+ FloatArray x, // input (FP32)
+ FloatArray output, // output (FP32) [hiddenDim]
+ FloatArray rmsWeights, // RMS norm weights
+ FloatArray rmsScale, // temp[0] = scale factor
+ HalfFloatArray wUp, // combined gate+up weights [2×hiddenDim × dim]
+ int dim, // input dimension
+ int hiddenDim, // output dimension
+ int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ if (rowId >= hiddenDim) {
+ return;
+ }
+
+ float scale = rmsScale.get(0);
+
+ // Allocate shared memory for reduction
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ // === Compute GATE (row i) ===
+ int gateRowOffset = rowId * dim;
+
+ float gatePartialSum = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ gatePartialSum += wUp.get(gateRowOffset + j).getFloat32() * normalized;
+ }
+
+ localSum[localId] = gatePartialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ float gateResult = localSum[0];
+
+ // === Compute UP (row hiddenDim + i) ===
+ int upRowOffset = (hiddenDim + rowId) * dim;
+
+ float upPartialSum = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ upPartialSum += wUp.get(upRowOffset + j).getFloat32() * normalized;
+ }
+
+ localSum[localId] = upPartialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ float upResult = localSum[0];
+
+ // === Apply SiLU(gate) × up ===
+ if (localId == 0) {
+ float silu = gateResult / (1.0f + TornadoMath.exp(-gateResult));
+ output.set(rowId, silu * upResult);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java
index 930e1774..506d3fdf 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java
@@ -4,6 +4,7 @@
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
// @formatter:off
@@ -292,5 +293,376 @@ private static void processHeadTornado(
}
}
+ /**
+ * Fused RoPE rotation with KV cache copy for Qwen3.
+ * Combines ropeRotation + copyToCache into a single kernel.
+ */
+ public static void ropeRotationWithCacheCopy(
+ KernelContext context,
+ IntArray positionHolder,
+ FloatArray q, // Q vector (in/out)
+ FloatArray k, // K vector (in/out)
+ FloatArray v, // V vector (in only)
+ FloatArray keyCache, // Key cache (out)
+ FloatArray valueCache, // Value cache (out)
+ int numberOfKeyValueHeads,
+ int nEmbdHead,
+ int nEmbdGqa,
+ int layer,
+ int contextLength) {
+
+ int h = context.globalIdx;
+ int ic = context.globalIdy;
+
+ int pos = positionHolder.get(0);
+ int rotn = h < numberOfKeyValueHeads ? 2 : 1;
+ int poffset = h * nEmbdHead;
+ int nComplEmbdHead = nEmbdHead / 2;
+
+ // Compute RoPE frequencies for Qwen3 (theta = 1000000.0f)
+ float theta = 1000000.0f;
+ int i = ic * 2;
+ float freq = 1.0f / TornadoMath.pow(theta, (float) i / (float) nEmbdHead);
+
+ float val = pos * freq;
+ float fcr = TornadoMath.cos(val);
+ float fci = TornadoMath.sin(val);
+
+ // Rotate Q (all heads)
+ float v0q = q.get(poffset + ic);
+ float v1q = q.get(poffset + ic + nComplEmbdHead);
+ q.set(poffset + ic, v0q * fcr - v1q * fci);
+ q.set(poffset + ic + nComplEmbdHead, v0q * fci + v1q * fcr);
+
+ // Rotate K and copy K/V to cache (only for KV heads)
+ if (rotn > 1 && (poffset + ic + nComplEmbdHead) < k.getSize()) {
+ float v0k = k.get(poffset + ic);
+ float v1k = k.get(poffset + ic + nComplEmbdHead);
+ float rotatedK0 = v0k * fcr - v1k * fci;
+ float rotatedK1 = v0k * fci + v1k * fcr;
+
+ // Write rotated K back
+ k.set(poffset + ic, rotatedK0);
+ k.set(poffset + ic + nComplEmbdHead, rotatedK1);
+
+ // Direct cache write (fused - no separate copy kernel!)
+ int cacheOffset = layer * contextLength * nEmbdGqa + pos * nEmbdGqa;
+ int kvIdx = h * nEmbdHead;
+
+ keyCache.set(cacheOffset + kvIdx + ic, rotatedK0);
+ keyCache.set(cacheOffset + kvIdx + ic + nComplEmbdHead, rotatedK1);
+
+ // Copy V to cache (V doesn't need rotation)
+ valueCache.set(cacheOffset + kvIdx + ic, v.get(poffset + ic));
+ valueCache.set(cacheOffset + kvIdx + ic + nComplEmbdHead, v.get(poffset + ic + nComplEmbdHead));
+ }
+ }
+
+ /**
+ * Fused Q/K/V matrix-vector multiplication for Qwen3 GQA.
+ * Q has full head dimension, K/V have reduced KV head dimension.
+ *
+ * Workgroup assignment:
+ * - rowId [0, qDim): Q projection
+ * - rowId [qDim, qDim+kvDim): K projection
+ * - rowId [qDim+kvDim, qDim+2*kvDim): V projection
+ */
+ public static void fusedQKVMatmul(
+ KernelContext context,
+ FloatArray x, // input vector
+ FloatArray q, // output Q
+ FloatArray k, // output K
+ FloatArray v, // output V
+ HalfFloatArray wq, // Q weight matrix
+ HalfFloatArray wk, // K weight matrix
+ HalfFloatArray wv, // V weight matrix
+ int inputDim, // input dimension (config.dim())
+ int qDim, // Q output dimension
+ int kvDim, // KV output dimension
+ int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ // Allocate local memory for reduction
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ if (rowId < qDim) {
+ // ========== Q projection ==========
+ int rowOffset = rowId * inputDim;
+
+ float partialSum = 0.0f;
+ for (int j = localId; j < inputDim; j += localWorkGroupSize) {
+ partialSum += wq.get(rowOffset + j).getFloat32() * x.get(j);
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ q.set(rowId, localSum[0]);
+ }
+
+ } else if (rowId < qDim + kvDim) {
+ // ========== K projection ==========
+ int kRow = rowId - qDim;
+ int rowOffset = kRow * inputDim;
+
+ float partialSum = 0.0f;
+ for (int j = localId; j < inputDim; j += localWorkGroupSize) {
+ partialSum += wk.get(rowOffset + j).getFloat32() * x.get(j);
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ k.set(kRow, localSum[0]);
+ }
+
+ } else if (rowId < qDim + 2 * kvDim) {
+ // ========== V projection ==========
+ int vRow = rowId - qDim - kvDim;
+ int rowOffset = vRow * inputDim;
+
+ float partialSum = 0.0f;
+ for (int j = localId; j < inputDim; j += localWorkGroupSize) {
+ partialSum += wv.get(rowOffset + j).getFloat32() * x.get(j);
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ v.set(vRow, localSum[0]);
+ }
+ }
+ }
+
+ /**
+ * Fused RMSNorm apply + Q/K/V projection for Qwen3 GQA.
+ * Eliminates intermediate wrapXb buffer write/read.
+ */
+ public static void fusedRmsNormQKVMatmul(
+ KernelContext context,
+ FloatArray x, // raw input (FP32)
+ FloatArray q, // output Q
+ FloatArray k, // output K
+ FloatArray v, // output V
+ FloatArray rmsWeights, // RMS norm weights
+ FloatArray rmsScale, // temp[0] = scale factor
+ HalfFloatArray wq, // Q weight matrix
+ HalfFloatArray wk, // K weight matrix
+ HalfFloatArray wv, // V weight matrix
+ int inputDim, // input dimension (config.dim())
+ int qDim, // Q output dimension
+ int kvDim, // KV output dimension
+ int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ float scale = rmsScale.get(0);
+
+ // Allocate local memory for reduction
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ if (rowId < qDim) {
+ // ========== Q projection with inline normalization ==========
+ int rowOffset = rowId * inputDim;
+
+ float partialSum = 0.0f;
+ for (int j = localId; j < inputDim; j += localWorkGroupSize) {
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ partialSum += wq.get(rowOffset + j).getFloat32() * normalized;
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ q.set(rowId, localSum[0]);
+ }
+
+ } else if (rowId < qDim + kvDim) {
+ // ========== K projection with inline normalization ==========
+ int kRow = rowId - qDim;
+ int rowOffset = kRow * inputDim;
+
+ float partialSum = 0.0f;
+ for (int j = localId; j < inputDim; j += localWorkGroupSize) {
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ partialSum += wk.get(rowOffset + j).getFloat32() * normalized;
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ k.set(kRow, localSum[0]);
+ }
+
+ } else if (rowId < qDim + 2 * kvDim) {
+ // ========== V projection with inline normalization ==========
+ int vRow = rowId - qDim - kvDim;
+ int rowOffset = vRow * inputDim;
+
+ float partialSum = 0.0f;
+ for (int j = localId; j < inputDim; j += localWorkGroupSize) {
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ partialSum += wv.get(rowOffset + j).getFloat32() * normalized;
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ v.set(vRow, localSum[0]);
+ }
+ }
+ }
+
+
+ /**
+ * Fused Q and K RMSNorm for Qwen3.
+ * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel.
+ *
+ * Workgroup assignment:
+ * - Workgroups [0, nHeads): Process Q heads
+ * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads
+ */
+ public static void fusedQKRmsNorm(
+ KernelContext context,
+ FloatArray q, // Q vector (in/out)
+ FloatArray k, // K vector (in/out)
+ FloatArray qWeights, // Q RMS norm weights
+ FloatArray kWeights, // K RMS norm weights
+ int nHeads, // number of Q heads
+ int nHeadKv, // number of K heads
+ int nEmbdHead, // head dimension
+ int localMemSize, // local memory size (must be fixed)
+ float rmsNormEps) {
+
+ int groupId = context.groupIdx;
+ int localId = context.localIdx;
+ int localSize = context.localGroupSizeX;
+
+ // Allocate local memory with FIXED size parameter
+ float[] localSum = context.allocateFloatLocalArray(localMemSize);
+
+ if (groupId < nHeads) {
+ // === Process Q head ===
+ int headOffset = groupId * nEmbdHead;
+
+ // Step 1: Compute sum of squares (reduction)
+ float partialSum = 0.0f;
+ for (int i = localId; i < nEmbdHead; i += localSize) {
+ float val = q.get(headOffset + i);
+ partialSum += val * val;
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ // Parallel reduction
+ for (int stride = localSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ // Compute normalization factor
+ float ss = localSum[0];
+ ss = ss / nEmbdHead + rmsNormEps;
+ ss = 1.0f / TornadoMath.sqrt(ss);
+
+ context.localBarrier();
+
+ // Step 2: Apply normalization with weights (in-place)
+ for (int i = localId; i < nEmbdHead; i += localSize) {
+ float normalized = ss * q.get(headOffset + i);
+ q.set(headOffset + i, qWeights.get(i) * normalized);
+ }
+
+ } else if (groupId < nHeads + nHeadKv) {
+ // === Process K head ===
+ int headIdx = groupId - nHeads;
+ int headOffset = headIdx * nEmbdHead;
+
+ // Step 1: Compute sum of squares (reduction)
+ float partialSum = 0.0f;
+ for (int i = localId; i < nEmbdHead; i += localSize) {
+ float val = k.get(headOffset + i);
+ partialSum += val * val;
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ // Parallel reduction
+ for (int stride = localSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ // Compute normalization factor
+ float ss = localSum[0];
+ ss = ss / nEmbdHead + rmsNormEps;
+ ss = 1.0f / TornadoMath.sqrt(ss);
+
+ context.localBarrier();
+
+ // Step 2: Apply normalization with weights (in-place)
+ for (int i = localId; i < nEmbdHead; i += localSize) {
+ float normalized = ss * k.get(headOffset + i);
+ k.set(headOffset + i, kWeights.get(i) * normalized);
+ }
+ }
+ }
}
// @formatter:on
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java
index fa610960..80617400 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java
@@ -3,10 +3,14 @@
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.types.HalfFloat;
+import uk.ac.manchester.tornado.api.types.HalfFloat;
+import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+
public class TransformerComputeKernels {
/**
@@ -22,6 +26,25 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) {
}
}
+ public static void convertFP32toFP16v2(KernelContext context, FloatArray input, HalfFloatArray output) {
+ int i = context.globalIdx;
+ HalfFloat val = new HalfFloat(input.get(i));
+ output.set(i,val);
+ }
+
+ public static void mapContextWithQuantize(
+ KernelContext context,
+ HalfFloatArray outputFP16, // Direct FP16 output
+ FloatArray x,
+ FloatArray weights,
+ FloatArray temp) {
+
+ int gid = context.globalIdx;
+ float ss = temp.get(0);
+ float result = weights.get(gid) * (ss * x.get(gid));
+ outputFP16.set(gid, new HalfFloat(result));
+ }
+
public static void convertFP16toFP32(KernelContext context, HalfFloatArray x, FloatArray wrapX) {
int i = context.globalIdx;
wrapX.set(i, x.get(i).getFloat32());
@@ -142,4 +165,12 @@ public static void reductionOneBlock2WithLogits(KernelContext context, FloatArra
output.set(gid, weights.get(gid) * (ss * output.get(gid)));
}
+ public static void mapContextWithQuantizeLogits(KernelContext context, HalfFloatArray output, FloatArray input, FloatArray weights, FloatArray temp) {
+ int gid = context.globalIdx;
+ float ss = temp.get(0);
+ float in = ss * input.get(gid);
+ float interim = weights.get(gid) * in;
+ output.set(gid, new HalfFloat(interim));
+ }
+
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java
index ac6619f6..16f8c3b6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java
@@ -4,7 +4,11 @@
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.types.HalfFloat;
-import uk.ac.manchester.tornado.api.types.arrays.*;
+import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.Int8Array;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
public class TransformerComputeKernelsLayered {
@@ -14,6 +18,88 @@ public class TransformerComputeKernelsLayered {
public TransformerComputeKernelsLayered() {
}
+ public static void fusedQKvBiasAddition(KernelContext context, FloatArray q_out, FloatArray k_out, FloatArray qBias, FloatArray v_out, FloatArray kBias, FloatArray vBias, int dimQ, int dimKV) {
+
+ int gid = context.globalIdx;
+
+ if (gid < dimQ) {
+ // 1. Add Q bias
+ q_out.set(gid, q_out.get(gid) + qBias.get(gid));
+
+ // 2. Conditionally Add K and V Bias
+ if (gid < dimKV) {
+ k_out.set(gid, k_out.get(gid) + kBias.get(gid));
+ v_out.set(gid, v_out.get(gid) + vBias.get(gid));
+ }
+ }
+ }
+
+ public static void fusedRmsNormFFNGateUp(KernelContext context, FloatArray x, // raw input (FP32)
+ FloatArray hb, // output
+ FloatArray rmsWeights, // RMS norm weights
+ FloatArray rmsScale, // temp[0] = scale factor
+ HalfFloatArray w1, HalfFloatArray w3, int dim, // input dimension
+ int hiddenDim, // output dimension
+ int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ if (rowId >= hiddenDim) {
+ return;
+ }
+
+ float scale = rmsScale.get(0);
+
+ // Allocate shared memory for normalized input (reused for both W1 and W3)
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ int rowOffsetW1 = rowId * dim;
+ int rowOffsetW3 = rowId * dim;
+
+ // === W1 matmul with inline normalization ===
+ float sum1 = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ sum1 += w1.get(rowOffsetW1 + j).getFloat32() * normalized;
+ }
+
+ localSum[localId] = sum1;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+ float result1 = localSum[0];
+
+ // === W3 matmul with inline normalization (same computation) ===
+ float sum3 = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ sum3 += w3.get(rowOffsetW3 + j).getFloat32() * normalized;
+ }
+
+ localSum[localId] = sum3;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+ float result3 = localSum[0];
+
+ // === SiLU + GLU ===
+ if (localId == 0) {
+ float silu = result1 / (1.0f + TornadoMath.exp(-result1));
+ hb.set(rowId, silu * result3);
+ }
+ }
+
/**
* Performs RMS (Root Mean Square) normalization using parallel reduction. This is the first phase of RMS normalization that computes the variance and scaling factor across all work groups.
*
@@ -136,6 +222,60 @@ public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, Float
}
}
+ /**
+ * Fused RoPE rotation with KV cache copy. Eliminates separate copyToCaches kernel.
+ *
+ * - Rotates Q (full dim) - Rotates K and writes directly to keyCache - Copies V directly to valueCache (no rotation needed)
+ */
+ public static void ropeRotationWithCacheCopy(KernelContext context, IntArray positionHolder, FloatArray sq, // Q vector (in/out)
+ FloatArray sk, // K vector (in/out)
+ FloatArray sv, // V vector (in only)
+ FloatArray keyCache, // Key cache (out)
+ FloatArray valueCache, // Value cache (out)
+ int kvDim, int headSize, int layer, int contextLength) {
+
+ int i = context.globalIdx * 2;
+ int pos = positionHolder.get(0);
+
+ // Bounds check for Q rotation (Q has dim elements, processed in pairs)
+ if (i + 1 < sq.getSize()) {
+ // RoPE frequency calculation
+ int head_dim = i % headSize;
+ float freq = 1.0f / TornadoMath.pow(50000.0f, head_dim / (float) headSize);
+ float val = pos * freq;
+ float fcr = TornadoMath.cos(val);
+ float fci = TornadoMath.sin(val);
+
+ // Rotate Q
+ float v0q = sq.get(i);
+ float v1q = sq.get(i + 1);
+ sq.set(i, v0q * fcr - v1q * fci);
+ sq.set(i + 1, v0q * fci + v1q * fcr);
+
+ // Rotate K AND write to cache (only for kvDim elements)
+ if (i + 1 < kvDim) {
+ float v0k = sk.get(i);
+ float v1k = sk.get(i + 1);
+ float rotated0 = v0k * fcr - v1k * fci;
+ float rotated1 = v0k * fci + v1k * fcr;
+
+ // Write rotated K back to sk
+ sk.set(i, rotated0);
+ sk.set(i + 1, rotated1);
+
+ // Direct cache write (fused - no separate copy kernel!)
+ int cacheOffset = layer * contextLength * kvDim + pos * kvDim;
+ keyCache.set(cacheOffset + i, rotated0);
+ keyCache.set(cacheOffset + i + 1, rotated1);
+
+ // Copy V to cache (V doesn't need rotation)
+ valueCache.set(cacheOffset + i, sv.get(i));
+ valueCache.set(cacheOffset + i + 1, sv.get(i + 1));
+ }
+ }
+
+ }
+
public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArray v, int dimQ, int dimKV) {
int totalSize = dimQ + 2 * dimKV;
@@ -348,7 +488,7 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray
int pos = positionHolder.get(0);
int loff = layer * contextLength * kvDim;
int kvHeadIdx = h / kvMul;
- int BLOCK_SIZE_C = 8;
+ int BLOCK_SIZE_C = 16;
// Allocate shared memory for tiled computation
float[] q_shared = context.allocateFloatLocalArray(headSize);
@@ -455,6 +595,191 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray
}
}
+ public static void processHeadsFlashAttentionOptV2(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize,
+ // NOTE: Still used for logic, but not for allocation size
+ int kvDim, int kvMul, IntArray positionHolder, int layer, int contextLength) {
+
+ // --- STATIC CONSTANTS FOR OPENCL ALLOCATIONS ---
+ // These must be large enough to handle the maximum expected values for
+ // headSize and localSize in your model/hardware setup.
+ // Assuming Max Head Size is 256 and Max Local Size is 256.
+ final int MAX_HEAD_SIZE = 256;
+ final int MAX_LOCAL_SIZE = 256;
+ final int MAX_BLOCK_SIZE_C = 32;
+ final int MAX_TILE_ELEMENTS = MAX_BLOCK_SIZE_C * MAX_HEAD_SIZE;
+
+ int tid = context.localIdx;
+ int h = context.groupIdx;
+ int localSize = context.localGroupSizeX;
+
+ if (h >= nHeads) {
+ return;
+ }
+
+ int pos = positionHolder.get(0);
+ int loff = layer * contextLength * kvDim;
+ int kvHeadIdx = h / kvMul;
+ int BLOCK_SIZE_C = 32;
+
+ // === Shared memory allocations (FIXED: using static sizes) ===
+ // ERROR FIX 1: Use MAX_HEAD_SIZE instead of dynamic headSize
+ float[] q_shared = context.allocateFloatLocalArray(MAX_HEAD_SIZE);
+ // ERROR FIX 2: Use MAX_TILE_ELEMENTS instead of BLOCK_SIZE_C * headSize
+ float[] k_tile = context.allocateFloatLocalArray(MAX_TILE_ELEMENTS);
+ float[] v_tile = context.allocateFloatLocalArray(MAX_TILE_ELEMENTS);
+
+ float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); // Size is constant (32)
+ float[] exp_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); // Size is constant (32)
+
+ // ERROR FIX 3: Use MAX_LOCAL_SIZE instead of dynamic localSize
+ float[] reduction_shared = context.allocateFloatLocalArray(MAX_LOCAL_SIZE);
+
+ float[] state_shared = context.allocateFloatLocalArray(4); // Size is constant (4)
+
+ // === Dimension partitioning: each thread handles subset of output dims ===
+ int dimsPerThread = (headSize + localSize - 1) / localSize;
+ int myStartDim = tid * dimsPerThread;
+ int myEndDim = Math.min(myStartDim + dimsPerThread, headSize);
+ int myDimCount = myEndDim - myStartDim;
+
+ // FIX from previous iteration: ensuring output array is statically sized
+ final int MAX_OUTPUT_DIMS = MAX_HEAD_SIZE / 8; // e.g., 32 if MAX_HEAD_SIZE=256
+ float[] output = new float[MAX_OUTPUT_DIMS];
+
+ // Initialize thread-local output
+ for (int i = 0; i < myDimCount; i++) {
+ output[i] = 0.0f;
+ }
+
+ // Initialize shared state
+ if (tid == 0) {
+ state_shared[0] = Float.NEGATIVE_INFINITY;
+ state_shared[1] = 0.0f;
+ }
+
+ // Load query into shared memory (cooperative)
+ // NOTE: Loop bound must still use headSize to read correct data volume
+ for (int i = tid; i < headSize; i += localSize) {
+ q_shared[i] = q.get(h * headSize + i);
+ }
+ context.localBarrier();
+
+ // Process sequence in tiles
+ for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) {
+ int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos);
+ int tileLen = tileEnd - tileC + 1;
+
+ // === Cooperative K/V tile loading ===
+ int totalElements = tileLen * headSize;
+ int elementsPerThread = (totalElements + localSize - 1) / localSize;
+ int startElem = tid * elementsPerThread;
+ int endElem = Math.min(startElem + elementsPerThread, totalElements);
+
+ for (int globalElemIdx = startElem; globalElemIdx < endElem; globalElemIdx++) {
+ int seqIdx = globalElemIdx / headSize;
+ int dimIdx = globalElemIdx % headSize;
+ int kvOffset = loff + (tileC + seqIdx) * kvDim + kvHeadIdx * headSize + dimIdx;
+ int tileMemOffset = seqIdx * headSize + dimIdx;
+
+ // Check bounds just to be safe, though kvDim/headSize should ensure this is valid.
+ if (tileMemOffset < MAX_TILE_ELEMENTS) {
+ k_tile[tileMemOffset] = key_cache.get(kvOffset);
+ v_tile[tileMemOffset] = value_cache.get(kvOffset);
+ }
+ }
+ context.localBarrier();
+
+ // === Compute attention scores (cooperative) ===
+ for (int t = tid; t < tileLen; t += localSize) {
+ float score = 0.0f;
+ for (int d = 0; d < headSize; d++) {
+ score += q_shared[d] * k_tile[t * headSize + d];
+ }
+ s_tile[t] = score / TornadoMath.sqrt(headSize);
+ }
+ context.localBarrier();
+
+ // ... (Parallel reduction for tileMax - uses reduction_shared, which is now fixed)
+ float threadMax = Float.NEGATIVE_INFINITY;
+ for (int t = tid; t < tileLen; t += localSize) {
+ if (s_tile[t] > threadMax) {
+ threadMax = s_tile[t];
+ }
+ }
+ reduction_shared[tid] = threadMax;
+ context.localBarrier();
+
+ for (int stride = localSize / 2; stride > 0; stride /= 2) {
+ if (tid < stride) {
+ reduction_shared[tid] = Math.max(reduction_shared[tid], reduction_shared[tid + stride]);
+ }
+ context.localBarrier();
+ }
+ float tileMax = reduction_shared[0];
+
+ // === Update running max and rescale if needed ===
+ float prevMax = state_shared[0];
+ float newMax = Math.max(prevMax, tileMax);
+ float scale = 1.0f;
+
+ if (newMax != prevMax && prevMax != Float.NEGATIVE_INFINITY) {
+ scale = TornadoMath.exp(prevMax - newMax);
+ for (int i = 0; i < myDimCount; i++) {
+ output[i] *= scale;
+ }
+ }
+
+ // === Compute exp(score - max) and tile sum (cooperative) ===
+ for (int t = tid; t < tileLen; t += localSize) {
+ exp_tile[t] = TornadoMath.exp(s_tile[t] - newMax);
+ }
+ context.localBarrier();
+
+ // Parallel reduction for tile sum
+ // ... (Uses reduction_shared, which is now fixed)
+ float threadSum = 0.0f;
+ for (int t = tid; t < tileLen; t += localSize) {
+ threadSum += exp_tile[t];
+ }
+ reduction_shared[tid] = threadSum;
+ context.localBarrier();
+
+ for (int stride = localSize / 2; stride > 0; stride /= 2) {
+ if (tid < stride) {
+ reduction_shared[tid] += reduction_shared[tid + stride];
+ }
+ context.localBarrier();
+ }
+ float tileSum = reduction_shared[0];
+
+ // Update shared state (thread 0)
+ if (tid == 0) {
+ state_shared[0] = newMax;
+ state_shared[1] = state_shared[1] * scale + tileSum;
+ }
+ context.localBarrier();
+
+ // === Accumulate output (each thread handles its dimensions) ===
+ for (int t = 0; t < tileLen; t++) {
+ float expScore = exp_tile[t];
+ for (int i = 0; i < myDimCount; i++) {
+ int d = myStartDim + i;
+ output[i] += expScore * v_tile[t * headSize + d];
+ }
+ }
+ context.localBarrier();
+ }
+
+ // === Final normalization and write ===
+ float sumExp = state_shared[1];
+ float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f;
+
+ int baseOffset = h * headSize + myStartDim;
+ for (int i = 0; i < myDimCount; i++) {
+ xb.set(baseOffset + i, output[i] * normFactor);
+ }
+ }
+
/**
* Same as processHeadsFlashAttention but with some optimizations that seem to lower attention's execution time, especially in larger models.
*/
@@ -688,6 +1013,133 @@ public static void matrixVectorGeneric(
hb.set(rowId, sum);
}
}
+
+ /**
+ * Fused Q/K/V matrix-vector multiplication.
+ * Reduces kernel launch overhead and improves input vector cache utilization.
+ *
+ * Workgroup assignment:
+ * - rowId [0, dim): Q projection
+ * - rowId [dim, dim+kvDim): K projection
+ * - rowId [dim+kvDim, dim+2*kvDim): V projection
+ */
+ public static void fusedQKVMatmulX(
+ KernelContext context,
+ HalfFloatArray x, // input vector (FP16)
+ FloatArray q, // output Q (FP32)
+ FloatArray k, // output K (FP32)
+ FloatArray v, // output V (FP32)
+ HalfFloatArray wq, // Q weight matrix
+ HalfFloatArray wk, // K weight matrix
+ HalfFloatArray wv, // V weight matrix
+ int dim, // model dimension (Q output size)
+ int kvDim, // KV dimension (K/V output size)
+ int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ // Allocate local memory for reduction
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ if (rowId < dim) {
+ // ========== Q projection ==========
+ int rowOffset = rowId * dim;
+
+ float partialSum = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ partialSum += wq.get(rowOffset + j).getFloat32() * x.get(j).getFloat32();
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ q.set(rowId, localSum[0]);
+ }
+
+ } else if (rowId < dim + kvDim) {
+ // ========== K projection ==========
+ int kRow = rowId - dim;
+ int rowOffset = kRow * dim;
+
+ float partialSum = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ partialSum += wk.get(rowOffset + j).getFloat32() * x.get(j).getFloat32();
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ k.set(kRow, localSum[0]);
+ }
+
+ } else if (rowId < dim + 2 * kvDim) {
+ // ========== V projection ==========
+ int vRow = rowId - dim - kvDim;
+ int rowOffset = vRow * dim;
+
+ float partialSum = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ partialSum += wv.get(rowOffset + j).getFloat32() * x.get(j).getFloat32();
+ }
+
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ v.set(vRow, localSum[0]);
+ }
+ }
+ }
+
+ // @formatter:off
+ public static void matrixVectorGeneric(
+ KernelContext context,
+ HalfFloatArray x,
+ FloatArray hb, // output
+ HalfFloatArray w,
+ int dim1, // inner loop
+ int dim0, // outer loop
+ int localWorkGroupSize) {
+ // One row per workgroup (not per thread)
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+ int localSize = localWorkGroupSize;
+
+ // Early exit if this workgroup is beyond our output dimension
+ if (rowId >= dim0) {
+ return;
+ }
+ float sum = matrixVectorRowMajorOptimizedSingle(context, localSize, x, w, dim1);
+
+ // Thread 0 in each workgroup writes the final result
+ if (localId == 0) {
+ hb.set(rowId, sum);
+ }
+ }
// @formatter:on
/**
@@ -730,6 +1182,26 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA
}
}
+ public static void matrixVectorGenericWithResidual(KernelContext context, HalfFloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) {
+ // One row per workgroup (not per thread)
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+ int localSize = localWorkGroupSize;
+
+ // Early exit if this workgroup is beyond our output dimension
+ if (rowId >= d) {
+ return;
+ }
+
+ float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n);
+
+ // Thread 0 in each workgroup writes the final result
+ if (localId == 0) {
+ float result = hb.get(rowId) + sum;
+ hb.set(rowId, result);
+ }
+ }
+
/**
* Fused feed-forward network with SiLU activation and GLU gating. Implements the SwiGLU variant used in LLaMA-style models.
*
@@ -752,7 +1224,9 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA
* @param localWorkGroupSize
* Work group size
*/
- public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) {
+
+ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, HalfFloatArray x, HalfFloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d,
+ int localWorkGroupSize) {
// One row per workgroup (not per thread)
int rowId = context.groupIdx;
int localId = context.localIdx;
@@ -768,21 +1242,61 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex
if (localId == 0) {
float silu = siluActivation(sum1); // Using the new SiLU method
float result = silu * sum3;
- hb.set(rowId, result);
+ hb.set(rowId, new HalfFloat(result));
}
}
- /**
- * Gaussian Error Linear Unit (GELU) activation function. Approximation formula: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
- *
- * @param x
- * Input value
- * @return Activated value
- */
- public static float geluActivation(float x) {
- float x3 = x * x * x;
- return 0.5f * x * (1.0f + TornadoMath.tanh((0.797885f * (x + 0.044715f * x3))));
- }
+ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) {
+ // One row per workgroup (not per thread)
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ if (rowId >= d) {
+ return;
+ }
+
+ float sum1 = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w1, n);
+ float sum3 = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w3, n);
+
+ // Thread 0 in each workgroup writes the final result
+ if (localId == 0) {
+ float silu = siluActivation(sum1); // Using the new SiLU method
+ float result = silu * sum3;
+ hb.set(rowId, result);
+ }
+ }
+
+ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, HalfFloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) {
+ // One row per workgroup (not per thread)
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ if (rowId >= d) {
+ return;
+ }
+
+ HalfFloat sum1 = matrixVectorRowMajorOptimizedFHF(context, localWorkGroupSize, x, w1, n);
+ HalfFloat sum3 = matrixVectorRowMajorOptimizedFHF(context, localWorkGroupSize, x, w3, n);
+
+ // Thread 0 in each workgroup writes the final result
+ if (localId == 0) {
+ float silu = siluActivation(sum1.getFloat32()); // Using the new SiLU method
+ float result = silu * sum3.getFloat32();
+ hb.set(rowId, result);
+ }
+ }
+
+ /**
+ * Gaussian Error Linear Unit (GELU) activation function. Approximation formula: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
+ *
+ * @param x
+ * Input value
+ * @return Activated value
+ */
+ public static float geluActivation(float x) {
+ float x3 = x * x * x;
+ return 0.5f * x * (1.0f + TornadoMath.tanh((0.797885f * (x + 0.044715f * x3))));
+ }
/**
* Sigmoid-weighted Linear Unit (SiLU) activation function. Also known as Swish activation.
@@ -876,6 +1390,299 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc
return localSum[0];
}
+ public static HalfFloat matrixVectorRowMajorOptimizedFHF(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) {
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ // Allocate local memory for reduction
+ HalfFloat[] localSum = context.allocateHalfFloatLocalArray(localSize);
+
+ int rowOffset = rowId * n;
+
+ // Each thread calculates partial dot product
+ float partialSum = 0.0f;
+ // HalfFloat partialSum = new HalfFloat(0f);
+ for (int j = localId; j < n; j += localSize) {
+ int matrixIdx = rowOffset + j;
+ // HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j));
+ partialSum += w.get(matrixIdx).getFloat32() * x.get(j).getFloat32();
+ // partialSum = HalfFloat.add(partialSum, mul);
+ }
+
+ // Store partial sum in local memory
+ localSum[localId] = new HalfFloat(partialSum);
+ context.localBarrier();
+
+ // Parallel reduction within workgroup
+ for (int stride = localSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] = HalfFloat.add(localSum[localId], localSum[localId + stride]);
+ // localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ return localSum[0];
+ }
+
+ public static float matrixVectorRowMajorOptimizedF(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) {
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ // Allocate local memory for reduction
+ float[] localSum = context.allocateFloatLocalArray(localSize);
+
+ int rowOffset = rowId * n;
+
+ // Each thread calculates partial dot product
+ float partialSum = 0.0f;
+ // HalfFloat partialSum = new HalfFloat(0f);
+ for (int j = localId; j < n; j += localSize) {
+ int matrixIdx = rowOffset + j;
+ // HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j));
+ partialSum += w.get(matrixIdx).getFloat32() * x.get(j).getFloat32();
+ // partialSum = HalfFloat.add(partialSum, mul);
+ }
+
+ // Store partial sum in local memory
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ // Parallel reduction within workgroup
+ for (int stride = localSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ return localSum[0];
+ }
+
+ public static float matrixVectorRowMajorOptimizedXX(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) {
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+ float[] localSum = context.allocateFloatLocalArray(localSize);
+
+ int rowOffset = rowId * n;
+ float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f;
+
+ int stride = localSize;
+ int stride2 = localSize << 1;
+ int stride3 = localSize * 3;
+ int stride4 = localSize << 2;
+
+ // Already coalesced: thread 0 reads idx 0, thread 1 reads idx 1, etc.
+ int j = localId;
+ int limit = n - stride3;
+
+ for (; j < limit; j += stride4) {
+ int base = rowOffset + j;
+ // Hoist x.get() calls - they're reused across all rows
+ float x0 = x.get(j).getFloat32();
+ float x1 = x.get(j + stride).getFloat32();
+ float x2 = x.get(j + stride2).getFloat32();
+ float x3 = x.get(j + stride3).getFloat32();
+
+ sum0 += w.get(base).getFloat32() * x0;
+ sum1 += w.get(base + stride).getFloat32() * x1;
+ sum2 += w.get(base + stride2).getFloat32() * x2;
+ sum3 += w.get(base + stride3).getFloat32() * x3;
+ }
+
+ for (; j < n; j += stride) {
+ sum0 += w.get(rowOffset + j).getFloat32() * x.get(j).getFloat32();
+ }
+
+ localSum[localId] = (sum0 + sum1) + (sum2 + sum3);
+ context.localBarrier();
+
+ // Reduction with minimal barriers
+ for (int s = localSize >> 1; s > 0; s >>= 1) {
+ if (localId < s) {
+ localSum[localId] += localSum[localId + s];
+ }
+ context.localBarrier();
+ }
+
+ return localSum[0];
+ }
+
+ public static float matrixVectorRowMajorOptimizedSingle(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) {
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ // Allocate local memory for reduction
+ float[] localSum = context.allocateFloatLocalArray(localSize);
+
+ int rowOffset = rowId * n;
+
+ HalfFloat partialSum = new HalfFloat(0f);
+ for (int j = localId; j < n; j += localSize) {
+ int matrixIdx = rowOffset + j;
+ HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j));
+ partialSum = HalfFloat.add(partialSum, mul);
+ }
+
+ // Store partial sum in local memory
+ localSum[localId] = partialSum.getHalfFloatValue();
+ context.localBarrier();
+
+ // Parallel reduction within workgroup
+ for (int stride = localSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ return localSum[0];
+ }
+
+ public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) {
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+ float[] localSum = context.allocateFloatLocalArray(localSize);
+
+ int rowOffset = rowId * n;
+
+ // Accumulate in HalfFloat to avoid conversions in inner loop
+ HalfFloat sum0 = new HalfFloat(0f);
+ HalfFloat sum1 = new HalfFloat(0f);
+ HalfFloat sum2 = new HalfFloat(0f);
+ HalfFloat sum3 = new HalfFloat(0f);
+
+ int stride = localSize;
+ int stride2 = localSize << 1;
+ int stride3 = localSize * 3;
+ int stride4 = localSize << 2;
+
+ int j = localId;
+ int limit = n - stride3;
+
+ for (; j < limit; j += stride4) {
+ int base = rowOffset + j;
+
+ // Stay in HalfFloat - no getFloat32() calls
+ HalfFloat x0 = x.get(j);
+ HalfFloat x1 = x.get(j + stride);
+ HalfFloat x2 = x.get(j + stride2);
+ HalfFloat x3 = x.get(j + stride3);
+
+ sum0 = HalfFloat.add(sum0, HalfFloat.mult(w.get(base), x0));
+ sum1 = HalfFloat.add(sum1, HalfFloat.mult(w.get(base + stride), x1));
+ sum2 = HalfFloat.add(sum2, HalfFloat.mult(w.get(base + stride2), x2));
+ sum3 = HalfFloat.add(sum3, HalfFloat.mult(w.get(base + stride3), x3));
+ }
+
+ // Cleanup loop
+ for (; j < n; j += stride) {
+ sum0 = HalfFloat.add(sum0, HalfFloat.mult(w.get(rowOffset + j), x.get(j)));
+ }
+
+ // Convert to float32 only at the end for reduction
+ localSum[localId] = sum0.getFloat32() + sum1.getFloat32() + sum2.getFloat32() + sum3.getFloat32();
+ context.localBarrier();
+
+ for (int s = localSize >> 1; s > 0; s >>= 1) {
+ if (localId < s) {
+ localSum[localId] += localSum[localId + s];
+ }
+ context.localBarrier();
+ }
+
+ return localSum[0];
+ }
+
+ public static void fusedQKVMatmul(KernelContext context, HalfFloatArray x, // input (read once!)
+ FloatArray q, FloatArray k, FloatArray v, // outputs
+ HalfFloatArray wq, HalfFloatArray wk, HalfFloatArray wv, int dim, int kvDim, int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ // Determine which output this workgroup computes
+ int totalRows = dim + 2 * kvDim; // Q rows + K rows + V rows
+
+ if (rowId < dim) {
+ // Q projection
+ float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wq, dim);
+ if (localId == 0) {
+ q.set(rowId, sum);
+ }
+ } else if (rowId < dim + kvDim) {
+ // K projection
+ int kRow = rowId - dim;
+ float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wk, dim);
+ if (localId == 0) {
+ k.set(kRow, sum);
+ }
+ } else {
+ // V projection
+ int vRow = rowId - dim - kvDim;
+ float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wv, dim);
+ if (localId == 0) {
+ v.set(vRow, sum);
+ }
+ }
+ }
+
+ public static float matrixVectorRowMajorOptimizedx(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) {
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ // Allocate local memory for reduction
+ float[] localSum = context.allocateFloatLocalArray(localSize);
+
+ int rowOffset = rowId * n;
+
+ // Each thread calculates partial dot product - UNROLLED BY 4
+ float sum0 = 0.0f;
+ float sum1 = 0.0f;
+ float sum2 = 0.0f;
+ float sum3 = 0.0f;
+
+ int j = localId;
+ int stride = localSize;
+ int stride4 = localSize << 2; // localSize * 4
+ int limit = n - (stride * 3); // Safe limit for 4 elements
+
+ // Main loop unrolled by 4 with separate accumulators
+ for (; j < limit; j += stride4) {
+ int base = rowOffset + j;
+ int j1 = j + stride;
+ int j2 = j + (stride << 1);
+ int j3 = j + stride * 3;
+
+ sum0 += w.get(base).getFloat32() * x.get(j).getFloat32();
+ sum1 += w.get(base + stride).getFloat32() * x.get(j1).getFloat32();
+ sum2 += w.get(base + (stride << 1)).getFloat32() * x.get(j2).getFloat32();
+ sum3 += w.get(base + stride * 3).getFloat32() * x.get(j3).getFloat32();
+ }
+
+ // Handle remainder
+ for (; j < n; j += stride) {
+ sum0 += w.get(rowOffset + j).getFloat32() * x.get(j).getFloat32();
+ }
+
+ // Combine accumulators (tree reduction for better precision)
+ float partialSum = (sum0 + sum1) + (sum2 + sum3);
+
+ // Store partial sum in local memory
+ localSum[localId] = partialSum;
+ context.localBarrier();
+
+ // Parallel reduction within workgroup
+ for (int s = localSize >> 1; s > 0; s >>= 1) {
+ if (localId < s) {
+ localSum[localId] += localSum[localId + s];
+ }
+ context.localBarrier();
+ }
+
+ return localSum[0];
+ }
+
// Second kernel - Combines partial sums and computes final normalization
public static void reductionFinalNormalization(KernelContext context, FloatArray output, int size, float ermsNorm) {
int gid = context.globalIdx;
@@ -1169,10 +1976,7 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex
}
}
- public static void fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte(KernelContext context, FloatArray x, FloatArray hb,
- ByteArray w1,
- ByteArray w3,
- int n, int d, int localWorkGroupSize) {
+ public static void fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte(KernelContext context, FloatArray x, FloatArray hb, ByteArray w1, ByteArray w3, int n, int d, int localWorkGroupSize) {
// One row per workgroup (not per thread)
int rowId = context.groupIdx;
int localId = context.localIdx;
@@ -1237,4 +2041,364 @@ public static void processHeadsParallel(FloatArray q, FloatArray key_cache, Floa
}
}
+ /**
+ * Fused Q/K/V matrix-vector multiplication for Q8_0 quantized weights. Reduces kernel launch overhead and improves input vector cache utilization.
+ *
+ * Workgroup assignment: - rowId [0, dim): Q projection - rowId [dim, dim+kvDim): K projection - rowId [dim+kvDim, dim+2*kvDim): V projection
+ */
+ public static void fusedQKVMatmulQ8(KernelContext context, FloatArray x, FloatArray q, FloatArray k, FloatArray v, ByteArray wq, ByteArray wk, ByteArray wv, int dim, int kvDim,
+ int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ int blockSize = 32;
+ final int Q8_0_BLOCK_BYTES = 34;
+ int blocksPerRow = (dim + blockSize - 1) / blockSize;
+
+ float[] localSums = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ if (rowId < dim) {
+ // ========== Q projection ==========
+ int rowBlockOffset = rowId * blocksPerRow;
+
+ float partialSum1 = 0.0f;
+ float partialSum2 = 0.0f;
+ float partialSum3 = 0.0f;
+ float partialSum4 = 0.0f;
+
+ for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) {
+ int blockIdx = j / blockSize;
+ int withinBlockIdx = j % blockSize;
+ int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
+
+ HalfFloat scale = wq.getHalfFloat(blockByteOffset);
+ float scaleFloat = scale.getFloat32();
+
+ int quantsOffset = blockByteOffset + 2 + withinBlockIdx;
+ byte quant1 = wq.get(quantsOffset);
+ byte quant2 = wq.get(quantsOffset + 1);
+ byte quant3 = wq.get(quantsOffset + 2);
+ byte quant4 = wq.get(quantsOffset + 3);
+
+ partialSum1 += ((float) quant1 * scaleFloat) * x.get(j);
+ partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1);
+ partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2);
+ partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3);
+ }
+
+ float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4;
+
+ for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) {
+ int blockIdx = j / blockSize;
+ int withinBlockIdx = j % blockSize;
+ int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
+
+ HalfFloat scale = wq.getHalfFloat(blockByteOffset);
+ float scaleFloat = scale.getFloat32();
+
+ byte quant = wq.get(blockByteOffset + 2 + withinBlockIdx);
+ partialSum += ((float) quant * scaleFloat) * x.get(j);
+ }
+
+ localSums[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSums[localId] += localSums[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ q.set(rowId, localSums[0]);
+ }
+
+ } else if (rowId < dim + kvDim) {
+ // ========== K projection ==========
+ int kRow = rowId - dim;
+ int rowBlockOffset = kRow * blocksPerRow;
+
+ float partialSum1 = 0.0f;
+ float partialSum2 = 0.0f;
+ float partialSum3 = 0.0f;
+ float partialSum4 = 0.0f;
+
+ for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) {
+ int blockIdx = j / blockSize;
+ int withinBlockIdx = j % blockSize;
+ int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
+
+ HalfFloat scale = wk.getHalfFloat(blockByteOffset);
+ float scaleFloat = scale.getFloat32();
+
+ int quantsOffset = blockByteOffset + 2 + withinBlockIdx;
+ byte quant1 = wk.get(quantsOffset);
+ byte quant2 = wk.get(quantsOffset + 1);
+ byte quant3 = wk.get(quantsOffset + 2);
+ byte quant4 = wk.get(quantsOffset + 3);
+
+ partialSum1 += ((float) quant1 * scaleFloat) * x.get(j);
+ partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1);
+ partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2);
+ partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3);
+ }
+
+ float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4;
+
+ for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) {
+ int blockIdx = j / blockSize;
+ int withinBlockIdx = j % blockSize;
+ int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
+
+ HalfFloat scale = wk.getHalfFloat(blockByteOffset);
+ float scaleFloat = scale.getFloat32();
+
+ byte quant = wk.get(blockByteOffset + 2 + withinBlockIdx);
+ partialSum += ((float) quant * scaleFloat) * x.get(j);
+ }
+
+ localSums[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSums[localId] += localSums[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ k.set(kRow, localSums[0]);
+ }
+
+ } else if (rowId < dim + 2 * kvDim) {
+ // ========== V projection ==========
+ int vRow = rowId - dim - kvDim;
+ int rowBlockOffset = vRow * blocksPerRow;
+
+ float partialSum1 = 0.0f;
+ float partialSum2 = 0.0f;
+ float partialSum3 = 0.0f;
+ float partialSum4 = 0.0f;
+
+ for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) {
+ int blockIdx = j / blockSize;
+ int withinBlockIdx = j % blockSize;
+ int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
+
+ HalfFloat scale = wv.getHalfFloat(blockByteOffset);
+ float scaleFloat = scale.getFloat32();
+
+ int quantsOffset = blockByteOffset + 2 + withinBlockIdx;
+ byte quant1 = wv.get(quantsOffset);
+ byte quant2 = wv.get(quantsOffset + 1);
+ byte quant3 = wv.get(quantsOffset + 2);
+ byte quant4 = wv.get(quantsOffset + 3);
+
+ partialSum1 += ((float) quant1 * scaleFloat) * x.get(j);
+ partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1);
+ partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2);
+ partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3);
+ }
+
+ float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4;
+
+ for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) {
+ int blockIdx = j / blockSize;
+ int withinBlockIdx = j % blockSize;
+ int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
+
+ HalfFloat scale = wv.getHalfFloat(blockByteOffset);
+ float scaleFloat = scale.getFloat32();
+
+ byte quant = wv.get(blockByteOffset + 2 + withinBlockIdx);
+ partialSum += ((float) quant * scaleFloat) * x.get(j);
+ }
+
+ localSums[localId] = partialSum;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSums[localId] += localSums[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ if (localId == 0) {
+ v.set(vRow, localSums[0]);
+ }
+ }
+ }
+
+ /**
+ * Fully fused RMS normalization + FFN W1/W3 matmul with SiLU/GLU for Q8_0 weights.
+ * Each workgroup redundantly computes RMS scale to avoid cross-workgroup sync.
+ */
+ public static void fullyFusedRmsNormFFNGateUpQ8(
+ KernelContext context,
+ FloatArray x, // raw input (FP32)
+ FloatArray hb, // output
+ FloatArray rmsWeights, // RMS norm weights
+ ByteArray w1, // Q8_0 quantized
+ ByteArray w3, // Q8_0 quantized
+ int dim, // input dimension
+ int hiddenDim, // output dimension
+ int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ if (rowId >= hiddenDim) {
+ return;
+ }
+
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ // ========== RMS Norm: Compute scale (each workgroup does this redundantly) ==========
+ float sumSquares = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ float val = x.get(j);
+ sumSquares += val * val;
+ }
+
+ localSum[localId] = sumSquares;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+
+ float scale = 1.0f / TornadoMath.sqrt(localSum[0] / dim + 1e-5f);
+
+ // ========== W1 matmul with inline RMS normalization ==========
+ int blockSize = 32;
+ final int Q8_0_BLOCK_BYTES = 34;
+ int blocksPerRow = (dim + blockSize - 1) / blockSize;
+ int rowBlockOffset = rowId * blocksPerRow;
+
+ float partialSum1_a = 0.0f;
+ float partialSum1_b = 0.0f;
+ float partialSum1_c = 0.0f;
+ float partialSum1_d = 0.0f;
+
+ for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) {
+ int blockIdx = j / blockSize;
+ int withinBlockIdx = j % blockSize;
+ int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
+
+ HalfFloat w1Scale = w1.getHalfFloat(blockByteOffset);
+ float w1ScaleFloat = w1Scale.getFloat32();
+
+ int quantsOffset = blockByteOffset + 2 + withinBlockIdx;
+ byte q1 = w1.get(quantsOffset);
+ byte q2 = w1.get(quantsOffset + 1);
+ byte q3 = w1.get(quantsOffset + 2);
+ byte q4 = w1.get(quantsOffset + 3);
+
+ float norm0 = rmsWeights.get(j) * scale * x.get(j);
+ float norm1 = rmsWeights.get(j + 1) * scale * x.get(j + 1);
+ float norm2 = rmsWeights.get(j + 2) * scale * x.get(j + 2);
+ float norm3 = rmsWeights.get(j + 3) * scale * x.get(j + 3);
+
+ partialSum1_a += ((float) q1 * w1ScaleFloat) * norm0;
+ partialSum1_b += ((float) q2 * w1ScaleFloat) * norm1;
+ partialSum1_c += ((float) q3 * w1ScaleFloat) * norm2;
+ partialSum1_d += ((float) q4 * w1ScaleFloat) * norm3;
+ }
+
+ float partialSum1 = partialSum1_a + partialSum1_b + partialSum1_c + partialSum1_d;
+
+ for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) {
+ int blockIdx = j / blockSize;
+ int withinBlockIdx = j % blockSize;
+ int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
+
+ HalfFloat w1Scale = w1.getHalfFloat(blockByteOffset);
+ float w1ScaleFloat = w1Scale.getFloat32();
+
+ byte quant = w1.get(blockByteOffset + 2 + withinBlockIdx);
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ partialSum1 += ((float) quant * w1ScaleFloat) * normalized;
+ }
+
+ localSum[localId] = partialSum1;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+ float result1 = localSum[0];
+
+ // ========== W3 matmul with inline RMS normalization ==========
+ float partialSum3_a = 0.0f;
+ float partialSum3_b = 0.0f;
+ float partialSum3_c = 0.0f;
+ float partialSum3_d = 0.0f;
+
+ for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) {
+ int blockIdx = j / blockSize;
+ int withinBlockIdx = j % blockSize;
+ int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
+
+ HalfFloat w3Scale = w3.getHalfFloat(blockByteOffset);
+ float w3ScaleFloat = w3Scale.getFloat32();
+
+ int quantsOffset = blockByteOffset + 2 + withinBlockIdx;
+ byte q1 = w3.get(quantsOffset);
+ byte q2 = w3.get(quantsOffset + 1);
+ byte q3 = w3.get(quantsOffset + 2);
+ byte q4 = w3.get(quantsOffset + 3);
+
+ float norm0 = rmsWeights.get(j) * scale * x.get(j);
+ float norm1 = rmsWeights.get(j + 1) * scale * x.get(j + 1);
+ float norm2 = rmsWeights.get(j + 2) * scale * x.get(j + 2);
+ float norm3 = rmsWeights.get(j + 3) * scale * x.get(j + 3);
+
+ partialSum3_a += ((float) q1 * w3ScaleFloat) * norm0;
+ partialSum3_b += ((float) q2 * w3ScaleFloat) * norm1;
+ partialSum3_c += ((float) q3 * w3ScaleFloat) * norm2;
+ partialSum3_d += ((float) q4 * w3ScaleFloat) * norm3;
+ }
+
+ float partialSum3 = partialSum3_a + partialSum3_b + partialSum3_c + partialSum3_d;
+
+ for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) {
+ int blockIdx = j / blockSize;
+ int withinBlockIdx = j % blockSize;
+ int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
+
+ HalfFloat w3Scale = w3.getHalfFloat(blockByteOffset);
+ float w3ScaleFloat = w3Scale.getFloat32();
+
+ byte quant = w3.get(blockByteOffset + 2 + withinBlockIdx);
+ float normalized = rmsWeights.get(j) * scale * x.get(j);
+ partialSum3 += ((float) quant * w3ScaleFloat) * normalized;
+ }
+
+ localSum[localId] = partialSum3;
+ context.localBarrier();
+
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+ float result3 = localSum[0];
+
+ // ========== SiLU + GLU ==========
+ if (localId == 0) {
+ float silu = result1 / (1.0f + TornadoMath.exp(-result1));
+ hb.set(rowId, silu * result3);
+ }
+ }
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
index 96acd650..8d105e89 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
@@ -4,6 +4,7 @@
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
@@ -21,7 +22,8 @@ public class LlamaFP16FFNLayers extends AbstractFFNLayers {
TaskGraph ffnTaskGraphs;
GridScheduler scheduler;
- List ffnLayerTaskGraphs;
+ List ffnLayerTaskGraphs;
+
public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) {
super(taskGraph, state, weights, config, schedulerType);
this.ffnLayerTaskGraphs = setupFFNLayered();
@@ -29,47 +31,45 @@ public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Config
@Override
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
- WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim()/2, 128);
WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
- int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
-
int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC);
WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize());
- WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128);
+
+ int fusedQKVRows = config.dim() + 2 * config.kvDim();
+ int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512);
// Map workers to tasks
for (int i = 0; i < config.numberOfLayers(); i++) {
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
+ // === Attention Block ===
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker);
+ // === FFN Block ===
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker);
}
return tornadoForwardScheduler;
}
@Override
- public GridScheduler getGridScheduler() {
+ public GridScheduler getGridScheduler() {
return scheduler;
}
@Override
- public TaskGraph getTaskGraph() {
+ public TaskGraph getTaskGraph() {
return ffnTaskGraphs;
}
@@ -83,22 +83,106 @@ public List getFfnLayerTaskGraphs() {
}
List setupFFNLayered() {
- state.temp.init(0.0f);
- state.tempFFN.init(0.0f);
- var numLayers = config.numberOfLayers();
-
- return IntStream.range(0, numLayers)
- .mapToObj(i -> {
- var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i);
- if (i == numLayers - 1) setupLastID(ffnLayer.getTaskGraphName());
- return ffnLayer.snapshot();
- })
- .toList();
+ return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> {
+ var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i);
+ if (i == config.numberOfLayers() - 1) {
+ setupLastID(ffnLayer.getTaskGraphName());
+ }
+ return ffnLayer.snapshot();
+ }).toList();
}
+ // @formatter:off
+ /**
+ * Transformer Layer Task Flow (LlamaFP16FFNLayers)
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ * ATTENTION BLOCK
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * wrapX (FP32)
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ attn_rms_reduce │──▶ temp (partial sums)
+ * └────────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌──────────────────┐
+ * │ attn_rms_finalize│──▶ temp (final scale)
+ * └────────┬─────────┘
+ * │
+ * ▼
+ * ┌─────────────────────┐
+ * │ attn_rms_apply_fp16 │──▶ wrapXbFP16 (normalized, FP16)
+ * └──────────┬──────────┘
+ * │
+ * ▼
+ * ┌────────────────┐ ┌─────────────────────────────┐
+ * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │
+ * └───────┬────────┘ └─────────────────────────────┘
+ * │
+ * ▼
+ * ┌───────────────────┐ ┌─────────────────────────────────────┐
+ * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │
+ * └─────────┬─────────┘ └─────────────────────────────────────┘
+ * │
+ * ▼
+ * ┌───────────┐
+ * │ attention │──▶ wrapXb (attention output)
+ * └─────┬─────┘
+ * │
+ * ▼
+ * ┌──────────────────┐
+ * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection)
+ * └────────┬─────────┘
+ * │
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │ FFN BLOCK
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │
+ * ▼
+ * ┌────────────────┐
+ * │ ffn_rms_reduce │──▶ tempFFN (partial sums)
+ * └───────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌─────────────────┐
+ * │ ffn_rms_finalize│──▶ tempFFN (final scale)
+ * └────────┬────────┘
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3)
+ * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU)
+ * │
+ * ▼
+ * ┌──────────────┐
+ * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection)
+ * └──────┬───────┘
+ * │
+ * ▼
+ * wrapX (FP32) ──▶ [next layer or logits]
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps)
+ *
+ * Data Flow Summary:
+ * Input: wrapX (FP32) - hidden state from previous layer
+ * Output: wrapX (FP32) - updated hidden state with residual connections
+ *
+ * Key Fusion Points:
+ * • qkv_projection: Fused Q/K/V matmuls (3→1 kernel)
+ * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel)
+ * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel)
+ *
+ */
TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) {
var layerTaskGraphName = "layer_" + layerIndex;
TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName);
+
+ // === Data Setup ===
unifiedLayer.consumeFromDevice(state.wrapX);
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
@@ -111,68 +195,174 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
weights.w2Layered[layerIndex].asHalfFloatArray(),
weights.w3Layered[layerIndex].asHalfFloatArray());
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
- unifiedLayer
- .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
- if (shouldUseFinalNormalization()) {
- unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp,
- config.dim(), config.rmsNormEps());
- }
- unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp)
- .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(),
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(),
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(),
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize())
- .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(),
- layerIndex, config.contextLength());
- configureAttention(unifiedLayer, layerIndex);
- unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(),
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
- if (shouldUseFinalNormalization()) {
- unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps());
- }
- unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN)
- .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(),
- weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(),
- config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX);
+
+ // === Attention Block ===
+ // RMS Normalization
+ unifiedLayer.task("attn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, state.temp, state.wrapX,
+ config.dim(), config.rmsNormEps(), state.localSize);
+
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("attn_rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, state.temp, config.dim(), config.rmsNormEps());
+ }
+
+ unifiedLayer.task("attn_rms_apply_fp16",
+ TransformerComputeKernels::mapContextWithQuantize,
+ context, state.wrapXbFP16, state.wrapX,
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);
+
+ // QKV Projection (fused)
+ unifiedLayer.task("qkv_projection",
+ TransformerComputeKernelsLayered::fusedQKVMatmulX,
+ context,
+ state.wrapXbFP16, // input (FP32)
+ state.wrapQ, // output Q
+ state.wrapK, // output K
+ state.wrapV, // output V
+ weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq
+ weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk
+ weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv
+ config.dim(), // dim
+ config.kvDim(), // kvDim
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // RoPE + KV Cache
+ unifiedLayer.task("rope_and_kv_cache",
+ TransformerComputeKernelsLayered::ropeRotationWithCacheCopy,
+ context,
+ state.positionHolder,
+ state.wrapQ, // Q (in/out)
+ state.wrapK, // K (in/out)
+ state.wrapV, // V (in only)
+ state.wrapKeyCache, // Key cache (out)
+ state.wrapValueCache, // Value cache (out)
+ config.kvDim(),
+ config.headSize(),
+ layerIndex,
+ config.contextLength());
+ // Attention
+ configureAttention(unifiedLayer, layerIndex);
+ // Output Projection (Wo) with residual
+ unifiedLayer.task("attn_output_proj",
+ TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
+ context, state.wrapXb, state.wrapX,
+ weights.woLayered[layerIndex].asHalfFloatArray(),
+ config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // === FFN Block ===
+ // RMS Normalization
+ unifiedLayer.task("ffn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, state.tempFFN, state.wrapX,
+ config.dim(), config.rmsNormEps(), state.localSize);
+
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("ffn_rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, state.tempFFN, config.dim(), config.rmsNormEps());
+ }
+
+ unifiedLayer.task("rms_ffn_gate_up",
+ TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp,
+ context,
+ state.wrapX, // raw input (FP32)
+ state.wrapHb, // output
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights
+ state.tempFFN, // RMS scale factor
+ weights.w1Layered[layerIndex].asHalfFloatArray(), // W1
+ weights.w3Layered[layerIndex].asHalfFloatArray(), // W3
+ config.dim(), // input dimension
+ config.hiddenDim(), // output dimension
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Down projection (W2) with residual
+ unifiedLayer.task("ffn_down_proj",
+ TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
+ context, state.wrapHb, state.wrapX,
+ weights.w2Layered[layerIndex].asHalfFloatArray(),
+ config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ unifiedLayer.persistOnDevice(state.wrapX);
+
return unifiedLayer;
}
protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
- // First layer: Transfer initial data to device (one-time transfer)
if (layerIndex == 0) {
- // Transfer all attention-related data: query, key, value matrices and their caches
- unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); //
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
- context, state.wrapXb, state.wrapXb2, //
- state.wrapQ, state.wrapK, state.wrapV, //
- state.wrapKeyCache, state.wrapValueCache, //
- state.wrapAtt, state.wrapHb); //
+ // First layer: Transfer initial data to device (one-time transfer)
+ unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
+ state.positionHolder,
+ state.temp, state.tempFFN
+ );
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ // Kernel context
+ context,
+ // Intermediate buffers
+ state.wrapXb, state.wrapXb2,
+ // QKV vectors
+ state.wrapQ, state.wrapK, state.wrapV,
+ // KV cache
+ state.wrapKeyCache, state.wrapValueCache,
+ // Attention & FFN buffers
+ state.wrapAtt, state.wrapHb, state.wrapXbFP16);
} else {
// Subsequent layers: Consume data already on device from previous layer
- unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, //
- state.wrapQ, state.wrapK, state.wrapV, //
- state.wrapKeyCache, state.wrapValueCache, //
- state.wrapAtt, state.wrapHb, //
- state.positionHolder //
- );
+ unifiedLayer.consumeFromDevice(
+ // Kernel context
+ context,
+ // Intermediate buffers
+ state.wrapXb, state.wrapXb2,
+ // QKV vectors
+ state.wrapQ, state.wrapK, state.wrapV,
+ // KV cache
+ state.wrapKeyCache, state.wrapValueCache,
+ // Attention & FFN buffers
+ state.wrapAtt, state.wrapHb,
+ // Position & misc
+ state.positionHolder, state.wrapXbFP16);
}
return unifiedLayer;
}
private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) {
if (schedulerType == SchedulerType.NVIDIA) {
- return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention,
- context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
- config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(),
- state.positionHolder, layerIndex, config.contextLength());
+ // Flash Attention (optimized for NVIDIA GPUs)
+ return unifiedLayer.task("attention",
+ TransformerComputeKernelsLayered::processHeadsFlashAttention,
+ context,
+ state.wrapQ, // Query
+ state.wrapKeyCache, // Key cache
+ state.wrapValueCache, // Value cache
+ state.wrapXb, // Output
+ config.numberOfHeads(),
+ config.headSize(),
+ config.kvDim(),
+ config.kvMul(),
+ state.positionHolder,
+ layerIndex,
+ config.contextLength());
} else {
- return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
- config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), state.positionHolder, state.wrapAtt, layerIndex, config.contextLength());
+ // Standard parallel attention (for non-NVIDIA backends)
+ return unifiedLayer.task("attention",
+ TransformerComputeKernelsLayered::processHeadsParallel,
+ state.wrapQ, // Query
+ state.wrapKeyCache, // Key cache
+ state.wrapValueCache, // Value cache
+ state.wrapXb, // Output
+ config.numberOfHeads(),
+ config.headSize(),
+ config.kvDim(),
+ config.kvMul(),
+ config.contextLength(), // seqLen parameter
+ state.positionHolder,
+ state.wrapAtt, // Attention weights buffer
+ layerIndex,
+ config.contextLength());
}
}
+ // @formatter:on
+
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
index a674c1c5..d2a81407 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
@@ -2,8 +2,8 @@
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
-import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
@@ -28,46 +28,81 @@ public class LogitsFP16Layer extends AbstractLayer {
public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) {
super(name, state, weights, config);
this.lastTaskGraphID = lastTaskGraphID;
- state.tempLogits.init(0.0f);
+ this.schedulerType = schedulerType;
var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor");
this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config);
- this.schedulerType = schedulerType;
}
- /**
- * Builds the logits computation graph.
- */
+ // @formatter:off
private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
- TaskGraph logits = new TaskGraph("logits");
- logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits)
- .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), weights.rms_final_weight_as_floatArray.asFloatArray())
- .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
- if (schedulerType == SchedulerType.NON_NVIDIA) {
- logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps());
- }
- logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits)
- .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(),
- LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS);
+ var logits = new TaskGraph("logits");
+ // === Data Setup ===
+ logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
+ logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
+ logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ // Kernel context
+ context,
+ // Output buffer
+ state.wrapLogits,
+ // Intermediate FP16 buffer
+ state.wrapXbFP16,
+ // Weights
+ weights.wclsByteArray.asHalfFloatArray(),
+ weights.rms_final_weight_as_floatArray.asFloatArray());
+
+ // === Final RMS Normalization ===
+ logits.task("rms_reduce",
+ TransformerComputeKernels::reductionOneBlockWithLayer,
+ context,
+ state.tempLogits, // output: partial sums + final scale factor
+ state.wrapX, // input: hidden state
+ config.dim(), // dimension
+ config.rmsNormEps(), // epsilon for numerical stability
+ state.localSize); // local workgroup size
+
+ if (schedulerType == SchedulerType.NON_NVIDIA) {
+ logits.task("rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context,
+ state.tempLogits, // in/out: combines partial sums
+ config.dim(), // dimension
+ config.rmsNormEps()); // epsilon
+ }
+
+ logits.task("rms_apply_fp16",
+ TransformerComputeKernels::mapContextWithQuantizeLogits,
+ context,
+ state.wrapXbFP16, // output: normalized (FP16)
+ state.wrapX, // input: hidden state
+ weights.rms_final_weight_as_floatArray.asFloatArray(), // RMS weights
+ state.tempLogits); // scale factor from reduction
+
+ // === Vocabulary Projection ===
+ logits.task("vocab_proj",
+ TransformerComputeKernelsLayered::matrixVectorGeneric,
+ context,
+ state.wrapXbFP16, // input (FP16)
+ state.wrapLogits, // output
+ weights.wclsByteArray.asHalfFloatArray(), // vocabulary weights
+ config.dim(), // input dimension
+ config.vocabularySize(), // output dimension
+ LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS);
+
+ // === Transfer Results to Host ===
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
return logits;
}
+ // @formatter:on
@Override
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
- WorkerGrid logitsRMS;
- if (weights instanceof Qwen2TornadoWeights) {
- logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32);
- } else {
- logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
- }
-
+ WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256);
var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
- WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
+ var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
-
- tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
- tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS);
- tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS);
+ tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS);
+ tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS);
+ tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker);
return tornadoForwardScheduler;
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
index 75f9f531..c3b9ddc6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
@@ -3,6 +3,7 @@
import org.beehive.gpullama3.inference.state.Phi3State;
import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
+import org.beehive.gpullama3.tornadovm.kernels.Phi3Kernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
@@ -19,42 +20,27 @@
/**
* Phi3FP16FFNLayers: FP16 FFN layers for Phi3 with Group Query Attention (GQA) support.
*
- * Key Differences from Qwen2/Qwen3:
- * - Uses combined QKV matrix (wqkv) instead of separate Q, K, V matrices
- * - Includes splitQKV task to separate combined buffer
- * - Uses ropeRotationPhi3 kernel for position embeddings
- * - FFN uses single wUp matrix that outputs both Gate and Up (2 * hiddenDim)
- * - Includes splitGateUpAndSiLU task for FFN activation
- * - Uses wDown for final FFN projection
- * - No Q, K, V bias terms
+ * Key Differences from Qwen2/Qwen3: - Uses combined QKV matrix (wqkv) instead of separate Q, K, V matrices - Includes splitQKV task to separate combined buffer - Uses ropeRotationPhi3 kernel for
+ * position embeddings - FFN uses single wUp matrix that outputs both Gate and Up (2 * hiddenDim) - Includes splitGateUpAndSiLU task for FFN activation - Uses wDown for final FFN projection - No Q, K,
+ * V bias terms
*
* Works directly with Phi3State to access and mutate Phi3-specific state fields.
*/
public class Phi3FP16FFNLayers extends AbstractFFNLayers {
- TaskGraph ffnLayerTaskGraph;
- GridScheduler scheduler;
- List ffnLayerTaskGraphs;
-
// Typed references to Phi3-specific state and config
private final Phi3State phi3State;
private final Phi3Configuration phi3Config;
-
// Phi3-specific dimension for combined QKV buffer
private final int opSize;
+ TaskGraph ffnLayerTaskGraph;
+ GridScheduler scheduler;
+ List ffnLayerTaskGraphs;
public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) {
- super(taskGraphName, state, weights, config,schedulerType);
+ super(taskGraphName, state, weights, config, schedulerType);
this.phi3State = state;
this.phi3Config = config;
-
- // Ensure we have Phi3-specific weights
- if (!(weights instanceof Phi3TornadoWeights phi3Weights)) {
- throw new IllegalArgumentException("Phi3FP16FFNLayers requires Phi3TornadoWeights with TornadoTensor layout");
- }
-
- // Calculate opSize for combined QKV buffer
- // opSize = num_heads * head_dim + 2 * (num_key_value_heads * head_dim) = dim + 2 * kvDim
this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize());
ffnLayerTaskGraphs = setupFFNLayered();
}
@@ -64,46 +50,41 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
// RMS norm worker
WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize);
- // Combined QKV matmul worker
- int matmulQkvGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid matmulQkvRowMajorWorker = WorkerGridFactory.genericWorker(matmulQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
-
- // RoPE worker (2D: heads x embedding_head/2)
- int ic = config.headSize() / 2;
- WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), config.headSize());
+ // Fused RMS + QKV matmul worker
+ int fusedQkvGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
- // Copy to cache worker
- WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.kvDim(), 32);
+ // Fused RoPE + cache copy worker (Phi3 uses dim/2 pattern)
+ WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128);
// Parallel attention worker
WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize());
- // Matmul1 worker (output projection)
+ // Output projection worker
int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC);
- // FFN workers
- int ffnUpGlobal = (2 * config.hiddenDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid ffnUpWorker = WorkerGridFactory.genericWorker(ffnUpGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ // Fused RMS + FFN gate/up worker
+ int fusedFFNGlobal = (2 * config.hiddenDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid fusedFFNWorker = WorkerGridFactory.genericWorker(fusedFFNGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ // FFN down projection worker
int ffnDownGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ // Same worker as before - total rows = dim + 2*kvDim = opSize
- // Map workers to tasks for each layer
for (int i = 0; i < config.numberOfLayers(); i++) {
- gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", matmulQkvRowMajorWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker);
- gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".wGateUp", ffnUpWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".wDown", ffnDownWorker);
+ // === Attention Block ===
+ gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQkvWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker);
+ // === FFN Block ===
+ gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_silu", fusedFFNWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", ffnDownWorker);
}
-
return gridScheduler;
}
@@ -131,11 +112,6 @@ public List getFfnLayerTaskGraphs() {
*/
List setupFFNLayered() {
List ffnGraphs = new ArrayList<>();
-
- // Initialize buffers using Phi3State directly
- phi3State.temp.init(0.0f);
- phi3State.tempFFN.init(0.0f);
-
for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) {
TaskGraph ffnLayer = setupSinglePhi3FFNLayer((Phi3TornadoWeights) weights, layerIndex);
if (layerIndex == phi3Config.numberOfLayers() - 1) {
@@ -143,160 +119,197 @@ List setupFFNLayered() {
}
ffnGraphs.add(ffnLayer.snapshot());
}
-
return ffnGraphs;
}
+ // @formatter:off
/**
- * Setup a single transformer layer for Phi3 with combined QKV and gate/up FFN
+ * Transformer Layer Task Flow (Phi3FP16FFNLayers - Fully Optimized)
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ * ATTENTION BLOCK
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * wrapX (FP32)
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ attn_rms_reduce │──▶ temp (scale factor for RMSNorm)
+ * └────────┬────────┘
+ * │
+ * ▼
+ * ┌────────────────────────┐
+ * │ attn_rms_qkv_projection│──▶ wrapQ, wrapK, wrapV (direct output)
+ * └───────────┬────────────┘ (fused: RMS apply + QKV matmul + split)
+ * │
+ * ▼
+ * ┌───────────────────┐ ┌─────────────────────────────────────┐
+ * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │
+ * └─────────┬─────────┘ └─────────────────────────────────────┘
+ * │ (fused: Phi3 RoPE + cache write)
+ * ▼
+ * ┌───────────┐
+ * │ attention │──▶ wrapXb (attention output)
+ * └─────┬─────┘
+ * │
+ * ▼
+ * ┌──────────────────┐
+ * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection)
+ * └────────┬─────────┘
+ * │
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │ FFN BLOCK
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │
+ * ▼
+ * ┌────────────────┐
+ * │ ffn_rms_reduce │──▶ tempFFN (scale factor)
+ * └───────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌──────────────────┐
+ * │ ffn_rms_finalize │──▶ tempFFN (final scale)
+ * └────────┬─────────┘
+ * │
+ * ▼
+ * ┌──────────────┐
+ * │ rms_ffn_silu │──▶ wrapHbU = SiLU(RMSNorm(x)·Wgate) ⊙ (RMSNorm(x)·Wup)
+ * └──────┬───────┘ (fused: RMS apply + gate/up matmul + SiLU + GLU)
+ * │
+ * ▼
+ * ┌──────────────┐
+ * │ ffn_down_proj│──▶ wrapX += wDown · wrapHbU (residual connection)
+ * └──────┬───────┘
+ * │
+ * ▼
+ * wrapX (FP32) ──▶ [next layer or logits]
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * Task Count: 8 tasks (NVIDIA) / 9 tasks (non-NVIDIA)
+ * Original: 13 tasks
+ * Reduction: 5 tasks eliminated (38% fewer kernel launches)
+ *
+ * Data Flow Summary:
+ * Input: wrapX (FP32) - hidden state from previous layer
+ * Output: wrapX (FP32) - updated hidden state with residual connections
+ *
+ * Key Fusion Points (vs original 13 tasks):
+ * • attn_rms_qkv_projection: Fused RMS apply + QKV matmul + direct split (3→1 kernel)
+ * • rope_and_kv_cache: Fused Phi3 RoPE rotation + cache write (2→1 kernel)
+ * • rms_ffn_silu: Fused RMS apply + gate/up matmul + SiLU + GLU (3→1 kernel)
+ *
+ * Phi3-Specific:
+ * • Combined wqkv: Single [opSize × dim] matrix for Q+K+V projection
+ * • Direct QKV output: No intermediate buffer, routes by row index
+ * • Phi3 RoPE: Uses headSize/2 offset pattern (different from Llama/Qwen)
+ * • Combined wUp: Single [2×hiddenDim × dim] matrix for gate+up
+ * • Inline SiLU+GLU: No intermediate wrapHb buffer needed
+ *
*/
+ // @formatter:off
TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) {
-
- TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex);
+ var taskGraphName = "layer_" + layerIndex;
+ var unifiedLayer = new TaskGraph(taskGraphName);
unifiedLayer.consumeFromDevice(phi3State.wrapX);
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
- // Copy-in weights per layer for batched-layered layout
+ // Attention weights
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
weights.wqkvLayered[layerIndex].asHalfFloatArray(),
weights.woLayered[layerIndex].asHalfFloatArray(),
+ // FFN weights
weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
weights.wUpLayered[layerIndex].asHalfFloatArray(),
- weights.wDownLayered[layerIndex].asHalfFloatArray()
- );
+ weights.wDownLayered[layerIndex].asHalfFloatArray());
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
- // RMSNorm for attention input
- unifiedLayer.task("reductionsOneBlock",
- TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
- context,
- phi3State.temp,
- phi3State.wrapX,
- phi3Config.dim(),
- phi3Config.rmsNormEps(),
- phi3State.localSize)
- .task("mapContext",
- TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
- context,
- phi3State.wrapXb,
- phi3State.wrapX,
- weights.rms_att_weightLayered[layerIndex].asFloatArray(),
- phi3State.temp);
-
- // Combined QKV projection
- unifiedLayer.task("qkvmatmul",
- TransformerComputeKernelsLayered::matrixVectorGeneric,
- context,
- phi3State.wrapXb,
- phi3State.wrapQkv,
- weights.wqkvLayered[layerIndex].asHalfFloatArray(),
- phi3Config.dim(),
- opSize,
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("splitQKV",
- TransformerComputeKernelsLayered::splitQKV,
- phi3State.wrapQkv,
- phi3State.wrapQ,
- phi3State.wrapK,
- phi3State.wrapV,
- phi3Config.dim(),
- phi3Config.headSize() * phi3Config.numberOfKeyValueHeads());
-
- // RoPE rotation (Phi3-specific kernel)
- unifiedLayer.task("rope",
- TransformerComputeKernelsLayered::ropeRotationPhi3,
- context,
- phi3State.positionHolder,
- phi3State.wrapQ,
- phi3State.wrapK,
- phi3Config.kvDim(),
- phi3Config.headSize());
-
- // Copy to caches
- unifiedLayer.task("copyToCaches",
- TransformerComputeKernelsLayered::copyToCache,
- phi3State.wrapKeyCache,
- phi3State.wrapK,
- phi3State.wrapValueCache,
- phi3State.wrapV,
- phi3State.positionHolder,
- phi3Config.kvDim(),
- layerIndex,
- phi3Config.contextLength());
-
- // Parallel attention
- unifiedLayer.task("parallel-attention",
- TransformerComputeKernelsLayered::processHeadsFlashAttention,
- context,
- phi3State.wrapQ,
- phi3State.wrapKeyCache,
- phi3State.wrapValueCache,
- phi3State.wrapXb,
- phi3Config.numberOfHeads(),
- phi3Config.headSize(),
- phi3Config.kvDim(),
- phi3Config.kvMul(),
- phi3State.positionHolder,
- layerIndex,
- phi3Config.contextLength());
-
- // Output projection
- unifiedLayer.task("matmul1",
- TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
- context,
- phi3State.wrapXb,
- phi3State.wrapX,
- weights.woLayered[layerIndex].asHalfFloatArray(),
- phi3Config.dim(),
- phi3Config.dim(),
+ // ═══════════════════════════════════════════════════════════════════════
+ // ATTENTION BLOCK
+ // ═══════════════════════════════════════════════════════════════════════
+
+ // RMS Normalization - compute scale factor
+ unifiedLayer.task("attn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, phi3State.temp, // output: scale factor
+ phi3State.wrapX, // input: hidden state
+ phi3Config.dim(), // dimension
+ phi3Config.rmsNormEps(), // epsilon
+ phi3State.localSize); // local memory size
+
+ unifiedLayer.task("attn_rms_qkv_projection", Phi3Kernels::fusedRmsNormQKVMatmulDirect, context, phi3State.wrapX, // input
+ phi3State.wrapQ, // output Q
+ phi3State.wrapK, // output K
+ phi3State.wrapV, // output V
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(), phi3State.temp, // RMS scale
+ weights.wqkvLayered[layerIndex].asHalfFloatArray(), phi3Config.dim(), // dim
+ phi3Config.kvDim(), // kvDim
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Fused Phi3 RoPE Rotation + KV Cache Write
+ unifiedLayer.task("rope_and_kv_cache", Phi3Kernels::ropeRotationWithCacheCopyPhi3, context, phi3State.positionHolder, // current position
+ phi3State.wrapQ, // Q vectors (in/out, rotated)
+ phi3State.wrapK, // K vectors (in/out, rotated)
+ phi3State.wrapV, // V vectors (in only)
+ phi3State.wrapKeyCache, // key cache (out)
+ phi3State.wrapValueCache, // value cache (out)
+ phi3Config.numberOfKeyValueHeads(), // nHeadKv
+ phi3Config.headSize(), // head dimension
+ phi3Config.kvDim(), // kvDim
+ layerIndex, // layer index for cache offset
+ phi3Config.contextLength()); // max sequence length
+
+ // Flash Attention
+ unifiedLayer.task("attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, phi3State.wrapQ, // query vectors
+ phi3State.wrapKeyCache, // key cache
+ phi3State.wrapValueCache, // value cache
+ phi3State.wrapXb, // output: attention result
+ phi3Config.numberOfHeads(), // nHeads
+ phi3Config.headSize(), // headSize
+ phi3Config.kvDim(), // kvDim
+ phi3Config.kvMul(), // kvMul (nHeads / nHeadKv)
+ phi3State.positionHolder, // position
+ layerIndex, // layer index
+ phi3Config.contextLength()); // context length
+
+ // Output Projection with Residual
+ unifiedLayer.task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, phi3State.wrapXb, // input: attention output
+ phi3State.wrapX, // output: wrapX += Wo · wrapXb
+ weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim × dim]
+ phi3Config.dim(), // input dim
+ phi3Config.dim(), // output dim
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // ═══════════════════════════════════════════════════════════════════════
+ // FFN BLOCK
+ // ═══════════════════════════════════════════════════════════════════════
+
+ // RMS Normalization - compute scale factor
+ unifiedLayer.task("ffn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, phi3State.tempFFN, // output: scale factor
+ phi3State.wrapX, // input: hidden state
+ phi3Config.dim(), // dimension
+ phi3Config.rmsNormEps(), // epsilon
+ phi3State.localSize); // local memory size
+
+ // Final normalization (non-NVIDIA only)
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("ffn_rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, context, phi3State.tempFFN, // scale factor (in/out)
+ phi3Config.dim(), // dimension
+ phi3Config.rmsNormEps()); // epsilon
+ }
+
+ unifiedLayer.task("rms_ffn_silu", Phi3Kernels::fusedRmsNormFFNGateUpSiLU, context, phi3State.wrapX, // input
+ phi3State.wrapHbU, // output (direct to final FFN buffer)
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), phi3State.tempFFN, // RMS scale
+ weights.wUpLayered[layerIndex].asHalfFloatArray(), phi3Config.dim(), // input dim
+ phi3Config.hiddenDim(), // output dim (hiddenDim, not 2×hiddenDim!)
LOCAL_WORK_GROUP_SIZE_ALLOC);
- // FFN section: RMSNorm
- unifiedLayer.task("reductionsOneBlockFFN",
- TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
- context,
- phi3State.tempFFN,
- phi3State.wrapX,
- phi3Config.dim(),
- phi3Config.rmsNormEps(),
- phi3State.localSize)
- .task("mapContextFFN",
- TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
- context,
- phi3State.wrapXb,
- phi3State.wrapX,
- weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
- phi3State.tempFFN);
-
- // FFN: combined Up and Gate projection (outputs 2 * hiddenDim)
- unifiedLayer.task("wGateUp",
- TransformerComputeKernelsLayered::matrixVectorGeneric,
- context,
- phi3State.wrapXb,
- phi3State.wrapHb,
- weights.wUpLayered[layerIndex].asHalfFloatArray(),
- phi3Config.dim(),
- 2 * phi3Config.hiddenDim(),
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("gateUpSiLU",
- TransformerComputeKernelsLayered::splitGateUpAndSiLU,
- phi3State.wrapHb,
- phi3State.wrapHbG,
- phi3State.wrapHbU,
- phi3Config.hiddenDim());
-
- // FFN: Down projection with residual
- unifiedLayer.task("wDown",
- TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
- context,
- phi3State.wrapHbU,
- phi3State.wrapX,
- weights.wDownLayered[layerIndex].asHalfFloatArray(),
- phi3Config.hiddenDim(),
- phi3Config.dim(),
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .persistOnDevice(
- phi3State.wrapX
- );
+ // Down Projection with Residual
+ unifiedLayer.task("ffn_down_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, phi3State.wrapHbU, // input: FFN intermediate
+ phi3State.wrapX, // output: wrapX += wDown · wrapHbU
+ weights.wDownLayered[layerIndex].asHalfFloatArray(), // wDown [dim × hiddenDim]
+ phi3Config.hiddenDim(), // input dim
+ phi3Config.dim(), // output dim
+ LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(phi3State.wrapX);
+
return unifiedLayer;
}
@@ -304,26 +317,26 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) {
* Configure data transfers for first and subsequent layers
*/
protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
- // First layer: Transfer initial data to device (one-time transfer)
if (layerIndex == 0) {
- // Transfer all attention-related data: query, key, value matrices and their caches
- unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); //
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
- context, state.wrapXb, state.wrapXb2, //
- state.wrapQ, state.wrapK, state.wrapV, //
- state.wrapKeyCache, state.wrapValueCache, //
- state.wrapAtt, state.wrapHb, //
- phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); //
+ // First layer: Transfer temporary buffers and state every execution
+ unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, phi3State.positionHolder,
+ phi3State.temp, phi3State.tempFFN);
+ // First execution: allocate workspace buffers
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ context, phi3State.wrapXb, phi3State.wrapXb2,
+ phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV,
+ phi3State.wrapKeyCache, phi3State.wrapValueCache,
+ phi3State.wrapAtt, phi3State.wrapHb, phi3State.wrapHbG,
+ phi3State.wrapHbU, phi3State.wrapQkv);
} else {
- // Subsequent layers: Consume data already on device from previous layer
- unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, //
- state.wrapQ, state.wrapK, state.wrapV, //
- state.wrapKeyCache, state.wrapValueCache, //
- state.wrapAtt, state.wrapHb, //
- state.positionHolder, // /
+ // Subsequent layers: Consume data from previous layer
+ unifiedLayer.consumeFromDevice(context, phi3State.wrapXb, phi3State.wrapXb2,
+ phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, phi3State.wrapKeyCache,
+ phi3State.wrapValueCache, phi3State.wrapAtt, phi3State.wrapHb, phi3State.positionHolder,
phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv);
}
return unifiedLayer;
}
+ // @formatter:on
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
index 858848ea..aded0c81 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
@@ -51,7 +51,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
int ic = config.headSize() / 2;
WorkerGrid ropeWorker = new WorkerGrid2D(h, ic);
ropeWorker.setGlobalWork(h, ic, 1);
- ropeWorker.setLocalWork(1, 1, 1);
+ ropeWorker.setLocalWork(h / 2, ic / 2, 1);
int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal);
@@ -61,13 +61,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal);
configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
- WorkerGrid qBiasWorker = new WorkerGrid1D(config.dim());
- qBiasWorker.setGlobalWork(config.dim(), 1, 1);
- qBiasWorker.setLocalWork(config.dim() / 8, 1, 1);
- WorkerGrid kvBiasWorker = new WorkerGrid1D(config.kvDim());
- kvBiasWorker.setGlobalWork(config.kvDim(), 1, 1);
- kvBiasWorker.setLocalWork(32, 1, 1);
-
int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor);
configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
@@ -87,32 +80,38 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
}
}
+ // WorkerGrid for fused QKV bias addition (dimension is dimQ)
+ WorkerGrid fusedQKVBiasWorker = new WorkerGrid1D(config.dim());
+ fusedQKVBiasWorker.setGlobalWork(config.dim(), 1, 1);
+ fusedQKVBiasWorker.setLocalWork(32, 1, 1); // Or an optimized local size
+
WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads());
parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * optimalLocalSize, 1, 1);
parallelAttentionWorker.setLocalWork(optimalLocalSize, 1, 1);
- WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim());
- copyToCachesWorker.setGlobalWork(config.kvDim(), 1, 1);
- copyToCachesWorker.setLocalWork(32, 1, 1); // Set local work size to 32 (for copying to caches)
+ int fusedQKVGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid fusedQKVWorker = new WorkerGrid1D(fusedQKVGlobal);
+ fusedQKVWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
+
+ // Fused QKV bias worker (covers dimQ which is largest)
+ WorkerGrid fusedQKVBiasWorkerNorm = new WorkerGrid1D(config.dim());
+ fusedQKVBiasWorkerNorm.setGlobalWork(config.dim(), 1, 1);
+ fusedQKVBiasWorkerNorm.setLocalWork(32, 1, 1);
// Map workers to tasks
for (int i = 0; i < config.numberOfLayers(); i++) {
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qbias", qBiasWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kbias", kvBiasWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vbias", kvBiasWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_qkv_bias", fusedQKVBiasWorkerNorm);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_qkv_bias", fusedQKVBiasWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker);
}
return tornadoForwardScheduler;
}
@@ -136,15 +135,8 @@ public List getFfnLayerTaskGraphs() {
return ffnLayerTaskGraphs;
}
- /**
- * Setup all FFN layers for all transformer layers
- */
List setupFFNLayered() {
- List ffnGraphs = new ArrayList<>();
-
- qwen2State.temp.init(0.0f);
- qwen2State.tempFFN.init(0.0f);
-
+ List ffnGraphs = new ArrayList<>(qwen2Config.numberOfLayers());
for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) {
TaskGraph ffnLayer = setupSingleQwen2FFNLayer((Qwen2TornadoWeights) weights, layerIndex);
if (layerIndex == qwen2Config.numberOfLayers() - 1) {
@@ -152,59 +144,257 @@ List setupFFNLayered() {
}
ffnGraphs.add(ffnLayer.snapshot());
}
-
return ffnGraphs;
}
+ // @formatter:off
/**
- * Setup a single transformer layer for Qwen2 with GQA
+ * Transformer Layer Task Flow (Qwen2FP16FFNLayers - Optimized)
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ * ATTENTION BLOCK
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * wrapX (FP32)
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ attn_rms_reduce │──▶ temp (scale factor for RMSNorm)
+ * └────────┬────────┘
+ * │
+ * ▼
+ * ┌─────────────────────────┐
+ * │ attn_rms_qkv_projection │──▶ wrapQ, wrapK, wrapV (FP32)
+ * └───────────┬─────────────┘ (fused: RMS apply + Q/K/V matmuls)
+ * │
+ * ▼
+ * ┌────────────────┐
+ * │ fused_qkv_bias │──▶ wrapQ, wrapK, wrapV += biases
+ * └───────┬────────┘ (fused: Q + K + V bias addition)
+ * │
+ * ▼
+ * ┌───────────────────┐ ┌─────────────────────────────────────┐
+ * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │
+ * └─────────┬─────────┘ └─────────────────────────────────────┘
+ * │ (fused: RoPE rotation + cache write)
+ * ▼
+ * ┌───────────┐
+ * │ attention │──▶ wrapXb (attention output)
+ * └─────┬─────┘
+ * │
+ * ▼
+ * ┌──────────────────┐
+ * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection)
+ * └────────┬─────────┘
+ * │
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │ FFN BLOCK
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │
+ * ▼
+ * ┌────────────────┐
+ * │ ffn_rms_reduce │──▶ tempFFN (scale factor)
+ * └───────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌──────────────────┐
+ * │ ffn_rms_finalize │──▶ tempFFN (final scale)
+ * └────────┬─────────┘
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3)
+ * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU)
+ * │
+ * ▼
+ * ┌──────────────┐
+ * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection)
+ * └──────┬───────┘
+ * │
+ * ▼
+ * wrapX (FP32) ──▶ [next layer or logits]
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * Task Count: 9 tasks (NVIDIA) / 10 tasks (non-NVIDIA)
+ * Previous: 12 tasks
+ * Reduction: 3 tasks eliminated (25% fewer kernel launches)
+ *
+ * Data Flow Summary:
+ * Input: wrapX (FP32) - hidden state from previous layer
+ * Output: wrapX (FP32) - updated hidden state with residual connections
+ *
+ * Key Fusion Points (vs previous 12 tasks):
+ * • fused_qkv_bias: Fused Q + K + V bias addition (3→1 kernel)
+ * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU
+ * (eliminates separate mapContextFFN kernel)
+ *
+ * Qwen2-Specific:
+ * • GQA: nHeads (Q) != nHeadKv (K/V), with kvMul = nHeads / nHeadKv
+ * • Bias terms: Q, K, V projections include bias (unlike Qwen3)
+ * • No Q/K RMSNorm: Unlike Qwen3, Qwen2 doesn't normalize Q/K after projection
+ *
*/
TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) {
var taskGraphName = "layer_" + layerIndex;
TaskGraph unifiedLayer = new TaskGraph(taskGraphName);
unifiedLayer.consumeFromDevice(state.wrapX);
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
- weights.rms_att_weightLayered[layerIndex].asFloatArray(), //
- weights.wqLayered[layerIndex].asHalfFloatArray(), //
- weights.wkLayered[layerIndex].asHalfFloatArray(), //
- weights.wvLayered[layerIndex].asHalfFloatArray(), //
- weights.woLayered[layerIndex].asHalfFloatArray(), //
- weights.q_biasLayered[layerIndex].asFloatArray(), //
- weights.k_biasLayered[layerIndex].asFloatArray(), //
- weights.v_biasLayered[layerIndex].asFloatArray(), //
- weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), //
- weights.w1Layered[layerIndex].asHalfFloatArray(), //
- weights.w2Layered[layerIndex].asHalfFloatArray(), //
- weights.w3Layered[layerIndex].asHalfFloatArray()); //
- unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); //
-
- unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.temp, qwen2State.wrapX, config.dim(), config.rmsNormEps(),
- qwen2State.localSize)
- .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(),
- qwen2State.temp)
- .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(),
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(),
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ // Attention weights
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(),
+ weights.wqLayered[layerIndex].asHalfFloatArray(),
+ weights.wkLayered[layerIndex].asHalfFloatArray(),
+ weights.wvLayered[layerIndex].asHalfFloatArray(),
+ weights.woLayered[layerIndex].asHalfFloatArray(),
+ // Qwen2-specific bias terms
+ weights.q_biasLayered[layerIndex].asFloatArray(),
+ weights.k_biasLayered[layerIndex].asFloatArray(),
+ weights.v_biasLayered[layerIndex].asFloatArray(),
+ // FFN weights
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
+ weights.w1Layered[layerIndex].asHalfFloatArray(),
+ weights.w2Layered[layerIndex].asHalfFloatArray(),
+ weights.w3Layered[layerIndex].asHalfFloatArray());
+ unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
+
+ // ═══════════════════════════════════════════════════════════════════════
+ // ATTENTION BLOCK
+ // ═══════════════════════════════════════════════════════════════════════
+
+ // RMS Normalization - compute scale factor
+ unifiedLayer.task("attn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context,
+ qwen2State.temp, // output: scale factor
+ qwen2State.wrapX, // input: hidden state
+ config.dim(), // dimension
+ config.rmsNormEps(), // epsilon
+ qwen2State.localSize); // local memory size
+
+ // Fused RMS Apply + QKV Projection
+ unifiedLayer.task("attn_rms_qkv_projection",
+ Qwen3Kernels::fusedRmsNormQKVMatmul,
+ context,
+ qwen2State.wrapX, // input: raw hidden state (FP32)
+ qwen2State.wrapQ, // output: Q vectors
+ qwen2State.wrapK, // output: K vectors
+ qwen2State.wrapV, // output: V vectors
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights
+ qwen2State.temp, // RMS scale factor from reduction
+ weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq
+ weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk
+ weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv
+ config.dim(), // input dimension
+ config.dim(), // Q output dimension
+ config.kvDim(), // K/V output dimension (GQA)
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Fused Q/K/V Bias Addition (3→1 kernel fusion)
+ unifiedLayer.task("fused_qkv_bias",
+ TransformerComputeKernelsLayered::fusedQKvBiasAddition,
+ context,
+ qwen2State.wrapQ, // Q (in/out)
+ qwen2State.wrapK, // K (in/out)
+ weights.q_biasLayered[layerIndex].asFloatArray(), // Q bias
+ qwen2State.wrapV, // V (in/out)
+ weights.k_biasLayered[layerIndex].asFloatArray(), // K bias
+ weights.v_biasLayered[layerIndex].asFloatArray(), // V bias
+ config.dim(), // dimQ
+ config.kvDim()); // dimKV
+
+ // Fused RoPE Rotation + KV Cache Write
+ unifiedLayer.task("rope_and_kv_cache",
+ Qwen3Kernels::ropeRotationWithCacheCopy,
+ context,
+ qwen2State.positionHolder, // current sequence position
+ qwen2State.wrapQ, // Q (rotated in-place)
+ qwen2State.wrapK, // K (rotated in-place)
+ qwen2State.wrapV, // V (copied to cache)
+ qwen2State.wrapKeyCache, // key cache (write)
+ qwen2State.wrapValueCache, // value cache (write)
+ config.numberOfKeyValueHeads(), // nHeadKv
+ config.headSize(), // per-head dimension
+ config.kvDim(), // kvDim
+ layerIndex, // layer offset
+ config.contextLength()); // max sequence length
+
+ // Flash Attention
+ unifiedLayer.task("attention",
+ Qwen2Kernels::processHeadsFlashAttention,
+ context,
+ qwen2State.wrapQ, // query vectors
+ qwen2State.wrapKeyCache, // key cache
+ qwen2State.wrapValueCache, // value cache
+ qwen2State.wrapXb, // output: attention result
+ config.numberOfHeads(), // nHeads
+ config.headSize(), // headSize
+ config.kvDim(), // kvDim
+ config.kvMul(), // kvMul (nHeads / nHeadKv)
+ qwen2State.positionHolder, // position
+ layerIndex, // layer index
+ config.contextLength()); // context length
+
+ // Output Projection with Residual
+ unifiedLayer.task("attn_output_proj",
+ TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
+ context,
+ qwen2State.wrapXb, // input: attention output
+ qwen2State.wrapX, // output: wrapX += Wo · wrapXb
+ weights.woLayered[layerIndex].asHalfFloatArray(), // Wo
+ config.dim(), // input dim
+ config.dim(), // output dim
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // ═══════════════════════════════════════════════════════════════════════
+ // FFN BLOCK
+ // ═══════════════════════════════════════════════════════════════════════
+
+ // RMS Normalization - compute scale factor
+ unifiedLayer.task("ffn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context,
+ qwen2State.tempFFN, // output: scale factor
+ qwen2State.wrapX, // input: hidden state
+ config.dim(), // dimension
+ config.rmsNormEps(), // epsilon
+ qwen2State.localSize); // local memory size
+
+ // Final normalization (non-NVIDIA only)
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("ffn_rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context,
+ qwen2State.tempFFN, // scale factor (in/out)
+ config.dim(), // dimension
+ config.rmsNormEps()); // epsilon
+ }
+
+ // Fused RMS Apply + Gate/Up Projection + SiLU + GLU
+ // (Replaces mapContextFFN + fusedFeedForwardWithSiLUAndGLUActivation)
+ unifiedLayer.task("rms_ffn_gate_up",
+ TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp,
+ context,
+ qwen2State.wrapX, // input: raw hidden state (FP32)
+ qwen2State.wrapHb, // output: SiLU(x·W1) ⊙ (x·W3)
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights
+ qwen2State.tempFFN, // RMS scale factor
+ weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 (gate)
+ weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 (up)
+ config.dim(), // input dimension
+ config.hiddenDim(), // hidden dimension
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Down Projection with Residual
+ unifiedLayer.task("ffn_down_proj",
+ TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
+ context,
+ qwen2State.wrapHb, // input: FFN intermediate
+ qwen2State.wrapX, // output: wrapX += W2 · wrapHb
+ weights.w2Layered[layerIndex].asHalfFloatArray(), // W2 (down)
+ config.hiddenDim(), // input dim
+ config.dim(), // output dim
LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(),
- LOCAL_WORK_GROUP_SIZE_ALLOC).task("qbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim())
- .task("kbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapK, weights.k_biasLayered[layerIndex].asFloatArray(), config.kvDim())
- .task("vbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim())
- .task("rope", Qwen3Kernels::ropeRotation, context, qwen2State.positionHolder, qwen2State.wrapQ, qwen2State.wrapK, config.numberOfKeyValueHeads(), config.headSize())
- .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, qwen2State.wrapKeyCache, qwen2State.wrapK, qwen2State.wrapValueCache, qwen2State.wrapV, qwen2State.positionHolder,
- config.kvDim(), layerIndex, config.contextLength())
- .task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context, qwen2State.wrapQ, qwen2State.wrapKeyCache, qwen2State.wrapValueCache, qwen2State.wrapXb,
- config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), qwen2State.positionHolder, layerIndex, config.contextLength())
- .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapXb, qwen2State.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(),
- config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.tempFFN, qwen2State.wrapX, config.dim(), config.rmsNormEps(),
- qwen2State.localSize)
- .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
- qwen2State.tempFFN)
- .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen2State.wrapXb, qwen2State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(),
- weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapHb, qwen2State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(),
- config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX);
+ .persistOnDevice(state.wrapX);
return unifiedLayer;
}
@@ -216,7 +406,8 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
// First layer: Transfer initial data to device (one-time transfer)
if (layerIndex == 0) {
// Transfer all attention-related data: query, key, value matrices and their caches
- unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen2State.positionHolder, qwen2State.temp, qwen2State.tempFFN); //
+ unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen2State.positionHolder,
+ qwen2State.temp, qwen2State.tempFFN); //
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
context, qwen2State.wrapXb, qwen2State.wrapXb2, //
qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, //
@@ -234,5 +425,6 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
}
return unifiedLayer;
}
+ // @formatter:on
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
index 379921c3..453a4d7c 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
@@ -14,8 +14,8 @@
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-import java.util.ArrayList;
import java.util.List;
+import java.util.stream.IntStream;
/**
* Qwen3FP16FFNLayers: FP16 FFN layers for Qwen3 with Group Query Attention (GQA) support.
@@ -43,7 +43,7 @@ public class Qwen3FP16FFNLayers extends AbstractFFNLayers {
List ffnLayerTaskGraphs;
public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) {
- super(taskGraphName, state, weights, config,schedulerType);
+ super(taskGraphName, state, weights, config, schedulerType);
this.qwen3State = state;
this.qwen3Config = config;
@@ -61,71 +61,44 @@ public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe
@Override
public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize);
-
- // Q matmul worker (GQA: full query heads)
- int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid matmulQRowMajorWorker = WorkerGridFactory.genericWorker(matmulQGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
-
- // KV matmul worker (GQA: reduced KV heads)
- int matmulKVGlobal = nEmbdGqa * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid matmulKVRowMajorWorker = WorkerGridFactory.genericWorker(matmulKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
-
- // Current embedding head worker
- WorkerGrid curWorker = WorkerGridFactory.createRmsNormWorker(nEmbdHead, 128);
-
- // Q current worker
- WorkerGrid qCurWorker = WorkerGridFactory.genericWorker(config.numberOfHeads() * nEmbdHead, nEmbdHead);
-
- // K current worker
- WorkerGrid kCurWorker = WorkerGridFactory.genericWorker(config.numberOfKeyValueHeads() * nEmbdHead, nEmbdHead);
-
- // RoPE worker (2D: heads x embedding_head/2)
- int ic = nEmbdHead / 2;
WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), nEmbdHead);
-
- // Copy to cache worker
- WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(nEmbdGqa, 128);
-
// Parallel attention worker
WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), nEmbdHead);
-
- // Matmul1 worker (output projection)
+ // attn_output_proj worker (output projection)
int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC);
-
// FFN workers
int fusedFFNW1W3Global = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid fusedFFNW1W3Worker = WorkerGridFactory.genericWorker(fusedFFNW1W3Global, LOCAL_WORK_GROUP_SIZE_ALLOC);
int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid projectionTwoWorker = WorkerGridFactory.genericWorker(projectionTwoGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ int qkRmsNormGroups = config.numberOfHeads() + config.numberOfKeyValueHeads();
+ WorkerGrid qkRmsNormWorker = WorkerGridFactory.genericWorker(qkRmsNormGroups * nEmbdHead, nEmbdHead);
- // Map workers to tasks for each layer
- for (int i = 0; i < config.numberOfLayers(); i++) {
- gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
-
- gridScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker);
-
- gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker);
-
- gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker);
-
- gridScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker);
+ int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads();
+ int kvDim0 = nEmbdGqa;
+ int fusedQKVRows = qDim0 + 2 * kvDim0; // Q rows + K rows + V rows
+ int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
- gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker);
- gridScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker);
+ // Map workers to tasks for each layer (in task execution order)
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ // === Attention Block ===
+ gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".qk_rmsnorm", qkRmsNormWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker);
+ // === FFN Block ===
+ gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
+ if (shouldUseFinalNormalization()) {
+ gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_finalize", rmsNormWorker);
+ }
+ gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker);
+ gridScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", projectionTwoWorker);
}
-
return gridScheduler;
}
@@ -152,113 +125,273 @@ public List getFfnLayerTaskGraphs() {
* Setup all FFN layers for all transformer layers
*/
List setupFFNLayered() {
- List ffnGraphs = new ArrayList<>();
- qwen3State.temp.init(0.0f);
- qwen3State.tempFFN.init(0.0f);
- qwen3State.tempQcur.init(0.0f);
- qwen3State.tempKcur.init(0.0f);
-
- for (int layerIndex = 0; layerIndex < qwen3Config.numberOfLayers(); layerIndex++) {
- TaskGraph ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, layerIndex);
- if (layerIndex == qwen3Config.numberOfLayers() - 1) {
+ return IntStream.range(0, qwen3Config.numberOfLayers()).mapToObj(i -> {
+ var ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, i);
+ if (i == qwen3Config.numberOfLayers() - 1) {
setupLastID(ffnLayer.getTaskGraphName());
}
- ffnGraphs.add(ffnLayer.snapshot());
- }
-
- return ffnGraphs;
+ return ffnLayer.snapshot();
+ }).toList();
}
+ // @formatter:off
/**
- * Setup a single transformer layer for Qwen3 with GQA
+ * Transformer Layer Task Flow (Qwen3FP16FFNLayers)
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ * ATTENTION BLOCK
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * wrapX (FP32)
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ attn_rms_reduce │──▶ temp (scale factor for RMSNorm)
+ * └────────┬────────┘
+ * │
+ * ▼
+ * ┌─────────────────────────┐
+ * │ attn_rms_qkv_projection │──▶ wrapQ, wrapK, wrapV (FP32)
+ * └───────────┬─────────────┘ (fused: RMS apply + Q/K/V matmuls)
+ * │
+ * ▼
+ * ┌─────────────┐
+ * │ qk_rmsnorm │──▶ wrapQ, wrapK normalized in-place
+ * └──────┬──────┘ (fused: Q + K RMSNorm reduction + apply)
+ * │
+ * ▼
+ * ┌───────────────────┐ ┌─────────────────────────────────────┐
+ * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │
+ * └─────────┬─────────┘ └─────────────────────────────────────┘
+ * │ (fused: RoPE rotation + cache write)
+ * ▼
+ * ┌───────────┐
+ * │ attention │──▶ wrapXb (attention output)
+ * └─────┬─────┘
+ * │
+ * ▼
+ * ┌──────────────────┐
+ * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection)
+ * └────────┬─────────┘
+ * │
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │ FFN BLOCK
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │
+ * ▼
+ * ┌────────────────┐
+ * │ ffn_rms_reduce │──▶ tempFFN (scale factor)
+ * └───────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌──────────────────┐
+ * │ ffn_rms_finalize │──▶ tempFFN (final scale)
+ * └────────┬─────────┘
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3)
+ * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU)
+ * │
+ * ▼
+ * ┌──────────────┐
+ * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection)
+ * └──────┬───────┘
+ * │
+ * ▼
+ * wrapX (FP32) ──▶ [next layer or logits]
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * Task Count: 9 tasks (NVIDIA) / 10 tasks (non-NVIDIA)
+ *
+ * Data Flow Summary:
+ * Input: wrapX (FP32) - hidden state from previous layer
+ * Output: wrapX (FP32) - updated hidden state with residual connections
+ *
+ * Key Fusion Points (vs baseline 18 tasks):
+ * • attn_rms_qkv_projection: Fused RMS apply + Q/K/V matmuls (4→1 kernel)
+ * • qk_rmsnorm: Fused Q + K RMSNorm (4→1 kernel)
+ * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel)
+ * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel)
+ *
+ * Qwen3-Specific:
+ * • GQA: nHeads (Q) != nHeadKv (K/V), with gqa = nHeads / nHeadKv
+ * • Q/K RMSNorm: Additional normalization after QKV projection (qk_rmsnorm)
+ * • RoPE theta: 1,000,000 (vs Llama's 10,000 or 50,000)
+ *
*/
TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) {
var taskGraphName = "layer_" + layerIndex;
- TaskGraph unifiedLayer = new TaskGraph(taskGraphName);
+
+ // === Dimension Parameters ===
+ int qDim = nEmbdHeadK * qwen3Config.numberOfHeads(); // Q output size (full heads)
+ int kvDim = nEmbdGqa; // K/V output size (reduced for GQA)
+ int inputDim = qwen3Config.dim(); // Model dimension
+
+ var unifiedLayer = new TaskGraph(taskGraphName);
+
+ // === Data Setup ===
unifiedLayer.consumeFromDevice(state.wrapX);
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
- //Copy-in weights per layer for batched-layered layout
- weights.rms_att_weightLayered[layerIndex].asFloatArray(), //
- weights.wqLayered[layerIndex].asHalfFloatArray(), //
- weights.wkLayered[layerIndex].asHalfFloatArray(), //
- weights.wvLayered[layerIndex].asHalfFloatArray(), //
- weights.woLayered[layerIndex].asHalfFloatArray(), //
- //rms_att_KNormLayered
- weights.rms_att_KNormLayered[layerIndex].asFloatArray(), //
- //rms_att_QNormLayered
- weights.rms_att_QNormLayered[layerIndex].asFloatArray(), //
- weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), //
- weights.w1Layered[layerIndex].asHalfFloatArray(), //
- weights.w2Layered[layerIndex].asHalfFloatArray(), //
- weights.w3Layered[layerIndex].asHalfFloatArray() //
- );
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ // Attention weights
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS norm weights
+ weights.wqLayered[layerIndex].asHalfFloatArray(), // Q projection
+ weights.wkLayered[layerIndex].asHalfFloatArray(), // K projection
+ weights.wvLayered[layerIndex].asHalfFloatArray(), // V projection
+ weights.woLayered[layerIndex].asHalfFloatArray(), // Output projection
+ // Qwen3-specific Q/K norm weights
+ weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K RMSNorm weights
+ weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q RMSNorm weights
+ // FFN weights
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // FFN RMS norm weights
+ weights.w1Layered[layerIndex].asHalfFloatArray(), // FFN gate
+ weights.w2Layered[layerIndex].asHalfFloatArray(), // FFN down
+ weights.w3Layered[layerIndex].asHalfFloatArray()); // FFN up
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
- unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.temp, qwen3State.wrapX, // in
- qwen3Config.dim(), qwen3Config.rmsNormEps(), qwen3State.localSize).task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen3State.wrapXb, // out
- qwen3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen3State.temp);
- int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads();
- int kvDim0 = nEmbdGqa;
- int qkvDim1 = qwen3Config.dim();
- unifiedLayer.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapQ, // output
- weights.wqLayered[layerIndex].asHalfFloatArray(), qkvDim1, qDim0, LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapK, // output
- weights.wkLayered[layerIndex].asHalfFloatArray(), qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapV, // output
- weights.wvLayered[layerIndex].asHalfFloatArray(), qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC);
-
- // Qcur rmsnorm
- unifiedLayer.task("rmsnormReduction_Qcur", Qwen3Kernels::rmsnormWithParallelOffset, context, qwen3State.tempQcur, // output
- qwen3State.wrapQ, // input
- qwen3State.localSize, // currently 128, should be variable of global nEmbHead
- nEmbdHead, // for normalization
- qwen3Config.rmsNormEps()) // for normalization
- .task("rmsnormMapIndexInPlace_Qcur", Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, context, qwen3State.wrapQ, // output
- weights.rms_att_QNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempQcur);
-
- // Kcur rmsnorm
- unifiedLayer.task("rmsnormReduction_Kcur", Qwen3Kernels::rmsnormWithParallelOffset, context, qwen3State.tempKcur, // output
- qwen3State.wrapK, // input
- qwen3State.localSize, // currently 128, should be variable of global nEmbHead
- nEmbdHead, // for normalization
- qwen3Config.rmsNormEps()) // for normalization
- .task("rmsnormMapIndexInPlace_Kcur", Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, context, qwen3State.wrapK, // output
- weights.rms_att_KNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempKcur);
-
- // rope rotation task graph
- unifiedLayer.task("ropeRotation", Qwen3Kernels::ropeRotation, context, qwen3State.positionHolder, qwen3State.wrapQ, // out
- qwen3State.wrapK, // out
- qwen3Config.numberOfKeyValueHeads(), nEmbdHead);
-
- unifiedLayer.task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, qwen3State.wrapKeyCache, // out
- qwen3State.wrapK, // in
- qwen3State.wrapValueCache, // out
- qwen3State.wrapV, // in
- qwen3State.positionHolder, nEmbdGqa, layerIndex, qwen3Config.contextLength());
-
- unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, context, qwen3State.wrapQ, qwen3State.wrapKeyCache, qwen3State.wrapValueCache,
- qwen3State.wrapXb, // out
- qwen3Config.numberOfHeads(), nEmbdHead, nEmbdGqa, gqa, qwen3State.positionHolder, layerIndex, qwen3Config.contextLength());
-
- unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapXb, // vector
- qwen3State.wrapX, // out, should be [1024]
- weights.woLayered[layerIndex].asHalfFloatArray(), // matrix
- nEmbdHeadK * qwen3Config.numberOfHeads(), // dim1 = 2048
- qwen3Config.dim(), // dim0 = 1024
+ // ═══════════════════════════════════════════════════════════════════════
+ // ATTENTION BLOCK
+ // ═══════════════════════════════════════════════════════════════════════
+
+ // RMS Normalization - compute scale factor
+ unifiedLayer.task("attn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context,
+ qwen3State.temp, // output: scale factor
+ qwen3State.wrapX, // input: hidden state
+ qwen3Config.dim(), // dimension
+ qwen3Config.rmsNormEps(), // epsilon
+ qwen3State.localSize); // local memory size
+
+ // Fused RMS Apply + QKV Projection
+ unifiedLayer.task("attn_rms_qkv_projection",
+ Qwen3Kernels::fusedRmsNormQKVMatmul,
+ context,
+ qwen3State.wrapX, // input: raw hidden state (FP32)
+ qwen3State.wrapQ, // output: Q vectors
+ qwen3State.wrapK, // output: K vectors
+ qwen3State.wrapV, // output: V vectors
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights
+ qwen3State.temp, // RMS scale factor from reduction
+ weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq [qDim x inputDim]
+ weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk [kvDim x inputDim]
+ weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv [kvDim x inputDim]
+ inputDim, // input dimension
+ qDim, // Q output dimension
+ kvDim, // K/V output dimension (GQA: reduced)
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Fused Q/K RMSNorm (Qwen3-specific)
+ unifiedLayer.task("qk_rmsnorm",
+ Qwen3Kernels::fusedQKRmsNorm,
+ context,
+ qwen3State.wrapQ, // Q vectors (in/out)
+ qwen3State.wrapK, // K vectors (in/out)
+ weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q norm weights
+ weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K norm weights
+ qwen3Config.numberOfHeads(), // nHeads (Q heads)
+ qwen3Config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA)
+ nEmbdHead, // head dimension
+ nEmbdHead, // local memory size
+ qwen3Config.rmsNormEps()); // epsilon
+
+ // Fused RoPE Rotation + KV Cache Write
+ unifiedLayer.task("rope_and_kv_cache",
+ Qwen3Kernels::ropeRotationWithCacheCopy,
+ context,
+ qwen3State.positionHolder, // current position
+ qwen3State.wrapQ, // Q vectors (in/out, rotated)
+ qwen3State.wrapK, // K vectors (in/out, rotated)
+ qwen3State.wrapV, // V vectors (in only)
+ qwen3State.wrapKeyCache, // key cache (out)
+ qwen3State.wrapValueCache, // value cache (out)
+ qwen3Config.numberOfKeyValueHeads(), // nHeadKv
+ nEmbdHead, // head dimension
+ nEmbdGqa, // kvDim
+ layerIndex, // layer index for cache offset
+ qwen3Config.contextLength()); // max sequence length
+
+ // Flash Attention
+ unifiedLayer.task("attention",
+ TransformerComputeKernelsLayered::processHeadsFlashAttention,
+ context,
+ qwen3State.wrapQ, // query vectors
+ qwen3State.wrapKeyCache, // key cache
+ qwen3State.wrapValueCache, // value cache
+ qwen3State.wrapXb, // output: attention result
+ qwen3Config.numberOfHeads(), // nHeads
+ nEmbdHead, // headSize
+ nEmbdGqa, // kvDim
+ gqa, // kvMul (nHeads / nHeadKv)
+ qwen3State.positionHolder, // position
+ layerIndex, // layer index
+ qwen3Config.contextLength()); // context length
+
+ // Output Projection with Residual
+ unifiedLayer.task("attn_output_proj",
+ TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
+ context,
+ qwen3State.wrapXb, // input: attention output
+ qwen3State.wrapX, // output: wrapX += Wo · wrapXb
+ weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim x qDim]
+ nEmbdHeadK * qwen3Config.numberOfHeads(), // input dim (qDim)
+ qwen3Config.dim(), // output dim
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // ═══════════════════════════════════════════════════════════════════════
+ // FFN BLOCK
+ // ═══════════════════════════════════════════════════════════════════════
+
+ // RMS Normalization - compute scale factor
+ unifiedLayer.task("ffn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context,
+ qwen3State.tempFFN, // output: scale factor
+ qwen3State.wrapX, // input: hidden state
+ qwen3Config.dim(), // dimension
+ qwen3Config.rmsNormEps(), // epsilon
+ qwen3State.localSize); // local memory size
+
+ // Final normalization (non-NVIDIA only)
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("ffn_rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context,
+ qwen3State.tempFFN, // scale factor (in/out)
+ qwen3Config.dim(), // dimension
+ qwen3Config.rmsNormEps()); // epsilon
+ }
+
+ // Fused RMS Apply + Gate/Up Projection + SiLU + GLU
+ unifiedLayer.task("rms_ffn_gate_up",
+ TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp,
+ context,
+ qwen3State.wrapX, // input: raw hidden state (FP32)
+ qwen3State.wrapHb, // output: SiLU(x·W1) ⊙ (x·W3)
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights
+ qwen3State.tempFFN, // RMS scale factor
+ weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 (gate)
+ weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 (up)
+ qwen3Config.dim(), // input dimension
+ qwen3Config.hiddenDim(), // hidden dimension
LOCAL_WORK_GROUP_SIZE_ALLOC);
- unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.tempFFN, qwen3State.wrapX, qwen3Config.dim(),
- qwen3Config.rmsNormEps(), qwen3State.localSize)
- .task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, qwen3Config.dim(), qwen3Config.rmsNormEps())
- .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
- qwen3State.tempFFN);
+ // Down Projection with Residual
+ unifiedLayer.task("ffn_down_proj",
+ TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
+ context,
+ qwen3State.wrapHb, // input: FFN intermediate
+ qwen3State.wrapX, // output: wrapX += W2 · wrapHb
+ weights.w2Layered[layerIndex].asHalfFloatArray(), // W2 (down)
+ qwen3Config.hiddenDim(), // input dim
+ qwen3Config.dim(), // output dim
+ LOCAL_WORK_GROUP_SIZE_ALLOC)
+ .persistOnDevice(qwen3State.wrapX);
- unifiedLayer.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen3State.wrapXb, qwen3State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(),
- weights.w3Layered[layerIndex].asHalfFloatArray(), qwen3Config.dim(), qwen3Config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapHb, qwen3State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(),
- qwen3Config.hiddenDim(), qwen3Config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(qwen3State.wrapX);
return unifiedLayer;
}
+ // @formatter:on
/**
* Configure data transfers for first and subsequent layers
@@ -266,14 +399,14 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
if (layerIndex == 0) {
// First layer: Transfer temporary buffers and QKV state every execution
- unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN);
- unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.tempQcur, qwen3State.tempKcur);
+ unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.positionHolder);
+ unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.temp, qwen3State.tempFFN);
// First execution: allocate workspace buffers
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
context, qwen3State.wrapXb, qwen3State.wrapXb2, //
qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, //
qwen3State.wrapKeyCache, qwen3State.wrapValueCache, //
- qwen3State.wrapAtt, qwen3State.wrapHb);
+ qwen3State.wrapAtt, qwen3State.wrapHb );
} else {
// Subsequent layers: Consume data from previous layer
unifiedLayer.consumeFromDevice(context, qwen3State.wrapXb, qwen3State.wrapXb2, //
@@ -282,9 +415,8 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
qwen3State.wrapValueCache, qwen3State.wrapAtt, //
qwen3State.wrapHb, qwen3State.positionHolder); //
- unifiedLayer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur);
}
return unifiedLayer;
}
-
+ // @formatter:on
}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
index 385b626f..ba1b6a79 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
@@ -42,25 +42,111 @@ public ImmutableTaskGraph getImmutableTaskGraph() {
}
List setupFFNLayered() {
- state.temp.init(0.0f);
- state.tempFFN.init(0.0f);
- var numLayers = config.numberOfLayers();
-
- return IntStream.range(0, numLayers).mapToObj(i -> {
+ return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> {
var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i);
- if (i == numLayers - 1) {
+ if (i == config.numberOfLayers() - 1) {
setupLastID(ffnLayer.getTaskGraphName());
}
return ffnLayer.snapshot();
}).toList();
}
+ // @formatter:off
+ /**
+ * Transformer Layer Task Flow (LlamaQ8FFNLayers)
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ * ATTENTION BLOCK
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * wrapX (FP32)
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ attn_rms_reduce │──▶ temp (partial sums)
+ * └────────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌──────────────────┐
+ * │ attn_rms_finalize│──▶ temp (final scale)
+ * └────────┬─────────┘
+ * │
+ * ▼
+ * ┌────────────────┐
+ * │ attn_rms_apply │──▶ wrapXb (normalized, FP32)
+ * └───────┬────────┘
+ * │
+ * ▼
+ * ┌────────────────┐ ┌─────────────────────────────┐
+ * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │
+ * └───────┬────────┘ └─────────────────────────────┘
+ * │
+ * ▼
+ * ┌───────────────────┐ ┌─────────────────────────────────────┐
+ * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │
+ * └─────────┬─────────┘ └─────────────────────────────────────┘
+ * │
+ * ▼
+ * ┌───────────┐
+ * │ attention │──▶ wrapXb (attention output)
+ * └─────┬─────┘
+ * │
+ * ▼
+ * ┌──────────────────┐
+ * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection)
+ * └────────┬─────────┘
+ * │
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │ FFN BLOCK
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │
+ * ▼
+ * ┌────────────────┐
+ * │ ffn_rms_reduce │──▶ tempFFN (partial sums)
+ * └───────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌─────────────────┐
+ * │ ffn_rms_finalize│──▶ tempFFN (final scale)
+ * └────────┬────────┘
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3)
+ * └────────┬────────┘ (fully fused: RMS reduce/apply + W1/W3 matmuls + SiLU + GLU)
+ * │
+ * ▼
+ * ┌──────────────┐
+ * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection)
+ * └──────┬───────┘
+ * │
+ * ▼
+ * wrapX (FP32) ──▶ [next layer or logits]
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps)
+ *
+ * Data Flow Summary:
+ * Input: wrapX (FP32) - hidden state from previous layer
+ * Output: wrapX (FP32) - updated hidden state with residual connections
+ *
+ * Key Fusion Points:
+ * • qkv_projection: Fused Q/K/V matmuls with Q8 dequantization (3→1 kernel)
+ * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel)
+ * • rms_ffn_gate_up: Fully fused RMS norm + W1/W3 matmuls + SiLU + GLU (5→1 kernel)
+ *
+ * Quantization: Q8_0 format (8-bit weights with block-wise scaling)
+ *
+ */
TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) {
var layerTaskGraphName = "layer_" + layerIndex;
TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName);
+
+ // === Data Setup ===
unifiedLayer.consumeFromDevice(state.wrapX);
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
- //Copy-in weights per layer for batched-layered layout
+ // Copy-in weights per layer for batched-layered layout (Q8 format)
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
weights.wqLayered[layerIndex].asByteArray(),
weights.wkLayered[layerIndex].asByteArray(),
@@ -71,35 +157,101 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
weights.w2Layered[layerIndex].asByteArray(),
weights.w3Layered[layerIndex].asByteArray());
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
- unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
- if (shouldUseFinalNormalization()) {
- unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp,
- config.dim(), config.rmsNormEps());
- }
- unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp)
- .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapQ,
- weights.wqLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapK,
- weights.wkLayered[layerIndex].asByteArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapV,
- weights.wvLayered[layerIndex].asByteArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize())
- .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(),
- layerIndex, config.contextLength());
- configureAttention(unifiedLayer, layerIndex);
- unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapXb, state.wrapX,
- weights.woLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
- if (shouldUseFinalNormalization()) {
- unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN,
- config.dim(), config.rmsNormEps());
- }
- unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN)
- .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte, context, state.wrapXb, state.wrapHb,
- weights.w1Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray(), config.dim(), config.hiddenDim(),
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapHb, state.wrapX,
- weights.w2Layered[layerIndex].asByteArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX);
+
+ // === Attention Block ===
+ // RMS Normalization
+ unifiedLayer.task("attn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, state.temp, state.wrapX,
+ config.dim(), config.rmsNormEps(), state.localSize);
+
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("attn_rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, state.temp, config.dim(), config.rmsNormEps());
+ }
+
+ unifiedLayer.task("attn_rms_apply",
+ TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
+ context, state.wrapXb, state.wrapX,
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);
+
+ // QKV Projection (fused with Q8 dequantization)
+ unifiedLayer.task("qkv_projection",
+ TransformerComputeKernelsLayered::fusedQKVMatmulQ8,
+ context,
+ state.wrapXb, // input (FP32)
+ state.wrapQ, // output Q
+ state.wrapK, // output K
+ state.wrapV, // output V
+ weights.wqLayered[layerIndex].asByteArray(), // Wq (Q8)
+ weights.wkLayered[layerIndex].asByteArray(), // Wk (Q8)
+ weights.wvLayered[layerIndex].asByteArray(), // Wv (Q8)
+ config.dim(), // dim
+ config.kvDim(), // kvDim
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // RoPE + KV Cache
+ unifiedLayer.task("rope_and_kv_cache",
+ TransformerComputeKernelsLayered::ropeRotationWithCacheCopy,
+ context,
+ state.positionHolder,
+ state.wrapQ, // Q (in/out)
+ state.wrapK, // K (in/out)
+ state.wrapV, // V (in only)
+ state.wrapKeyCache, // Key cache (out)
+ state.wrapValueCache, // Value cache (out)
+ config.kvDim(),
+ config.headSize(),
+ layerIndex,
+ config.contextLength());
+
+ // Attention
+ configureAttention(unifiedLayer, layerIndex);
+
+ // Output Projection (Wo) with residual (Q8 dequantization)
+ unifiedLayer.task("attn_output_proj",
+ TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte,
+ context, state.wrapXb, state.wrapX,
+ weights.woLayered[layerIndex].asByteArray(),
+ config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // === FFN Block ===
+ // RMS Normalization
+ unifiedLayer.task("ffn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, state.tempFFN, state.wrapX,
+ config.dim(), config.rmsNormEps(), state.localSize);
+
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("ffn_rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, state.tempFFN, config.dim(), config.rmsNormEps());
+ }
+
+ // Fully fused: RMS apply + Gate/Up projections + SiLU + GLU (Q8 dequantization)
+ unifiedLayer.task("rms_ffn_gate_up",
+ TransformerComputeKernelsLayered::fullyFusedRmsNormFFNGateUpQ8,
+ context,
+ state.wrapX, // raw input (FP32)
+ state.wrapHb, // output
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights
+ weights.w1Layered[layerIndex].asByteArray(), // W1 (Q8)
+ weights.w3Layered[layerIndex].asByteArray(), // W3 (Q8)
+ config.dim(), // input dimension
+ config.hiddenDim(), // output dimension
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Down projection (W2) with residual (Q8 dequantization)
+ unifiedLayer.task("ffn_down_proj",
+ TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte,
+ context, state.wrapHb, state.wrapX,
+ weights.w2Layered[layerIndex].asByteArray(),
+ config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Keep activation X on device for next layer
+ unifiedLayer.persistOnDevice(state.wrapX);
+
return unifiedLayer;
}
@@ -107,15 +259,20 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
// First layer: Transfer initial data to device (one-time transfer)
if (layerIndex == 0) {
// Transfer all attention-related data: query, key, value matrices and their caches
- unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); //
+ unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
+ state.positionHolder,
+ state.temp, state.tempFFN); //
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
- context, state.wrapXb, state.wrapXb2, //
+ context,
+ state.wrapXb, state.wrapXb2, //
state.wrapQ, state.wrapK, state.wrapV, //
state.wrapKeyCache, state.wrapValueCache, //
state.wrapAtt, state.wrapHb); //
} else {
// Subsequent layers: Consume data already on device from previous layer
- unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, //
+ unifiedLayer.consumeFromDevice(
+ context,
+ state.wrapXb, state.wrapXb2, //
state.wrapQ, state.wrapK, state.wrapV, //
state.wrapKeyCache, state.wrapValueCache, //
state.wrapAtt, state.wrapHb, //
@@ -127,35 +284,42 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
@Override
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
- WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128);
+ // === Worker Grid Definitions ===
WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
- int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
-
int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Fused QKV: dim rows for Q + kvDim rows for K + kvDim rows for V
+ int fusedQkvGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512);
+
WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize());
- WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128);
+ // === Per-Layer Grid Assignments (ordered by task graph flow) ===
for (int i = 0; i < config.numberOfLayers(); i++) {
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
+ // --- Attention Block ---
+ // RMS Normalization
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker);
+ // --- FFN Block ---
+ // RMS Normalization
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
+ // Fused RMS + Gate/Up Projections
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker);
+ // Down Projection
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker);
}
+
return tornadoForwardScheduler;
}
@@ -165,15 +329,25 @@ public List getFfnLayerTaskGraphs() {
private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) {
if (schedulerType == SchedulerType.NVIDIA) {
- return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention,
- context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
- config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(),
- state.positionHolder, layerIndex, config.contextLength());
+ return unifiedLayer.task("attention",
+ TransformerComputeKernelsLayered::processHeadsFlashAttention,
+ context,
+ state.wrapQ, state.wrapKeyCache,
+ state.wrapValueCache, state.wrapXb,
+ config.numberOfHeads(), config.headSize(),
+ config.kvDim(), config.kvMul(),
+ state.positionHolder, layerIndex,
+ config.contextLength());
} else {
- return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel,
- state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
- config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(),
- state.positionHolder, state.wrapAtt, layerIndex, config.contextLength());
+ return unifiedLayer.task("attention",
+ TransformerComputeKernelsLayered::processHeadsParallel,
+ state.wrapQ, state.wrapKeyCache,
+ state.wrapValueCache, state.wrapXb,
+ config.numberOfHeads(), config.headSize(),
+ config.kvDim(), config.kvMul(), config.contextLength(),
+ state.positionHolder, state.wrapAtt, layerIndex,
+ config.contextLength());
}
}
+ // @formatter:on
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
index c8c5a753..d54bb3ef 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
@@ -2,8 +2,8 @@
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
-import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
@@ -13,12 +13,9 @@
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
-import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-import java.util.SequencedCollection;
-
public class LogitsQ8_0Layer extends AbstractLayer {
private String lastTaskGraphID;
@@ -30,7 +27,6 @@ public class LogitsQ8_0Layer extends AbstractLayer {
public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) {
super(taskGraphName, state, weights, config);
this.lastTaskGraphID = lastTaskGraphID;
- state.tempLogits.init(0.0f);
var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor");
this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config);
this.schedulerType = schedulerType;
@@ -38,39 +34,68 @@ public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Confi
@Override
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
- WorkerGrid logitsRMS;
- if (weights instanceof Qwen2TornadoWeights) {
- logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32);
- } else {
- logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
- }
-
+ var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256);
var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
- WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
+ var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
-
- tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
- tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS);
+ tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker);
+ tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS);
tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS);
return tornadoForwardScheduler;
}
+ // @formatter:off
private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
- TaskGraph logits = new TaskGraph("logits");
- logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits)
- .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asByteArray(),
- weights.rms_final_weight_as_floatArray)
- .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
- if (schedulerType == SchedulerType.NON_NVIDIA) {
- logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps());
- }
- logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits)
- .task("projection", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, //
- context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asByteArray(), //
- config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS) //
- .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
+ var logits = new TaskGraph("logits");
+ // === Data Setup ===
+ logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
+ logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
+ logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ context, //
+ state.wrapLogits, //
+ weights.wclsByteArray.asByteArray(), //
+ weights.rms_final_weight_as_floatArray);
+
+ // === Final RMS Normalization ===
+ logits.task("rms_reduce",
+ TransformerComputeKernels::reductionOneBlockWithLayer,
+ context,
+ state.tempLogits, // output: partial sums + final scale factor
+ state.wrapX, // input: hidden state
+ config.dim(), // dimension
+ config.rmsNormEps(), // epsilon for numerical stability
+ state.localSize); // local workgroup size
+
+ if (schedulerType == SchedulerType.NON_NVIDIA) {
+ logits.task("rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context,
+ state.tempLogits,
+ config.dim(),
+ config.rmsNormEps());
+ }
+ logits.task("mapContextLogits",
+ TransformerComputeKernels::reductionOneBlock2WithLogits,
+ context,
+ state.wrapX,
+ weights.rms_final_weight_as_floatArray.asFloatArray(),
+ state.tempLogits);
+
+ // === Vocabulary vocab_proj ===
+ logits.task("vocab_proj", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, //
+ context,
+ state.wrapX,
+ state.wrapLogits,
+ weights.wclsByteArray.asByteArray(),
+ config.dim(),
+ config.vocabularySize(),
+ LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS);
+
+ // === Transfer Results to Host ===
+ logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
return logits;
}
+ // @formatter:on
@Override
public GridScheduler getGridScheduler() {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
index 441bbece..3588af9e 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
@@ -69,6 +69,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128);
WorkerGrid splitGateUpSiLUWorker = WorkerGridFactory.genericWorker(config.hiddenDim(), 128);
WorkerGrid splitQKVWorker = WorkerGridFactory.genericWorker(opSize, 128);
+
for (int i = 0; i < config.numberOfLayers(); i++) {
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", qkvDimRowMajorGlobalWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);