diff --git a/llama-tornado b/llama-tornado index 9c0d6ba8..c98090f8 100755 --- a/llama-tornado +++ b/llama-tornado @@ -422,7 +422,7 @@ def create_parser() -> argparse.ArgumentParser: ) debug_group.add_argument( "--profiler-dump-dir", - default="/home/mikepapadim/repos/gpu-llama3.java/prof.json", + default=None, help="Directory for profiler output", ) @@ -498,6 +498,11 @@ def main(): parser = create_parser() args = parser.parse_args() + # Set default profiler log path relative to LLAMA_ROOT + if args.profiler_dump_dir is None: + llama_root = os.environ.get("LLAMA_ROOT") + args.profiler_dump_dir = os.path.join(llama_root, "profiler-log.json") + # Set default seed if not provided if args.seed is None: args.seed = int(time.time()) diff --git a/pom.xml b/pom.xml index 814df7f5..a59830ef 100644 --- a/pom.xml +++ b/pom.xml @@ -54,12 +54,12 @@ io.github.beehive-lab tornado-api - 2.0.1-dev + 2.1.0 io.github.beehive-lab tornado-runtime - 2.0.1-dev + 2.1.0 diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index a24b9626..21344223 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -64,6 +64,8 @@ protected StateFields createStateFields(Configuration config) { fields.wrapK = new FloatArray(config.dim()); fields.wrapV = new FloatArray(config.dim()); + fields.wrapXFP16 = new HalfFloatArray(config.dim()); + fields.wrapXbFP16 = new HalfFloatArray(config.dim()); // dim vs kvdim fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java index 2ae4d269..79115922 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java @@ -87,6 +87,8 @@ protected StateFields createStateFields(Configuration config) { } fields.wrapX = new FloatArray(dim); fields.wrapXb = new FloatArray(dim); + fields.wrapXFP16 = new HalfFloatArray(dim); + fields.wrapXbFP16 = new HalfFloatArray(dim); fields.wrapXb2 = new FloatArray(dim); fields.wrapHb = new FloatArray(2 * hiddenDim); fields.wrapHb2 = new FloatArray(hiddenDim); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java index 23939bba..266730ac 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java @@ -48,6 +48,7 @@ protected StateFields createStateFields(Configuration configuration) { } fields.wrapX = new FloatArray(config.dim()); fields.wrapXb = new FloatArray(config.dim()); + fields.wrapXbFP16 = new HalfFloatArray(config.dim()); fields.wrapXb2 = new FloatArray(config.dim()); fields.wrapHb = new FloatArray(config.hiddenDim()); fields.wrapHb2 = new FloatArray(config.hiddenDim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java index 6f90398f..d70625b9 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java @@ -75,6 +75,8 @@ protected StateFields createStateFields(Configuration configuration) { fields.wrapX = new FloatArray(config.dim()); fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads()); + fields.wrapXbFP16 = new HalfFloatArray(nEmbdHeadK * config.numberOfHeads()); + fields.wrapXb2 = new FloatArray(config.dim()); fields.wrapHb = new FloatArray(config.hiddenDim()); fields.wrapHb2 = new FloatArray(config.hiddenDim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index 3f972e1b..f8e9906a 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -4,6 +4,9 @@ import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.*; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; /** * Represents the base state structure used during LLM inference. @@ -58,6 +61,9 @@ public abstract class State { public final IntArray positionHolder; public TornadoNativeArray embeddingX; + + public final HalfFloatArray wrapXbFP16; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage. + // store inter public int localSize; public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size. @@ -65,6 +71,7 @@ public abstract class State { public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size. public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models. + public HalfFloatArray wrapXFP16; /** last index in previous block */ protected State(Configuration config, int batchsize) { @@ -100,6 +107,9 @@ protected State(Configuration config, int batchsize) { this.wrapK = fields.wrapK; this.wrapV = fields.wrapV; + this.wrapXFP16 = fields.wrapXFP16; + this.wrapXbFP16 = fields.wrapXbFP16; + // dim vs kvdim this.wrapKeyCache = fields.wrapKeyCache; this.wrapValueCache = fields.wrapValueCache; @@ -136,6 +146,7 @@ public void createActivationQ8_0(int size) { int q8BytesNeeded = blocksNeeded * Q8_0_BLOCK_BYTES; this.embeddingX = new ByteArray(q8BytesNeeded); } + public HalfFloatArray wrapXFP16, wrapXbFP16; } @Override diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 293d2c0c..eadd2e68 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -1,9 +1,9 @@ package org.beehive.gpullama3.tornadovm; -import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; @@ -133,6 +133,8 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { // Set the position in the state object (used by attention layers) state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); // 2. Execute each transformer layer graph sequentially // Each graph computes attention and feed-forward transformations for one layer @@ -141,7 +143,8 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) .execute(); } - + state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f + state.wrapLogits.clear(); // Clear the output logits tensor -> set to 0f // 3. Execute the final graph that projects the last hidden state to output logits executionPlan.withGraph(getFinalLogitsGraphIndex()) .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) @@ -179,7 +182,7 @@ private int getFinalLogitsGraphIndex() { /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer. public void forceCopyInReadOnlyDataLayered() { // Execute all TornadoVM graphs - state.wrapX.init(0.0f); + state.wrapX.clear(); state.positionHolder.init(0); // Execute activation update graph diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java new file mode 100644 index 00000000..d45a62c4 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java @@ -0,0 +1,380 @@ +package org.beehive.gpullama3.tornadovm.kernels; + +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +/** + * Phi3Kernels: Optimized GPU kernels for Phi3 model family. + * + *

Key differences from Qwen/Llama kernels:

+ * + */ +public class Phi3Kernels { + + /** + * Fused RMSNorm apply + single matrix-vector multiplication. + * + *

Combines RMS normalization application with a generic matmul in one kernel, + * reducing memory bandwidth by avoiding intermediate storage.

+ * + *

Formula: output[row] = sum_j(W[row,j] * rmsWeight[j] * scale * x[j])

+ * + *

Use cases:

+ * + * + * @param context Kernel execution context + * @param x Input hidden state (FP32) [dim] + * @param output Output buffer (FP32) [outputDim] + * @param rmsWeights RMS normalization weights (FP32) [dim] + * @param rmsScale Precomputed RMS scale factor [1] (from reduction kernel) + * @param w Weight matrix (FP16) [outputDim × dim] + * @param inputDim Input dimension (dim) + * @param outputDim Output dimension + * @param localWorkGroupSize Local work group size for reduction + */ + public static void fusedRmsNormMatmul( + KernelContext context, + FloatArray x, // input (FP32) + FloatArray output, // output (FP32) + FloatArray rmsWeights, // RMS norm weights + FloatArray rmsScale, // temp[0] = scale factor + HalfFloatArray w, // weight matrix + int inputDim, // input dimension + int outputDim, // output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= outputDim) { + return; + } + + float scale = rmsScale.get(0); + + // Allocate shared memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + int rowOffset = rowId * inputDim; + + // Each thread computes partial dot product with inline normalization + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum += w.get(rowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + // Thread 0 writes final result + if (localId == 0) { + output.set(rowId, localSum[0]); + } + } + + /** + * Phi3 RoPE rotation with fused KV cache copy. + * + *

Phi3 uses a different RoPE pattern than Llama/Qwen:

+ * + * + *

This fused kernel combines:

+ * + * + * @param context Kernel execution context + * @param positionHolder Current position in sequence [1] + * @param sq Query vectors (in/out, rotated) [dim] + * @param sk Key vectors (in/out, rotated) [kvDim] + * @param sv Value vectors (in only) [kvDim] + * @param keyCache Key cache (out) [layers × contextLength × kvDim] + * @param valueCache Value cache (out) [layers × contextLength × kvDim] + * @param nHeadKv Number of KV heads + * @param headSize Dimension per head + * @param kvDim Total KV dimension (nHeadKv × headSize) + * @param layer Current layer index + * @param contextLength Maximum sequence length + */ + public static void ropeRotationWithCacheCopyPhi3( + KernelContext context, + IntArray positionHolder, + FloatArray sq, // Q vector (in/out) + FloatArray sk, // K vector (in/out) + FloatArray sv, // V vector (in only) + FloatArray keyCache, // Key cache (out) + FloatArray valueCache, // Value cache (out) + int nHeadKv, + int headSize, + int kvDim, + int layer, + int contextLength) { + + int idx = context.globalIdx; + int dimHalf = headSize / 2; + + // Each thread processes one dimension pair + if (idx >= dimHalf) { + return; + } + + int pos = positionHolder.get(0); + int cacheOffset = layer * contextLength * kvDim + pos * kvDim; + + // Calculate frequency for this dimension + float freq = 1.0f / TornadoMath.pow(10000.0f, (float) (idx * 2) / (float) headSize); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Process Q: all heads (dim = nHeads × headSize) + int totalDimQ = sq.getSize(); + for (int base = 0; base < totalDimQ; base += headSize) { + if (base + idx >= totalDimQ || base + idx + dimHalf >= totalDimQ) { + break; + } + + // Rotate Q with offset pattern + float v0 = sq.get(base + idx); + float v1 = sq.get(base + idx + dimHalf); + sq.set(base + idx, v0 * fcr - v1 * fci); + sq.set(base + idx + dimHalf, v0 * fci + v1 * fcr); + } + + // Process K: only kvDim elements, with cache write + for (int base = 0; base < kvDim; base += headSize) { + if (base + idx >= kvDim || base + idx + dimHalf >= kvDim) { + break; + } + + // Rotate K with offset pattern + float k0 = sk.get(base + idx); + float k1 = sk.get(base + idx + dimHalf); + float rotated0 = k0 * fcr - k1 * fci; + float rotated1 = k0 * fci + k1 * fcr; + + // Write rotated K back + sk.set(base + idx, rotated0); + sk.set(base + idx + dimHalf, rotated1); + + // Fused cache write for K + keyCache.set(cacheOffset + base + idx, rotated0); + keyCache.set(cacheOffset + base + idx + dimHalf, rotated1); + + // Fused cache copy for V (no rotation needed) + valueCache.set(cacheOffset + base + idx, sv.get(base + idx)); + valueCache.set(cacheOffset + base + idx + dimHalf, sv.get(base + idx + dimHalf)); + } + } + + /** + * Fused RMSNorm apply + QKV projection with direct output to separate Q, K, V buffers. + * + *

Eliminates the need for a separate splitQKV kernel by routing outputs + * directly based on row index:

+ * + * + *

Formula: output[row] = sum_j(Wqkv[row,j] * rmsWeight[j] * scale * x[j])

+ * + * @param context Kernel execution context + * @param x Input hidden state (FP32) [dim] + * @param q Output Q buffer (FP32) [dim] + * @param k Output K buffer (FP32) [kvDim] + * @param v Output V buffer (FP32) [kvDim] + * @param rmsWeights RMS normalization weights (FP32) [dim] + * @param rmsScale Precomputed RMS scale factor [1] + * @param wqkv Combined QKV weight matrix (FP16) [opSize × dim] + * @param dim Model dimension (Q output size) + * @param kvDim KV dimension (K/V output size) + * @param localWorkGroupSize Local work group size for reduction + */ + public static void fusedRmsNormQKVMatmulDirect( + KernelContext context, + FloatArray x, // input (FP32) + FloatArray q, // output Q (FP32) + FloatArray k, // output K (FP32) + FloatArray v, // output V (FP32) + FloatArray rmsWeights, // RMS norm weights + FloatArray rmsScale, // temp[0] = scale factor + HalfFloatArray wqkv, // combined QKV weight matrix + int dim, // input dim and Q output dim + int kvDim, // K/V output dim + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Total rows = dim (Q) + kvDim (K) + kvDim (V) + int totalRows = dim + 2 * kvDim; + if (rowId >= totalRows) { + return; + } + + float scale = rmsScale.get(0); + + // Allocate shared memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + int rowOffset = rowId * dim; + + // Each thread computes partial dot product with inline normalization + float partialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum += wqkv.get(rowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + // Thread 0 writes to appropriate output buffer + if (localId == 0) { + float result = localSum[0]; + + if (rowId < dim) { + // Q projection: rows [0, dim) + q.set(rowId, result); + } else if (rowId < dim + kvDim) { + // K projection: rows [dim, dim+kvDim) + int kIdx = rowId - dim; + k.set(kIdx, result); + } else { + // V projection: rows [dim+kvDim, dim+2*kvDim) + int vIdx = rowId - dim - kvDim; + v.set(vIdx, result); + } + } + } + /** + * Fused RMSNorm apply + Gate/Up projection + SiLU + GLU in one kernel. + * + *

Eliminates the need for separate gateUpSiLU kernel by computing both + * gate and up projections per workgroup and applying activation inline.

+ * + *

For each output index i:

+ * + * + * @param context Kernel execution context + * @param x Input hidden state (FP32) [dim] + * @param output Output buffer (FP32) [hiddenDim] - final FFN result + * @param rmsWeights RMS normalization weights (FP32) [dim] + * @param rmsScale Precomputed RMS scale factor [1] + * @param wUp Combined gate+up weight matrix (FP16) [2×hiddenDim × dim] + * @param dim Input dimension + * @param hiddenDim Hidden dimension (output size) + * @param localWorkGroupSize Local work group size for reduction + */ + public static void fusedRmsNormFFNGateUpSiLU( + KernelContext context, + FloatArray x, // input (FP32) + FloatArray output, // output (FP32) [hiddenDim] + FloatArray rmsWeights, // RMS norm weights + FloatArray rmsScale, // temp[0] = scale factor + HalfFloatArray wUp, // combined gate+up weights [2×hiddenDim × dim] + int dim, // input dimension + int hiddenDim, // output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= hiddenDim) { + return; + } + + float scale = rmsScale.get(0); + + // Allocate shared memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + // === Compute GATE (row i) === + int gateRowOffset = rowId * dim; + + float gatePartialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + gatePartialSum += wUp.get(gateRowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = gatePartialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + float gateResult = localSum[0]; + + // === Compute UP (row hiddenDim + i) === + int upRowOffset = (hiddenDim + rowId) * dim; + + float upPartialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + upPartialSum += wUp.get(upRowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = upPartialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + float upResult = localSum[0]; + + // === Apply SiLU(gate) × up === + if (localId == 0) { + float silu = gateResult / (1.0f + TornadoMath.exp(-gateResult)); + output.set(rowId, silu * upResult); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java index 930e1774..506d3fdf 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java @@ -4,6 +4,7 @@ import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; // @formatter:off @@ -292,5 +293,376 @@ private static void processHeadTornado( } } + /** + * Fused RoPE rotation with KV cache copy for Qwen3. + * Combines ropeRotation + copyToCache into a single kernel. + */ + public static void ropeRotationWithCacheCopy( + KernelContext context, + IntArray positionHolder, + FloatArray q, // Q vector (in/out) + FloatArray k, // K vector (in/out) + FloatArray v, // V vector (in only) + FloatArray keyCache, // Key cache (out) + FloatArray valueCache, // Value cache (out) + int numberOfKeyValueHeads, + int nEmbdHead, + int nEmbdGqa, + int layer, + int contextLength) { + + int h = context.globalIdx; + int ic = context.globalIdy; + + int pos = positionHolder.get(0); + int rotn = h < numberOfKeyValueHeads ? 2 : 1; + int poffset = h * nEmbdHead; + int nComplEmbdHead = nEmbdHead / 2; + + // Compute RoPE frequencies for Qwen3 (theta = 1000000.0f) + float theta = 1000000.0f; + int i = ic * 2; + float freq = 1.0f / TornadoMath.pow(theta, (float) i / (float) nEmbdHead); + + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q (all heads) + float v0q = q.get(poffset + ic); + float v1q = q.get(poffset + ic + nComplEmbdHead); + q.set(poffset + ic, v0q * fcr - v1q * fci); + q.set(poffset + ic + nComplEmbdHead, v0q * fci + v1q * fcr); + + // Rotate K and copy K/V to cache (only for KV heads) + if (rotn > 1 && (poffset + ic + nComplEmbdHead) < k.getSize()) { + float v0k = k.get(poffset + ic); + float v1k = k.get(poffset + ic + nComplEmbdHead); + float rotatedK0 = v0k * fcr - v1k * fci; + float rotatedK1 = v0k * fci + v1k * fcr; + + // Write rotated K back + k.set(poffset + ic, rotatedK0); + k.set(poffset + ic + nComplEmbdHead, rotatedK1); + + // Direct cache write (fused - no separate copy kernel!) + int cacheOffset = layer * contextLength * nEmbdGqa + pos * nEmbdGqa; + int kvIdx = h * nEmbdHead; + + keyCache.set(cacheOffset + kvIdx + ic, rotatedK0); + keyCache.set(cacheOffset + kvIdx + ic + nComplEmbdHead, rotatedK1); + + // Copy V to cache (V doesn't need rotation) + valueCache.set(cacheOffset + kvIdx + ic, v.get(poffset + ic)); + valueCache.set(cacheOffset + kvIdx + ic + nComplEmbdHead, v.get(poffset + ic + nComplEmbdHead)); + } + } + + /** + * Fused Q/K/V matrix-vector multiplication for Qwen3 GQA. + * Q has full head dimension, K/V have reduced KV head dimension. + * + * Workgroup assignment: + * - rowId [0, qDim): Q projection + * - rowId [qDim, qDim+kvDim): K projection + * - rowId [qDim+kvDim, qDim+2*kvDim): V projection + */ + public static void fusedQKVMatmul( + KernelContext context, + FloatArray x, // input vector + FloatArray q, // output Q + FloatArray k, // output K + FloatArray v, // output V + HalfFloatArray wq, // Q weight matrix + HalfFloatArray wk, // K weight matrix + HalfFloatArray wv, // V weight matrix + int inputDim, // input dimension (config.dim()) + int qDim, // Q output dimension + int kvDim, // KV output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowId < qDim) { + // ========== Q projection ========== + int rowOffset = rowId * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partialSum += wq.get(rowOffset + j).getFloat32() * x.get(j); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + q.set(rowId, localSum[0]); + } + + } else if (rowId < qDim + kvDim) { + // ========== K projection ========== + int kRow = rowId - qDim; + int rowOffset = kRow * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partialSum += wk.get(rowOffset + j).getFloat32() * x.get(j); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + k.set(kRow, localSum[0]); + } + + } else if (rowId < qDim + 2 * kvDim) { + // ========== V projection ========== + int vRow = rowId - qDim - kvDim; + int rowOffset = vRow * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partialSum += wv.get(rowOffset + j).getFloat32() * x.get(j); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + v.set(vRow, localSum[0]); + } + } + } + + /** + * Fused RMSNorm apply + Q/K/V projection for Qwen3 GQA. + * Eliminates intermediate wrapXb buffer write/read. + */ + public static void fusedRmsNormQKVMatmul( + KernelContext context, + FloatArray x, // raw input (FP32) + FloatArray q, // output Q + FloatArray k, // output K + FloatArray v, // output V + FloatArray rmsWeights, // RMS norm weights + FloatArray rmsScale, // temp[0] = scale factor + HalfFloatArray wq, // Q weight matrix + HalfFloatArray wk, // K weight matrix + HalfFloatArray wv, // V weight matrix + int inputDim, // input dimension (config.dim()) + int qDim, // Q output dimension + int kvDim, // KV output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + float scale = rmsScale.get(0); + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowId < qDim) { + // ========== Q projection with inline normalization ========== + int rowOffset = rowId * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum += wq.get(rowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + q.set(rowId, localSum[0]); + } + + } else if (rowId < qDim + kvDim) { + // ========== K projection with inline normalization ========== + int kRow = rowId - qDim; + int rowOffset = kRow * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum += wk.get(rowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + k.set(kRow, localSum[0]); + } + + } else if (rowId < qDim + 2 * kvDim) { + // ========== V projection with inline normalization ========== + int vRow = rowId - qDim - kvDim; + int rowOffset = vRow * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum += wv.get(rowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + v.set(vRow, localSum[0]); + } + } + } + + + /** + * Fused Q and K RMSNorm for Qwen3. + * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. + * + * Workgroup assignment: + * - Workgroups [0, nHeads): Process Q heads + * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads + */ + public static void fusedQKRmsNorm( + KernelContext context, + FloatArray q, // Q vector (in/out) + FloatArray k, // K vector (in/out) + FloatArray qWeights, // Q RMS norm weights + FloatArray kWeights, // K RMS norm weights + int nHeads, // number of Q heads + int nHeadKv, // number of K heads + int nEmbdHead, // head dimension + int localMemSize, // local memory size (must be fixed) + float rmsNormEps) { + + int groupId = context.groupIdx; + int localId = context.localIdx; + int localSize = context.localGroupSizeX; + + // Allocate local memory with FIXED size parameter + float[] localSum = context.allocateFloatLocalArray(localMemSize); + + if (groupId < nHeads) { + // === Process Q head === + int headOffset = groupId * nEmbdHead; + + // Step 1: Compute sum of squares (reduction) + float partialSum = 0.0f; + for (int i = localId; i < nEmbdHead; i += localSize) { + float val = q.get(headOffset + i); + partialSum += val * val; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + // Compute normalization factor + float ss = localSum[0]; + ss = ss / nEmbdHead + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + + context.localBarrier(); + + // Step 2: Apply normalization with weights (in-place) + for (int i = localId; i < nEmbdHead; i += localSize) { + float normalized = ss * q.get(headOffset + i); + q.set(headOffset + i, qWeights.get(i) * normalized); + } + + } else if (groupId < nHeads + nHeadKv) { + // === Process K head === + int headIdx = groupId - nHeads; + int headOffset = headIdx * nEmbdHead; + + // Step 1: Compute sum of squares (reduction) + float partialSum = 0.0f; + for (int i = localId; i < nEmbdHead; i += localSize) { + float val = k.get(headOffset + i); + partialSum += val * val; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + // Compute normalization factor + float ss = localSum[0]; + ss = ss / nEmbdHead + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + + context.localBarrier(); + + // Step 2: Apply normalization with weights (in-place) + for (int i = localId; i < nEmbdHead; i += localSize) { + float normalized = ss * k.get(headOffset + i); + k.set(headOffset + i, kWeights.get(i) * normalized); + } + } + } } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index fa610960..80617400 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -3,10 +3,14 @@ import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.math.TornadoMath; import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + public class TransformerComputeKernels { /** @@ -22,6 +26,25 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) { } } + public static void convertFP32toFP16v2(KernelContext context, FloatArray input, HalfFloatArray output) { + int i = context.globalIdx; + HalfFloat val = new HalfFloat(input.get(i)); + output.set(i,val); + } + + public static void mapContextWithQuantize( + KernelContext context, + HalfFloatArray outputFP16, // Direct FP16 output + FloatArray x, + FloatArray weights, + FloatArray temp) { + + int gid = context.globalIdx; + float ss = temp.get(0); + float result = weights.get(gid) * (ss * x.get(gid)); + outputFP16.set(gid, new HalfFloat(result)); + } + public static void convertFP16toFP32(KernelContext context, HalfFloatArray x, FloatArray wrapX) { int i = context.globalIdx; wrapX.set(i, x.get(i).getFloat32()); @@ -142,4 +165,12 @@ public static void reductionOneBlock2WithLogits(KernelContext context, FloatArra output.set(gid, weights.get(gid) * (ss * output.get(gid))); } + public static void mapContextWithQuantizeLogits(KernelContext context, HalfFloatArray output, FloatArray input, FloatArray weights, FloatArray temp) { + int gid = context.globalIdx; + float ss = temp.get(0); + float in = ss * input.get(gid); + float interim = weights.get(gid) * in; + output.set(gid, new HalfFloat(interim)); + } + } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index ac6619f6..16f8c3b6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -4,7 +4,11 @@ import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; import uk.ac.manchester.tornado.api.types.HalfFloat; -import uk.ac.manchester.tornado.api.types.arrays.*; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.Int8Array; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; public class TransformerComputeKernelsLayered { @@ -14,6 +18,88 @@ public class TransformerComputeKernelsLayered { public TransformerComputeKernelsLayered() { } + public static void fusedQKvBiasAddition(KernelContext context, FloatArray q_out, FloatArray k_out, FloatArray qBias, FloatArray v_out, FloatArray kBias, FloatArray vBias, int dimQ, int dimKV) { + + int gid = context.globalIdx; + + if (gid < dimQ) { + // 1. Add Q bias + q_out.set(gid, q_out.get(gid) + qBias.get(gid)); + + // 2. Conditionally Add K and V Bias + if (gid < dimKV) { + k_out.set(gid, k_out.get(gid) + kBias.get(gid)); + v_out.set(gid, v_out.get(gid) + vBias.get(gid)); + } + } + } + + public static void fusedRmsNormFFNGateUp(KernelContext context, FloatArray x, // raw input (FP32) + FloatArray hb, // output + FloatArray rmsWeights, // RMS norm weights + FloatArray rmsScale, // temp[0] = scale factor + HalfFloatArray w1, HalfFloatArray w3, int dim, // input dimension + int hiddenDim, // output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= hiddenDim) { + return; + } + + float scale = rmsScale.get(0); + + // Allocate shared memory for normalized input (reused for both W1 and W3) + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + int rowOffsetW1 = rowId * dim; + int rowOffsetW3 = rowId * dim; + + // === W1 matmul with inline normalization === + float sum1 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + sum1 += w1.get(rowOffsetW1 + j).getFloat32() * normalized; + } + + localSum[localId] = sum1; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float result1 = localSum[0]; + + // === W3 matmul with inline normalization (same computation) === + float sum3 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + sum3 += w3.get(rowOffsetW3 + j).getFloat32() * normalized; + } + + localSum[localId] = sum3; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float result3 = localSum[0]; + + // === SiLU + GLU === + if (localId == 0) { + float silu = result1 / (1.0f + TornadoMath.exp(-result1)); + hb.set(rowId, silu * result3); + } + } + /** * Performs RMS (Root Mean Square) normalization using parallel reduction. This is the first phase of RMS normalization that computes the variance and scaling factor across all work groups. * @@ -136,6 +222,60 @@ public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, Float } } + /** + * Fused RoPE rotation with KV cache copy. Eliminates separate copyToCaches kernel. + * + * - Rotates Q (full dim) - Rotates K and writes directly to keyCache - Copies V directly to valueCache (no rotation needed) + */ + public static void ropeRotationWithCacheCopy(KernelContext context, IntArray positionHolder, FloatArray sq, // Q vector (in/out) + FloatArray sk, // K vector (in/out) + FloatArray sv, // V vector (in only) + FloatArray keyCache, // Key cache (out) + FloatArray valueCache, // Value cache (out) + int kvDim, int headSize, int layer, int contextLength) { + + int i = context.globalIdx * 2; + int pos = positionHolder.get(0); + + // Bounds check for Q rotation (Q has dim elements, processed in pairs) + if (i + 1 < sq.getSize()) { + // RoPE frequency calculation + int head_dim = i % headSize; + float freq = 1.0f / TornadoMath.pow(50000.0f, head_dim / (float) headSize); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q + float v0q = sq.get(i); + float v1q = sq.get(i + 1); + sq.set(i, v0q * fcr - v1q * fci); + sq.set(i + 1, v0q * fci + v1q * fcr); + + // Rotate K AND write to cache (only for kvDim elements) + if (i + 1 < kvDim) { + float v0k = sk.get(i); + float v1k = sk.get(i + 1); + float rotated0 = v0k * fcr - v1k * fci; + float rotated1 = v0k * fci + v1k * fcr; + + // Write rotated K back to sk + sk.set(i, rotated0); + sk.set(i + 1, rotated1); + + // Direct cache write (fused - no separate copy kernel!) + int cacheOffset = layer * contextLength * kvDim + pos * kvDim; + keyCache.set(cacheOffset + i, rotated0); + keyCache.set(cacheOffset + i + 1, rotated1); + + // Copy V to cache (V doesn't need rotation) + valueCache.set(cacheOffset + i, sv.get(i)); + valueCache.set(cacheOffset + i + 1, sv.get(i + 1)); + } + } + + } + public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArray v, int dimQ, int dimKV) { int totalSize = dimQ + 2 * dimKV; @@ -348,7 +488,7 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray int pos = positionHolder.get(0); int loff = layer * contextLength * kvDim; int kvHeadIdx = h / kvMul; - int BLOCK_SIZE_C = 8; + int BLOCK_SIZE_C = 16; // Allocate shared memory for tiled computation float[] q_shared = context.allocateFloatLocalArray(headSize); @@ -455,6 +595,191 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray } } + public static void processHeadsFlashAttentionOptV2(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, + // NOTE: Still used for logic, but not for allocation size + int kvDim, int kvMul, IntArray positionHolder, int layer, int contextLength) { + + // --- STATIC CONSTANTS FOR OPENCL ALLOCATIONS --- + // These must be large enough to handle the maximum expected values for + // headSize and localSize in your model/hardware setup. + // Assuming Max Head Size is 256 and Max Local Size is 256. + final int MAX_HEAD_SIZE = 256; + final int MAX_LOCAL_SIZE = 256; + final int MAX_BLOCK_SIZE_C = 32; + final int MAX_TILE_ELEMENTS = MAX_BLOCK_SIZE_C * MAX_HEAD_SIZE; + + int tid = context.localIdx; + int h = context.groupIdx; + int localSize = context.localGroupSizeX; + + if (h >= nHeads) { + return; + } + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_SIZE_C = 32; + + // === Shared memory allocations (FIXED: using static sizes) === + // ERROR FIX 1: Use MAX_HEAD_SIZE instead of dynamic headSize + float[] q_shared = context.allocateFloatLocalArray(MAX_HEAD_SIZE); + // ERROR FIX 2: Use MAX_TILE_ELEMENTS instead of BLOCK_SIZE_C * headSize + float[] k_tile = context.allocateFloatLocalArray(MAX_TILE_ELEMENTS); + float[] v_tile = context.allocateFloatLocalArray(MAX_TILE_ELEMENTS); + + float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); // Size is constant (32) + float[] exp_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); // Size is constant (32) + + // ERROR FIX 3: Use MAX_LOCAL_SIZE instead of dynamic localSize + float[] reduction_shared = context.allocateFloatLocalArray(MAX_LOCAL_SIZE); + + float[] state_shared = context.allocateFloatLocalArray(4); // Size is constant (4) + + // === Dimension partitioning: each thread handles subset of output dims === + int dimsPerThread = (headSize + localSize - 1) / localSize; + int myStartDim = tid * dimsPerThread; + int myEndDim = Math.min(myStartDim + dimsPerThread, headSize); + int myDimCount = myEndDim - myStartDim; + + // FIX from previous iteration: ensuring output array is statically sized + final int MAX_OUTPUT_DIMS = MAX_HEAD_SIZE / 8; // e.g., 32 if MAX_HEAD_SIZE=256 + float[] output = new float[MAX_OUTPUT_DIMS]; + + // Initialize thread-local output + for (int i = 0; i < myDimCount; i++) { + output[i] = 0.0f; + } + + // Initialize shared state + if (tid == 0) { + state_shared[0] = Float.NEGATIVE_INFINITY; + state_shared[1] = 0.0f; + } + + // Load query into shared memory (cooperative) + // NOTE: Loop bound must still use headSize to read correct data volume + for (int i = tid; i < headSize; i += localSize) { + q_shared[i] = q.get(h * headSize + i); + } + context.localBarrier(); + + // Process sequence in tiles + for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) { + int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos); + int tileLen = tileEnd - tileC + 1; + + // === Cooperative K/V tile loading === + int totalElements = tileLen * headSize; + int elementsPerThread = (totalElements + localSize - 1) / localSize; + int startElem = tid * elementsPerThread; + int endElem = Math.min(startElem + elementsPerThread, totalElements); + + for (int globalElemIdx = startElem; globalElemIdx < endElem; globalElemIdx++) { + int seqIdx = globalElemIdx / headSize; + int dimIdx = globalElemIdx % headSize; + int kvOffset = loff + (tileC + seqIdx) * kvDim + kvHeadIdx * headSize + dimIdx; + int tileMemOffset = seqIdx * headSize + dimIdx; + + // Check bounds just to be safe, though kvDim/headSize should ensure this is valid. + if (tileMemOffset < MAX_TILE_ELEMENTS) { + k_tile[tileMemOffset] = key_cache.get(kvOffset); + v_tile[tileMemOffset] = value_cache.get(kvOffset); + } + } + context.localBarrier(); + + // === Compute attention scores (cooperative) === + for (int t = tid; t < tileLen; t += localSize) { + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += q_shared[d] * k_tile[t * headSize + d]; + } + s_tile[t] = score / TornadoMath.sqrt(headSize); + } + context.localBarrier(); + + // ... (Parallel reduction for tileMax - uses reduction_shared, which is now fixed) + float threadMax = Float.NEGATIVE_INFINITY; + for (int t = tid; t < tileLen; t += localSize) { + if (s_tile[t] > threadMax) { + threadMax = s_tile[t]; + } + } + reduction_shared[tid] = threadMax; + context.localBarrier(); + + for (int stride = localSize / 2; stride > 0; stride /= 2) { + if (tid < stride) { + reduction_shared[tid] = Math.max(reduction_shared[tid], reduction_shared[tid + stride]); + } + context.localBarrier(); + } + float tileMax = reduction_shared[0]; + + // === Update running max and rescale if needed === + float prevMax = state_shared[0]; + float newMax = Math.max(prevMax, tileMax); + float scale = 1.0f; + + if (newMax != prevMax && prevMax != Float.NEGATIVE_INFINITY) { + scale = TornadoMath.exp(prevMax - newMax); + for (int i = 0; i < myDimCount; i++) { + output[i] *= scale; + } + } + + // === Compute exp(score - max) and tile sum (cooperative) === + for (int t = tid; t < tileLen; t += localSize) { + exp_tile[t] = TornadoMath.exp(s_tile[t] - newMax); + } + context.localBarrier(); + + // Parallel reduction for tile sum + // ... (Uses reduction_shared, which is now fixed) + float threadSum = 0.0f; + for (int t = tid; t < tileLen; t += localSize) { + threadSum += exp_tile[t]; + } + reduction_shared[tid] = threadSum; + context.localBarrier(); + + for (int stride = localSize / 2; stride > 0; stride /= 2) { + if (tid < stride) { + reduction_shared[tid] += reduction_shared[tid + stride]; + } + context.localBarrier(); + } + float tileSum = reduction_shared[0]; + + // Update shared state (thread 0) + if (tid == 0) { + state_shared[0] = newMax; + state_shared[1] = state_shared[1] * scale + tileSum; + } + context.localBarrier(); + + // === Accumulate output (each thread handles its dimensions) === + for (int t = 0; t < tileLen; t++) { + float expScore = exp_tile[t]; + for (int i = 0; i < myDimCount; i++) { + int d = myStartDim + i; + output[i] += expScore * v_tile[t * headSize + d]; + } + } + context.localBarrier(); + } + + // === Final normalization and write === + float sumExp = state_shared[1]; + float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; + + int baseOffset = h * headSize + myStartDim; + for (int i = 0; i < myDimCount; i++) { + xb.set(baseOffset + i, output[i] * normFactor); + } + } + /** * Same as processHeadsFlashAttention but with some optimizations that seem to lower attention's execution time, especially in larger models. */ @@ -688,6 +1013,133 @@ public static void matrixVectorGeneric( hb.set(rowId, sum); } } + + /** + * Fused Q/K/V matrix-vector multiplication. + * Reduces kernel launch overhead and improves input vector cache utilization. + * + * Workgroup assignment: + * - rowId [0, dim): Q projection + * - rowId [dim, dim+kvDim): K projection + * - rowId [dim+kvDim, dim+2*kvDim): V projection + */ + public static void fusedQKVMatmulX( + KernelContext context, + HalfFloatArray x, // input vector (FP16) + FloatArray q, // output Q (FP32) + FloatArray k, // output K (FP32) + FloatArray v, // output V (FP32) + HalfFloatArray wq, // Q weight matrix + HalfFloatArray wk, // K weight matrix + HalfFloatArray wv, // V weight matrix + int dim, // model dimension (Q output size) + int kvDim, // KV dimension (K/V output size) + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowId < dim) { + // ========== Q projection ========== + int rowOffset = rowId * dim; + + float partialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partialSum += wq.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + q.set(rowId, localSum[0]); + } + + } else if (rowId < dim + kvDim) { + // ========== K projection ========== + int kRow = rowId - dim; + int rowOffset = kRow * dim; + + float partialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partialSum += wk.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + k.set(kRow, localSum[0]); + } + + } else if (rowId < dim + 2 * kvDim) { + // ========== V projection ========== + int vRow = rowId - dim - kvDim; + int rowOffset = vRow * dim; + + float partialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partialSum += wv.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + v.set(vRow, localSum[0]); + } + } + } + + // @formatter:off + public static void matrixVectorGeneric( + KernelContext context, + HalfFloatArray x, + FloatArray hb, // output + HalfFloatArray w, + int dim1, // inner loop + int dim0, // outer loop + int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= dim0) { + return; + } + float sum = matrixVectorRowMajorOptimizedSingle(context, localSize, x, w, dim1); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum); + } + } // @formatter:on /** @@ -730,6 +1182,26 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA } } + public static void matrixVectorGenericWithResidual(KernelContext context, HalfFloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float result = hb.get(rowId) + sum; + hb.set(rowId, result); + } + } + /** * Fused feed-forward network with SiLU activation and GLU gating. Implements the SwiGLU variant used in LLaMA-style models. * @@ -752,7 +1224,9 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA * @param localWorkGroupSize * Work group size */ - public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { + + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, HalfFloatArray x, HalfFloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, + int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; int localId = context.localIdx; @@ -768,21 +1242,61 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex if (localId == 0) { float silu = siluActivation(sum1); // Using the new SiLU method float result = silu * sum3; - hb.set(rowId, result); + hb.set(rowId, new HalfFloat(result)); } } - /** - * Gaussian Error Linear Unit (GELU) activation function. Approximation formula: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³))) - * - * @param x - * Input value - * @return Activated value - */ - public static float geluActivation(float x) { - float x3 = x * x * x; - return 0.5f * x * (1.0f + TornadoMath.tanh((0.797885f * (x + 0.044715f * x3)))); - } + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= d) { + return; + } + + float sum1 = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w1, n); + float sum3 = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w3, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float silu = siluActivation(sum1); // Using the new SiLU method + float result = silu * sum3; + hb.set(rowId, result); + } + } + + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, HalfFloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= d) { + return; + } + + HalfFloat sum1 = matrixVectorRowMajorOptimizedFHF(context, localWorkGroupSize, x, w1, n); + HalfFloat sum3 = matrixVectorRowMajorOptimizedFHF(context, localWorkGroupSize, x, w3, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float silu = siluActivation(sum1.getFloat32()); // Using the new SiLU method + float result = silu * sum3.getFloat32(); + hb.set(rowId, result); + } + } + + /** + * Gaussian Error Linear Unit (GELU) activation function. Approximation formula: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³))) + * + * @param x + * Input value + * @return Activated value + */ + public static float geluActivation(float x) { + float x3 = x * x * x; + return 0.5f * x * (1.0f + TornadoMath.tanh((0.797885f * (x + 0.044715f * x3)))); + } /** * Sigmoid-weighted Linear Unit (SiLU) activation function. Also known as Swish activation. @@ -876,6 +1390,299 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc return localSum[0]; } + public static HalfFloat matrixVectorRowMajorOptimizedFHF(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + HalfFloat[] localSum = context.allocateHalfFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; + // HalfFloat partialSum = new HalfFloat(0f); + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + // HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j)); + partialSum += w.get(matrixIdx).getFloat32() * x.get(j).getFloat32(); + // partialSum = HalfFloat.add(partialSum, mul); + } + + // Store partial sum in local memory + localSum[localId] = new HalfFloat(partialSum); + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] = HalfFloat.add(localSum[localId], localSum[localId + stride]); + // localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + public static float matrixVectorRowMajorOptimizedF(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; + // HalfFloat partialSum = new HalfFloat(0f); + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + // HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j)); + partialSum += w.get(matrixIdx).getFloat32() * x.get(j).getFloat32(); + // partialSum = HalfFloat.add(partialSum, mul); + } + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + public static float matrixVectorRowMajorOptimizedXX(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f; + + int stride = localSize; + int stride2 = localSize << 1; + int stride3 = localSize * 3; + int stride4 = localSize << 2; + + // Already coalesced: thread 0 reads idx 0, thread 1 reads idx 1, etc. + int j = localId; + int limit = n - stride3; + + for (; j < limit; j += stride4) { + int base = rowOffset + j; + // Hoist x.get() calls - they're reused across all rows + float x0 = x.get(j).getFloat32(); + float x1 = x.get(j + stride).getFloat32(); + float x2 = x.get(j + stride2).getFloat32(); + float x3 = x.get(j + stride3).getFloat32(); + + sum0 += w.get(base).getFloat32() * x0; + sum1 += w.get(base + stride).getFloat32() * x1; + sum2 += w.get(base + stride2).getFloat32() * x2; + sum3 += w.get(base + stride3).getFloat32() * x3; + } + + for (; j < n; j += stride) { + sum0 += w.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + + localSum[localId] = (sum0 + sum1) + (sum2 + sum3); + context.localBarrier(); + + // Reduction with minimal barriers + for (int s = localSize >> 1; s > 0; s >>= 1) { + if (localId < s) { + localSum[localId] += localSum[localId + s]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + public static float matrixVectorRowMajorOptimizedSingle(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + HalfFloat partialSum = new HalfFloat(0f); + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j)); + partialSum = HalfFloat.add(partialSum, mul); + } + + // Store partial sum in local memory + localSum[localId] = partialSum.getHalfFloatValue(); + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Accumulate in HalfFloat to avoid conversions in inner loop + HalfFloat sum0 = new HalfFloat(0f); + HalfFloat sum1 = new HalfFloat(0f); + HalfFloat sum2 = new HalfFloat(0f); + HalfFloat sum3 = new HalfFloat(0f); + + int stride = localSize; + int stride2 = localSize << 1; + int stride3 = localSize * 3; + int stride4 = localSize << 2; + + int j = localId; + int limit = n - stride3; + + for (; j < limit; j += stride4) { + int base = rowOffset + j; + + // Stay in HalfFloat - no getFloat32() calls + HalfFloat x0 = x.get(j); + HalfFloat x1 = x.get(j + stride); + HalfFloat x2 = x.get(j + stride2); + HalfFloat x3 = x.get(j + stride3); + + sum0 = HalfFloat.add(sum0, HalfFloat.mult(w.get(base), x0)); + sum1 = HalfFloat.add(sum1, HalfFloat.mult(w.get(base + stride), x1)); + sum2 = HalfFloat.add(sum2, HalfFloat.mult(w.get(base + stride2), x2)); + sum3 = HalfFloat.add(sum3, HalfFloat.mult(w.get(base + stride3), x3)); + } + + // Cleanup loop + for (; j < n; j += stride) { + sum0 = HalfFloat.add(sum0, HalfFloat.mult(w.get(rowOffset + j), x.get(j))); + } + + // Convert to float32 only at the end for reduction + localSum[localId] = sum0.getFloat32() + sum1.getFloat32() + sum2.getFloat32() + sum3.getFloat32(); + context.localBarrier(); + + for (int s = localSize >> 1; s > 0; s >>= 1) { + if (localId < s) { + localSum[localId] += localSum[localId + s]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + public static void fusedQKVMatmul(KernelContext context, HalfFloatArray x, // input (read once!) + FloatArray q, FloatArray k, FloatArray v, // outputs + HalfFloatArray wq, HalfFloatArray wk, HalfFloatArray wv, int dim, int kvDim, int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Determine which output this workgroup computes + int totalRows = dim + 2 * kvDim; // Q rows + K rows + V rows + + if (rowId < dim) { + // Q projection + float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wq, dim); + if (localId == 0) { + q.set(rowId, sum); + } + } else if (rowId < dim + kvDim) { + // K projection + int kRow = rowId - dim; + float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wk, dim); + if (localId == 0) { + k.set(kRow, sum); + } + } else { + // V projection + int vRow = rowId - dim - kvDim; + float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wv, dim); + if (localId == 0) { + v.set(vRow, sum); + } + } + } + + public static float matrixVectorRowMajorOptimizedx(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product - UNROLLED BY 4 + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + + int j = localId; + int stride = localSize; + int stride4 = localSize << 2; // localSize * 4 + int limit = n - (stride * 3); // Safe limit for 4 elements + + // Main loop unrolled by 4 with separate accumulators + for (; j < limit; j += stride4) { + int base = rowOffset + j; + int j1 = j + stride; + int j2 = j + (stride << 1); + int j3 = j + stride * 3; + + sum0 += w.get(base).getFloat32() * x.get(j).getFloat32(); + sum1 += w.get(base + stride).getFloat32() * x.get(j1).getFloat32(); + sum2 += w.get(base + (stride << 1)).getFloat32() * x.get(j2).getFloat32(); + sum3 += w.get(base + stride * 3).getFloat32() * x.get(j3).getFloat32(); + } + + // Handle remainder + for (; j < n; j += stride) { + sum0 += w.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + + // Combine accumulators (tree reduction for better precision) + float partialSum = (sum0 + sum1) + (sum2 + sum3); + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int s = localSize >> 1; s > 0; s >>= 1) { + if (localId < s) { + localSum[localId] += localSum[localId + s]; + } + context.localBarrier(); + } + + return localSum[0]; + } + // Second kernel - Combines partial sums and computes final normalization public static void reductionFinalNormalization(KernelContext context, FloatArray output, int size, float ermsNorm) { int gid = context.globalIdx; @@ -1169,10 +1976,7 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex } } - public static void fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte(KernelContext context, FloatArray x, FloatArray hb, - ByteArray w1, - ByteArray w3, - int n, int d, int localWorkGroupSize) { + public static void fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte(KernelContext context, FloatArray x, FloatArray hb, ByteArray w1, ByteArray w3, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; int localId = context.localIdx; @@ -1237,4 +2041,364 @@ public static void processHeadsParallel(FloatArray q, FloatArray key_cache, Floa } } + /** + * Fused Q/K/V matrix-vector multiplication for Q8_0 quantized weights. Reduces kernel launch overhead and improves input vector cache utilization. + * + * Workgroup assignment: - rowId [0, dim): Q projection - rowId [dim, dim+kvDim): K projection - rowId [dim+kvDim, dim+2*kvDim): V projection + */ + public static void fusedQKVMatmulQ8(KernelContext context, FloatArray x, FloatArray q, FloatArray k, FloatArray v, ByteArray wq, ByteArray wk, ByteArray wv, int dim, int kvDim, + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + int blockSize = 32; + final int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (dim + blockSize - 1) / blockSize; + + float[] localSums = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowId < dim) { + // ========== Q projection ========== + int rowBlockOffset = rowId * blocksPerRow; + + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wq.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + byte quant1 = wq.get(quantsOffset); + byte quant2 = wq.get(quantsOffset + 1); + byte quant3 = wq.get(quantsOffset + 2); + byte quant4 = wq.get(quantsOffset + 3); + + partialSum1 += ((float) quant1 * scaleFloat) * x.get(j); + partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1); + partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2); + partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wq.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + byte quant = wq.get(blockByteOffset + 2 + withinBlockIdx); + partialSum += ((float) quant * scaleFloat) * x.get(j); + } + + localSums[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + q.set(rowId, localSums[0]); + } + + } else if (rowId < dim + kvDim) { + // ========== K projection ========== + int kRow = rowId - dim; + int rowBlockOffset = kRow * blocksPerRow; + + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wk.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + byte quant1 = wk.get(quantsOffset); + byte quant2 = wk.get(quantsOffset + 1); + byte quant3 = wk.get(quantsOffset + 2); + byte quant4 = wk.get(quantsOffset + 3); + + partialSum1 += ((float) quant1 * scaleFloat) * x.get(j); + partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1); + partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2); + partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wk.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + byte quant = wk.get(blockByteOffset + 2 + withinBlockIdx); + partialSum += ((float) quant * scaleFloat) * x.get(j); + } + + localSums[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + k.set(kRow, localSums[0]); + } + + } else if (rowId < dim + 2 * kvDim) { + // ========== V projection ========== + int vRow = rowId - dim - kvDim; + int rowBlockOffset = vRow * blocksPerRow; + + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wv.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + byte quant1 = wv.get(quantsOffset); + byte quant2 = wv.get(quantsOffset + 1); + byte quant3 = wv.get(quantsOffset + 2); + byte quant4 = wv.get(quantsOffset + 3); + + partialSum1 += ((float) quant1 * scaleFloat) * x.get(j); + partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1); + partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2); + partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wv.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + byte quant = wv.get(blockByteOffset + 2 + withinBlockIdx); + partialSum += ((float) quant * scaleFloat) * x.get(j); + } + + localSums[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + v.set(vRow, localSums[0]); + } + } + } + + /** + * Fully fused RMS normalization + FFN W1/W3 matmul with SiLU/GLU for Q8_0 weights. + * Each workgroup redundantly computes RMS scale to avoid cross-workgroup sync. + */ + public static void fullyFusedRmsNormFFNGateUpQ8( + KernelContext context, + FloatArray x, // raw input (FP32) + FloatArray hb, // output + FloatArray rmsWeights, // RMS norm weights + ByteArray w1, // Q8_0 quantized + ByteArray w3, // Q8_0 quantized + int dim, // input dimension + int hiddenDim, // output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= hiddenDim) { + return; + } + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + // ========== RMS Norm: Compute scale (each workgroup does this redundantly) ========== + float sumSquares = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float val = x.get(j); + sumSquares += val * val; + } + + localSum[localId] = sumSquares; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + float scale = 1.0f / TornadoMath.sqrt(localSum[0] / dim + 1e-5f); + + // ========== W1 matmul with inline RMS normalization ========== + int blockSize = 32; + final int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (dim + blockSize - 1) / blockSize; + int rowBlockOffset = rowId * blocksPerRow; + + float partialSum1_a = 0.0f; + float partialSum1_b = 0.0f; + float partialSum1_c = 0.0f; + float partialSum1_d = 0.0f; + + for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat w1Scale = w1.getHalfFloat(blockByteOffset); + float w1ScaleFloat = w1Scale.getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + byte q1 = w1.get(quantsOffset); + byte q2 = w1.get(quantsOffset + 1); + byte q3 = w1.get(quantsOffset + 2); + byte q4 = w1.get(quantsOffset + 3); + + float norm0 = rmsWeights.get(j) * scale * x.get(j); + float norm1 = rmsWeights.get(j + 1) * scale * x.get(j + 1); + float norm2 = rmsWeights.get(j + 2) * scale * x.get(j + 2); + float norm3 = rmsWeights.get(j + 3) * scale * x.get(j + 3); + + partialSum1_a += ((float) q1 * w1ScaleFloat) * norm0; + partialSum1_b += ((float) q2 * w1ScaleFloat) * norm1; + partialSum1_c += ((float) q3 * w1ScaleFloat) * norm2; + partialSum1_d += ((float) q4 * w1ScaleFloat) * norm3; + } + + float partialSum1 = partialSum1_a + partialSum1_b + partialSum1_c + partialSum1_d; + + for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat w1Scale = w1.getHalfFloat(blockByteOffset); + float w1ScaleFloat = w1Scale.getFloat32(); + + byte quant = w1.get(blockByteOffset + 2 + withinBlockIdx); + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum1 += ((float) quant * w1ScaleFloat) * normalized; + } + + localSum[localId] = partialSum1; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float result1 = localSum[0]; + + // ========== W3 matmul with inline RMS normalization ========== + float partialSum3_a = 0.0f; + float partialSum3_b = 0.0f; + float partialSum3_c = 0.0f; + float partialSum3_d = 0.0f; + + for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat w3Scale = w3.getHalfFloat(blockByteOffset); + float w3ScaleFloat = w3Scale.getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + byte q1 = w3.get(quantsOffset); + byte q2 = w3.get(quantsOffset + 1); + byte q3 = w3.get(quantsOffset + 2); + byte q4 = w3.get(quantsOffset + 3); + + float norm0 = rmsWeights.get(j) * scale * x.get(j); + float norm1 = rmsWeights.get(j + 1) * scale * x.get(j + 1); + float norm2 = rmsWeights.get(j + 2) * scale * x.get(j + 2); + float norm3 = rmsWeights.get(j + 3) * scale * x.get(j + 3); + + partialSum3_a += ((float) q1 * w3ScaleFloat) * norm0; + partialSum3_b += ((float) q2 * w3ScaleFloat) * norm1; + partialSum3_c += ((float) q3 * w3ScaleFloat) * norm2; + partialSum3_d += ((float) q4 * w3ScaleFloat) * norm3; + } + + float partialSum3 = partialSum3_a + partialSum3_b + partialSum3_c + partialSum3_d; + + for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat w3Scale = w3.getHalfFloat(blockByteOffset); + float w3ScaleFloat = w3Scale.getFloat32(); + + byte quant = w3.get(blockByteOffset + 2 + withinBlockIdx); + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum3 += ((float) quant * w3ScaleFloat) * normalized; + } + + localSum[localId] = partialSum3; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float result3 = localSum[0]; + + // ========== SiLU + GLU ========== + if (localId == 0) { + float silu = result1 / (1.0f + TornadoMath.exp(-result1)); + hb.set(rowId, silu * result3); + } + } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 96acd650..8d105e89 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -4,6 +4,7 @@ import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; @@ -21,7 +22,8 @@ public class LlamaFP16FFNLayers extends AbstractFFNLayers { TaskGraph ffnTaskGraphs; GridScheduler scheduler; - List ffnLayerTaskGraphs; + List ffnLayerTaskGraphs; + public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); this.ffnLayerTaskGraphs = setupFFNLayered(); @@ -29,47 +31,45 @@ public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Config @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim()/2, 128); WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); - WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + + int fusedQKVRows = config.dim() + 2 * config.kvDim(); + int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); // Map workers to tasks for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + // === Attention Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + // === FFN Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); } return tornadoForwardScheduler; } @Override - public GridScheduler getGridScheduler() { + public GridScheduler getGridScheduler() { return scheduler; } @Override - public TaskGraph getTaskGraph() { + public TaskGraph getTaskGraph() { return ffnTaskGraphs; } @@ -83,22 +83,106 @@ public List getFfnLayerTaskGraphs() { } List setupFFNLayered() { - state.temp.init(0.0f); - state.tempFFN.init(0.0f); - var numLayers = config.numberOfLayers(); - - return IntStream.range(0, numLayers) - .mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); - if (i == numLayers - 1) setupLastID(ffnLayer.getTaskGraphName()); - return ffnLayer.snapshot(); - }) - .toList(); + return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); + if (i == config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + return ffnLayer.snapshot(); + }).toList(); } + // @formatter:off + /** + * Transformer Layer Task Flow (LlamaFP16FFNLayers) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (partial sums) + * └────────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ attn_rms_finalize│──▶ temp (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌─────────────────────┐ + * │ attn_rms_apply_fp16 │──▶ wrapXbFP16 (normalized, FP16) + * └──────────┬──────────┘ + * │ + * ▼ + * ┌────────────────┐ ┌─────────────────────────────┐ + * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │ + * └───────┬────────┘ └─────────────────────────────┘ + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (partial sums) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌─────────────────┐ + * │ ffn_rms_finalize│──▶ tempFFN (final scale) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points: + * • qkv_projection: Fused Q/K/V matmuls (3→1 kernel) + * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel) + * + */ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, weights.rms_att_weightLayered[layerIndex].asFloatArray(), @@ -111,68 +195,174 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, weights.w2Layered[layerIndex].asHalfFloatArray(), weights.w3Layered[layerIndex].asHalfFloatArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer - .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), - layerIndex, config.contextLength()); - configureAttention(unifiedLayer, layerIndex); - unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), - weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), - config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + + // === Attention Block === + // RMS Normalization + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantize, + context, state.wrapXbFP16, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + // QKV Projection (fused) + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulX, + context, + state.wrapXbFP16, // input (FP32) + state.wrapQ, // output Q + state.wrapK, // output K + state.wrapV, // output V + weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq + weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk + weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv + config.dim(), // dim + config.kvDim(), // kvDim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // RoPE + KV Cache + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, // Q (in/out) + state.wrapK, // K (in/out) + state.wrapV, // V (in only) + state.wrapKeyCache, // Key cache (out) + state.wrapValueCache, // Value cache (out) + config.kvDim(), + config.headSize(), + layerIndex, + config.contextLength()); + // Attention + configureAttention(unifiedLayer, layerIndex); + // Output Projection (Wo) with residual + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asHalfFloatArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // === FFN Block === + // RMS Normalization + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + state.wrapX, // raw input (FP32) + state.wrapHb, // output + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + state.tempFFN, // RMS scale factor + weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 + weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 + config.dim(), // input dimension + config.hiddenDim(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Down projection (W2) with residual + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asHalfFloatArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.persistOnDevice(state.wrapX); + return unifiedLayer; } protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { - // First layer: Transfer initial data to device (one-time transfer) if (layerIndex == 0) { - // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb); // + // First layer: Transfer initial data to device (one-time transfer) + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, + state.temp, state.tempFFN + ); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Kernel context + context, + // Intermediate buffers + state.wrapXb, state.wrapXb2, + // QKV vectors + state.wrapQ, state.wrapK, state.wrapV, + // KV cache + state.wrapKeyCache, state.wrapValueCache, + // Attention & FFN buffers + state.wrapAtt, state.wrapHb, state.wrapXbFP16); } else { // Subsequent layers: Consume data already on device from previous layer - unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - state.positionHolder // - ); + unifiedLayer.consumeFromDevice( + // Kernel context + context, + // Intermediate buffers + state.wrapXb, state.wrapXb2, + // QKV vectors + state.wrapQ, state.wrapK, state.wrapV, + // KV cache + state.wrapKeyCache, state.wrapValueCache, + // Attention & FFN buffers + state.wrapAtt, state.wrapHb, + // Position & misc + state.positionHolder, state.wrapXbFP16); } return unifiedLayer; } private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { if (schedulerType == SchedulerType.NVIDIA) { - return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, - context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), - state.positionHolder, layerIndex, config.contextLength()); + // Flash Attention (optimized for NVIDIA GPUs) + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), + config.kvDim(), + config.kvMul(), + state.positionHolder, + layerIndex, + config.contextLength()); } else { - return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + // Standard parallel attention (for non-NVIDIA backends) + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), + config.kvDim(), + config.kvMul(), + config.contextLength(), // seqLen parameter + state.positionHolder, + state.wrapAtt, // Attention weights buffer + layerIndex, + config.contextLength()); } } + // @formatter:on + } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index a674c1c5..d2a81407 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -2,8 +2,8 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; @@ -28,46 +28,81 @@ public class LogitsFP16Layer extends AbstractLayer { public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; - state.tempLogits.init(0.0f); + this.schedulerType = schedulerType; var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); - this.schedulerType = schedulerType; } - /** - * Builds the logits computation graph. - */ + // @formatter:off private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { - TaskGraph logits = new TaskGraph("logits"); - logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), weights.rms_final_weight_as_floatArray.asFloatArray()) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (schedulerType == SchedulerType.NON_NVIDIA) { - logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); - } - logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) - .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), - LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); + var logits = new TaskGraph("logits"); + // === Data Setup === + logits.consumeFromDevice(lastTaskGraphID, state.wrapX); + logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); + logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Kernel context + context, + // Output buffer + state.wrapLogits, + // Intermediate FP16 buffer + state.wrapXbFP16, + // Weights + weights.wclsByteArray.asHalfFloatArray(), + weights.rms_final_weight_as_floatArray.asFloatArray()); + + // === Final RMS Normalization === + logits.task("rms_reduce", + TransformerComputeKernels::reductionOneBlockWithLayer, + context, + state.tempLogits, // output: partial sums + final scale factor + state.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon for numerical stability + state.localSize); // local workgroup size + + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.tempLogits, // in/out: combines partial sums + config.dim(), // dimension + config.rmsNormEps()); // epsilon + } + + logits.task("rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantizeLogits, + context, + state.wrapXbFP16, // output: normalized (FP16) + state.wrapX, // input: hidden state + weights.rms_final_weight_as_floatArray.asFloatArray(), // RMS weights + state.tempLogits); // scale factor from reduction + + // === Vocabulary Projection === + logits.task("vocab_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, + state.wrapXbFP16, // input (FP16) + state.wrapLogits, // output + weights.wclsByteArray.asHalfFloatArray(), // vocabulary weights + config.dim(), // input dimension + config.vocabularySize(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); + + // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } + // @formatter:on @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid logitsRMS; - if (weights instanceof Qwen2TornadoWeights) { - logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); - } else { - logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); - } - + WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - - tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); return tornadoForwardScheduler; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 75f9f531..c3b9ddc6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -3,6 +3,7 @@ import org.beehive.gpullama3.inference.state.Phi3State; import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Phi3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; @@ -19,42 +20,27 @@ /** * Phi3FP16FFNLayers: FP16 FFN layers for Phi3 with Group Query Attention (GQA) support. * - * Key Differences from Qwen2/Qwen3: - * - Uses combined QKV matrix (wqkv) instead of separate Q, K, V matrices - * - Includes splitQKV task to separate combined buffer - * - Uses ropeRotationPhi3 kernel for position embeddings - * - FFN uses single wUp matrix that outputs both Gate and Up (2 * hiddenDim) - * - Includes splitGateUpAndSiLU task for FFN activation - * - Uses wDown for final FFN projection - * - No Q, K, V bias terms + * Key Differences from Qwen2/Qwen3: - Uses combined QKV matrix (wqkv) instead of separate Q, K, V matrices - Includes splitQKV task to separate combined buffer - Uses ropeRotationPhi3 kernel for + * position embeddings - FFN uses single wUp matrix that outputs both Gate and Up (2 * hiddenDim) - Includes splitGateUpAndSiLU task for FFN activation - Uses wDown for final FFN projection - No Q, K, + * V bias terms * * Works directly with Phi3State to access and mutate Phi3-specific state fields. */ public class Phi3FP16FFNLayers extends AbstractFFNLayers { - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; - List ffnLayerTaskGraphs; - // Typed references to Phi3-specific state and config private final Phi3State phi3State; private final Phi3Configuration phi3Config; - // Phi3-specific dimension for combined QKV buffer private final int opSize; + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { - super(taskGraphName, state, weights, config,schedulerType); + super(taskGraphName, state, weights, config, schedulerType); this.phi3State = state; this.phi3Config = config; - - // Ensure we have Phi3-specific weights - if (!(weights instanceof Phi3TornadoWeights phi3Weights)) { - throw new IllegalArgumentException("Phi3FP16FFNLayers requires Phi3TornadoWeights with TornadoTensor layout"); - } - - // Calculate opSize for combined QKV buffer - // opSize = num_heads * head_dim + 2 * (num_key_value_heads * head_dim) = dim + 2 * kvDim this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); ffnLayerTaskGraphs = setupFFNLayered(); } @@ -64,46 +50,41 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { // RMS norm worker WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize); - // Combined QKV matmul worker - int matmulQkvGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid matmulQkvRowMajorWorker = WorkerGridFactory.genericWorker(matmulQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - - // RoPE worker (2D: heads x embedding_head/2) - int ic = config.headSize() / 2; - WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), config.headSize()); + // Fused RMS + QKV matmul worker + int fusedQkvGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - // Copy to cache worker - WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.kvDim(), 32); + // Fused RoPE + cache copy worker (Phi3 uses dim/2 pattern) + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); // Parallel attention worker WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); - // Matmul1 worker (output projection) + // Output projection worker int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC); - // FFN workers - int ffnUpGlobal = (2 * config.hiddenDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid ffnUpWorker = WorkerGridFactory.genericWorker(ffnUpGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + // Fused RMS + FFN gate/up worker + int fusedFFNGlobal = (2 * config.hiddenDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedFFNWorker = WorkerGridFactory.genericWorker(fusedFFNGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + // FFN down projection worker int ffnDownGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + // Same worker as before - total rows = dim + 2*kvDim = opSize - // Map workers to tasks for each layer for (int i = 0; i < config.numberOfLayers(); i++) { - gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", matmulQkvRowMajorWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); - gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".wGateUp", ffnUpWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".wDown", ffnDownWorker); + // === Attention Block === + gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQkvWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker); + // === FFN Block === + gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_silu", fusedFFNWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", ffnDownWorker); } - return gridScheduler; } @@ -131,11 +112,6 @@ public List getFfnLayerTaskGraphs() { */ List setupFFNLayered() { List ffnGraphs = new ArrayList<>(); - - // Initialize buffers using Phi3State directly - phi3State.temp.init(0.0f); - phi3State.tempFFN.init(0.0f); - for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSinglePhi3FFNLayer((Phi3TornadoWeights) weights, layerIndex); if (layerIndex == phi3Config.numberOfLayers() - 1) { @@ -143,160 +119,197 @@ List setupFFNLayered() { } ffnGraphs.add(ffnLayer.snapshot()); } - return ffnGraphs; } + // @formatter:off /** - * Setup a single transformer layer for Phi3 with combined QKV and gate/up FFN + * Transformer Layer Task Flow (Phi3FP16FFNLayers - Fully Optimized) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (scale factor for RMSNorm) + * └────────┬────────┘ + * │ + * ▼ + * ┌────────────────────────┐ + * │ attn_rms_qkv_projection│──▶ wrapQ, wrapK, wrapV (direct output) + * └───────────┬────────────┘ (fused: RMS apply + QKV matmul + split) + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ (fused: Phi3 RoPE + cache write) + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (scale factor) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ ffn_rms_finalize │──▶ tempFFN (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌──────────────┐ + * │ rms_ffn_silu │──▶ wrapHbU = SiLU(RMSNorm(x)·Wgate) ⊙ (RMSNorm(x)·Wup) + * └──────┬───────┘ (fused: RMS apply + gate/up matmul + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += wDown · wrapHbU (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 8 tasks (NVIDIA) / 9 tasks (non-NVIDIA) + * Original: 13 tasks + * Reduction: 5 tasks eliminated (38% fewer kernel launches) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points (vs original 13 tasks): + * • attn_rms_qkv_projection: Fused RMS apply + QKV matmul + direct split (3→1 kernel) + * • rope_and_kv_cache: Fused Phi3 RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_silu: Fused RMS apply + gate/up matmul + SiLU + GLU (3→1 kernel) + * + * Phi3-Specific: + * • Combined wqkv: Single [opSize × dim] matrix for Q+K+V projection + * • Direct QKV output: No intermediate buffer, routes by row index + * • Phi3 RoPE: Uses headSize/2 offset pattern (different from Llama/Qwen) + * • Combined wUp: Single [2×hiddenDim × dim] matrix for gate+up + * • Inline SiLU+GLU: No intermediate wrapHb buffer needed + * */ + // @formatter:off TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { - - TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); + var taskGraphName = "layer_" + layerIndex; + var unifiedLayer = new TaskGraph(taskGraphName); unifiedLayer.consumeFromDevice(phi3State.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - // Copy-in weights per layer for batched-layered layout + // Attention weights weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqkvLayered[layerIndex].asHalfFloatArray(), weights.woLayered[layerIndex].asHalfFloatArray(), + // FFN weights weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), weights.wUpLayered[layerIndex].asHalfFloatArray(), - weights.wDownLayered[layerIndex].asHalfFloatArray() - ); + weights.wDownLayered[layerIndex].asHalfFloatArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - // RMSNorm for attention input - unifiedLayer.task("reductionsOneBlock", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, - phi3State.temp, - phi3State.wrapX, - phi3Config.dim(), - phi3Config.rmsNormEps(), - phi3State.localSize) - .task("mapContext", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, - context, - phi3State.wrapXb, - phi3State.wrapX, - weights.rms_att_weightLayered[layerIndex].asFloatArray(), - phi3State.temp); - - // Combined QKV projection - unifiedLayer.task("qkvmatmul", - TransformerComputeKernelsLayered::matrixVectorGeneric, - context, - phi3State.wrapXb, - phi3State.wrapQkv, - weights.wqkvLayered[layerIndex].asHalfFloatArray(), - phi3Config.dim(), - opSize, - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("splitQKV", - TransformerComputeKernelsLayered::splitQKV, - phi3State.wrapQkv, - phi3State.wrapQ, - phi3State.wrapK, - phi3State.wrapV, - phi3Config.dim(), - phi3Config.headSize() * phi3Config.numberOfKeyValueHeads()); - - // RoPE rotation (Phi3-specific kernel) - unifiedLayer.task("rope", - TransformerComputeKernelsLayered::ropeRotationPhi3, - context, - phi3State.positionHolder, - phi3State.wrapQ, - phi3State.wrapK, - phi3Config.kvDim(), - phi3Config.headSize()); - - // Copy to caches - unifiedLayer.task("copyToCaches", - TransformerComputeKernelsLayered::copyToCache, - phi3State.wrapKeyCache, - phi3State.wrapK, - phi3State.wrapValueCache, - phi3State.wrapV, - phi3State.positionHolder, - phi3Config.kvDim(), - layerIndex, - phi3Config.contextLength()); - - // Parallel attention - unifiedLayer.task("parallel-attention", - TransformerComputeKernelsLayered::processHeadsFlashAttention, - context, - phi3State.wrapQ, - phi3State.wrapKeyCache, - phi3State.wrapValueCache, - phi3State.wrapXb, - phi3Config.numberOfHeads(), - phi3Config.headSize(), - phi3Config.kvDim(), - phi3Config.kvMul(), - phi3State.positionHolder, - layerIndex, - phi3Config.contextLength()); - - // Output projection - unifiedLayer.task("matmul1", - TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, - context, - phi3State.wrapXb, - phi3State.wrapX, - weights.woLayered[layerIndex].asHalfFloatArray(), - phi3Config.dim(), - phi3Config.dim(), + // ═══════════════════════════════════════════════════════════════════════ + // ATTENTION BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("attn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, phi3State.temp, // output: scale factor + phi3State.wrapX, // input: hidden state + phi3Config.dim(), // dimension + phi3Config.rmsNormEps(), // epsilon + phi3State.localSize); // local memory size + + unifiedLayer.task("attn_rms_qkv_projection", Phi3Kernels::fusedRmsNormQKVMatmulDirect, context, phi3State.wrapX, // input + phi3State.wrapQ, // output Q + phi3State.wrapK, // output K + phi3State.wrapV, // output V + weights.rms_att_weightLayered[layerIndex].asFloatArray(), phi3State.temp, // RMS scale + weights.wqkvLayered[layerIndex].asHalfFloatArray(), phi3Config.dim(), // dim + phi3Config.kvDim(), // kvDim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Fused Phi3 RoPE Rotation + KV Cache Write + unifiedLayer.task("rope_and_kv_cache", Phi3Kernels::ropeRotationWithCacheCopyPhi3, context, phi3State.positionHolder, // current position + phi3State.wrapQ, // Q vectors (in/out, rotated) + phi3State.wrapK, // K vectors (in/out, rotated) + phi3State.wrapV, // V vectors (in only) + phi3State.wrapKeyCache, // key cache (out) + phi3State.wrapValueCache, // value cache (out) + phi3Config.numberOfKeyValueHeads(), // nHeadKv + phi3Config.headSize(), // head dimension + phi3Config.kvDim(), // kvDim + layerIndex, // layer index for cache offset + phi3Config.contextLength()); // max sequence length + + // Flash Attention + unifiedLayer.task("attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, phi3State.wrapQ, // query vectors + phi3State.wrapKeyCache, // key cache + phi3State.wrapValueCache, // value cache + phi3State.wrapXb, // output: attention result + phi3Config.numberOfHeads(), // nHeads + phi3Config.headSize(), // headSize + phi3Config.kvDim(), // kvDim + phi3Config.kvMul(), // kvMul (nHeads / nHeadKv) + phi3State.positionHolder, // position + layerIndex, // layer index + phi3Config.contextLength()); // context length + + // Output Projection with Residual + unifiedLayer.task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, phi3State.wrapXb, // input: attention output + phi3State.wrapX, // output: wrapX += Wo · wrapXb + weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim × dim] + phi3Config.dim(), // input dim + phi3Config.dim(), // output dim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // ═══════════════════════════════════════════════════════════════════════ + // FFN BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("ffn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, phi3State.tempFFN, // output: scale factor + phi3State.wrapX, // input: hidden state + phi3Config.dim(), // dimension + phi3Config.rmsNormEps(), // epsilon + phi3State.localSize); // local memory size + + // Final normalization (non-NVIDIA only) + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, context, phi3State.tempFFN, // scale factor (in/out) + phi3Config.dim(), // dimension + phi3Config.rmsNormEps()); // epsilon + } + + unifiedLayer.task("rms_ffn_silu", Phi3Kernels::fusedRmsNormFFNGateUpSiLU, context, phi3State.wrapX, // input + phi3State.wrapHbU, // output (direct to final FFN buffer) + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), phi3State.tempFFN, // RMS scale + weights.wUpLayered[layerIndex].asHalfFloatArray(), phi3Config.dim(), // input dim + phi3Config.hiddenDim(), // output dim (hiddenDim, not 2×hiddenDim!) LOCAL_WORK_GROUP_SIZE_ALLOC); - // FFN section: RMSNorm - unifiedLayer.task("reductionsOneBlockFFN", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, - phi3State.tempFFN, - phi3State.wrapX, - phi3Config.dim(), - phi3Config.rmsNormEps(), - phi3State.localSize) - .task("mapContextFFN", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, - context, - phi3State.wrapXb, - phi3State.wrapX, - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - phi3State.tempFFN); - - // FFN: combined Up and Gate projection (outputs 2 * hiddenDim) - unifiedLayer.task("wGateUp", - TransformerComputeKernelsLayered::matrixVectorGeneric, - context, - phi3State.wrapXb, - phi3State.wrapHb, - weights.wUpLayered[layerIndex].asHalfFloatArray(), - phi3Config.dim(), - 2 * phi3Config.hiddenDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("gateUpSiLU", - TransformerComputeKernelsLayered::splitGateUpAndSiLU, - phi3State.wrapHb, - phi3State.wrapHbG, - phi3State.wrapHbU, - phi3Config.hiddenDim()); - - // FFN: Down projection with residual - unifiedLayer.task("wDown", - TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, - context, - phi3State.wrapHbU, - phi3State.wrapX, - weights.wDownLayered[layerIndex].asHalfFloatArray(), - phi3Config.hiddenDim(), - phi3Config.dim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice( - phi3State.wrapX - ); + // Down Projection with Residual + unifiedLayer.task("ffn_down_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, phi3State.wrapHbU, // input: FFN intermediate + phi3State.wrapX, // output: wrapX += wDown · wrapHbU + weights.wDownLayered[layerIndex].asHalfFloatArray(), // wDown [dim × hiddenDim] + phi3Config.hiddenDim(), // input dim + phi3Config.dim(), // output dim + LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(phi3State.wrapX); + return unifiedLayer; } @@ -304,26 +317,26 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { * Configure data transfers for first and subsequent layers */ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { - // First layer: Transfer initial data to device (one-time transfer) if (layerIndex == 0) { - // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); // + // First layer: Transfer temporary buffers and state every execution + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, phi3State.positionHolder, + phi3State.temp, phi3State.tempFFN); + // First execution: allocate workspace buffers + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, phi3State.wrapXb, phi3State.wrapXb2, + phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, + phi3State.wrapKeyCache, phi3State.wrapValueCache, + phi3State.wrapAtt, phi3State.wrapHb, phi3State.wrapHbG, + phi3State.wrapHbU, phi3State.wrapQkv); } else { - // Subsequent layers: Consume data already on device from previous layer - unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - state.positionHolder, // / + // Subsequent layers: Consume data from previous layer + unifiedLayer.consumeFromDevice(context, phi3State.wrapXb, phi3State.wrapXb2, + phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, phi3State.wrapKeyCache, + phi3State.wrapValueCache, phi3State.wrapAtt, phi3State.wrapHb, phi3State.positionHolder, phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); } return unifiedLayer; } + // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 858848ea..aded0c81 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -51,7 +51,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) int ic = config.headSize() / 2; WorkerGrid ropeWorker = new WorkerGrid2D(h, ic); ropeWorker.setGlobalWork(h, ic, 1); - ropeWorker.setLocalWork(1, 1, 1); + ropeWorker.setLocalWork(h / 2, ic / 2, 1); int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); @@ -61,13 +61,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - WorkerGrid qBiasWorker = new WorkerGrid1D(config.dim()); - qBiasWorker.setGlobalWork(config.dim(), 1, 1); - qBiasWorker.setLocalWork(config.dim() / 8, 1, 1); - WorkerGrid kvBiasWorker = new WorkerGrid1D(config.kvDim()); - kvBiasWorker.setGlobalWork(config.kvDim(), 1, 1); - kvBiasWorker.setLocalWork(32, 1, 1); - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); @@ -87,32 +80,38 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) } } + // WorkerGrid for fused QKV bias addition (dimension is dimQ) + WorkerGrid fusedQKVBiasWorker = new WorkerGrid1D(config.dim()); + fusedQKVBiasWorker.setGlobalWork(config.dim(), 1, 1); + fusedQKVBiasWorker.setLocalWork(32, 1, 1); // Or an optimized local size + WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * optimalLocalSize, 1, 1); parallelAttentionWorker.setLocalWork(optimalLocalSize, 1, 1); - WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); - copyToCachesWorker.setGlobalWork(config.kvDim(), 1, 1); - copyToCachesWorker.setLocalWork(32, 1, 1); // Set local work size to 32 (for copying to caches) + int fusedQKVGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = new WorkerGrid1D(fusedQKVGlobal); + fusedQKVWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + + // Fused QKV bias worker (covers dimQ which is largest) + WorkerGrid fusedQKVBiasWorkerNorm = new WorkerGrid1D(config.dim()); + fusedQKVBiasWorkerNorm.setGlobalWork(config.dim(), 1, 1); + fusedQKVBiasWorkerNorm.setLocalWork(32, 1, 1); // Map workers to tasks for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qbias", qBiasWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kbias", kvBiasWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vbias", kvBiasWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_qkv_bias", fusedQKVBiasWorkerNorm); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_qkv_bias", fusedQKVBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); } return tornadoForwardScheduler; } @@ -136,15 +135,8 @@ public List getFfnLayerTaskGraphs() { return ffnLayerTaskGraphs; } - /** - * Setup all FFN layers for all transformer layers - */ List setupFFNLayered() { - List ffnGraphs = new ArrayList<>(); - - qwen2State.temp.init(0.0f); - qwen2State.tempFFN.init(0.0f); - + List ffnGraphs = new ArrayList<>(qwen2Config.numberOfLayers()); for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSingleQwen2FFNLayer((Qwen2TornadoWeights) weights, layerIndex); if (layerIndex == qwen2Config.numberOfLayers() - 1) { @@ -152,59 +144,257 @@ List setupFFNLayered() { } ffnGraphs.add(ffnLayer.snapshot()); } - return ffnGraphs; } + // @formatter:off /** - * Setup a single transformer layer for Qwen2 with GQA + * Transformer Layer Task Flow (Qwen2FP16FFNLayers - Optimized) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (scale factor for RMSNorm) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────────────┐ + * │ attn_rms_qkv_projection │──▶ wrapQ, wrapK, wrapV (FP32) + * └───────────┬─────────────┘ (fused: RMS apply + Q/K/V matmuls) + * │ + * ▼ + * ┌────────────────┐ + * │ fused_qkv_bias │──▶ wrapQ, wrapK, wrapV += biases + * └───────┬────────┘ (fused: Q + K + V bias addition) + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ (fused: RoPE rotation + cache write) + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (scale factor) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ ffn_rms_finalize │──▶ tempFFN (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (NVIDIA) / 10 tasks (non-NVIDIA) + * Previous: 12 tasks + * Reduction: 3 tasks eliminated (25% fewer kernel launches) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points (vs previous 12 tasks): + * • fused_qkv_bias: Fused Q + K + V bias addition (3→1 kernel) + * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU + * (eliminates separate mapContextFFN kernel) + * + * Qwen2-Specific: + * • GQA: nHeads (Q) != nHeadKv (K/V), with kvMul = nHeads / nHeadKv + * • Bias terms: Q, K, V projections include bias (unlike Qwen3) + * • No Q/K RMSNorm: Unlike Qwen3, Qwen2 doesn't normalize Q/K after projection + * */ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) { var taskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(taskGraphName); unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - weights.rms_att_weightLayered[layerIndex].asFloatArray(), // - weights.wqLayered[layerIndex].asHalfFloatArray(), // - weights.wkLayered[layerIndex].asHalfFloatArray(), // - weights.wvLayered[layerIndex].asHalfFloatArray(), // - weights.woLayered[layerIndex].asHalfFloatArray(), // - weights.q_biasLayered[layerIndex].asFloatArray(), // - weights.k_biasLayered[layerIndex].asFloatArray(), // - weights.v_biasLayered[layerIndex].asFloatArray(), // - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // - weights.w1Layered[layerIndex].asHalfFloatArray(), // - weights.w2Layered[layerIndex].asHalfFloatArray(), // - weights.w3Layered[layerIndex].asHalfFloatArray()); // - unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); // - - unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.temp, qwen2State.wrapX, config.dim(), config.rmsNormEps(), - qwen2State.localSize) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), - qwen2State.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Attention weights + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + // Qwen2-specific bias terms + weights.q_biasLayered[layerIndex].asFloatArray(), + weights.k_biasLayered[layerIndex].asFloatArray(), + weights.v_biasLayered[layerIndex].asFloatArray(), + // FFN weights + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // ═══════════════════════════════════════════════════════════════════════ + // ATTENTION BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + qwen2State.temp, // output: scale factor + qwen2State.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon + qwen2State.localSize); // local memory size + + // Fused RMS Apply + QKV Projection + unifiedLayer.task("attn_rms_qkv_projection", + Qwen3Kernels::fusedRmsNormQKVMatmul, + context, + qwen2State.wrapX, // input: raw hidden state (FP32) + qwen2State.wrapQ, // output: Q vectors + qwen2State.wrapK, // output: K vectors + qwen2State.wrapV, // output: V vectors + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights + qwen2State.temp, // RMS scale factor from reduction + weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq + weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk + weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv + config.dim(), // input dimension + config.dim(), // Q output dimension + config.kvDim(), // K/V output dimension (GQA) + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Fused Q/K/V Bias Addition (3→1 kernel fusion) + unifiedLayer.task("fused_qkv_bias", + TransformerComputeKernelsLayered::fusedQKvBiasAddition, + context, + qwen2State.wrapQ, // Q (in/out) + qwen2State.wrapK, // K (in/out) + weights.q_biasLayered[layerIndex].asFloatArray(), // Q bias + qwen2State.wrapV, // V (in/out) + weights.k_biasLayered[layerIndex].asFloatArray(), // K bias + weights.v_biasLayered[layerIndex].asFloatArray(), // V bias + config.dim(), // dimQ + config.kvDim()); // dimKV + + // Fused RoPE Rotation + KV Cache Write + unifiedLayer.task("rope_and_kv_cache", + Qwen3Kernels::ropeRotationWithCacheCopy, + context, + qwen2State.positionHolder, // current sequence position + qwen2State.wrapQ, // Q (rotated in-place) + qwen2State.wrapK, // K (rotated in-place) + qwen2State.wrapV, // V (copied to cache) + qwen2State.wrapKeyCache, // key cache (write) + qwen2State.wrapValueCache, // value cache (write) + config.numberOfKeyValueHeads(), // nHeadKv + config.headSize(), // per-head dimension + config.kvDim(), // kvDim + layerIndex, // layer offset + config.contextLength()); // max sequence length + + // Flash Attention + unifiedLayer.task("attention", + Qwen2Kernels::processHeadsFlashAttention, + context, + qwen2State.wrapQ, // query vectors + qwen2State.wrapKeyCache, // key cache + qwen2State.wrapValueCache, // value cache + qwen2State.wrapXb, // output: attention result + config.numberOfHeads(), // nHeads + config.headSize(), // headSize + config.kvDim(), // kvDim + config.kvMul(), // kvMul (nHeads / nHeadKv) + qwen2State.positionHolder, // position + layerIndex, // layer index + config.contextLength()); // context length + + // Output Projection with Residual + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + qwen2State.wrapXb, // input: attention output + qwen2State.wrapX, // output: wrapX += Wo · wrapXb + weights.woLayered[layerIndex].asHalfFloatArray(), // Wo + config.dim(), // input dim + config.dim(), // output dim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // ═══════════════════════════════════════════════════════════════════════ + // FFN BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + qwen2State.tempFFN, // output: scale factor + qwen2State.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon + qwen2State.localSize); // local memory size + + // Final normalization (non-NVIDIA only) + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + qwen2State.tempFFN, // scale factor (in/out) + config.dim(), // dimension + config.rmsNormEps()); // epsilon + } + + // Fused RMS Apply + Gate/Up Projection + SiLU + GLU + // (Replaces mapContextFFN + fusedFeedForwardWithSiLUAndGLUActivation) + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + qwen2State.wrapX, // input: raw hidden state (FP32) + qwen2State.wrapHb, // output: SiLU(x·W1) ⊙ (x·W3) + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + qwen2State.tempFFN, // RMS scale factor + weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 (gate) + weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 (up) + config.dim(), // input dimension + config.hiddenDim(), // hidden dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Down Projection with Residual + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + qwen2State.wrapHb, // input: FFN intermediate + qwen2State.wrapX, // output: wrapX += W2 · wrapHb + weights.w2Layered[layerIndex].asHalfFloatArray(), // W2 (down) + config.hiddenDim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC).task("qbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim()) - .task("kbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapK, weights.k_biasLayered[layerIndex].asFloatArray(), config.kvDim()) - .task("vbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim()) - .task("rope", Qwen3Kernels::ropeRotation, context, qwen2State.positionHolder, qwen2State.wrapQ, qwen2State.wrapK, config.numberOfKeyValueHeads(), config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, qwen2State.wrapKeyCache, qwen2State.wrapK, qwen2State.wrapValueCache, qwen2State.wrapV, qwen2State.positionHolder, - config.kvDim(), layerIndex, config.contextLength()) - .task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context, qwen2State.wrapQ, qwen2State.wrapKeyCache, qwen2State.wrapValueCache, qwen2State.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), qwen2State.positionHolder, layerIndex, config.contextLength()) - .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapXb, qwen2State.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), - config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.tempFFN, qwen2State.wrapX, config.dim(), config.rmsNormEps(), - qwen2State.localSize) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - qwen2State.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen2State.wrapXb, qwen2State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), - weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapHb, qwen2State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), - config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + .persistOnDevice(state.wrapX); return unifiedLayer; } @@ -216,7 +406,8 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye // First layer: Transfer initial data to device (one-time transfer) if (layerIndex == 0) { // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen2State.positionHolder, qwen2State.temp, qwen2State.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen2State.positionHolder, + qwen2State.temp, qwen2State.tempFFN); // unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // context, qwen2State.wrapXb, qwen2State.wrapXb2, // qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, // @@ -234,5 +425,6 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye } return unifiedLayer; } + // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 379921c3..453a4d7c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -14,8 +14,8 @@ import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; import java.util.List; +import java.util.stream.IntStream; /** * Qwen3FP16FFNLayers: FP16 FFN layers for Qwen3 with Group Query Attention (GQA) support. @@ -43,7 +43,7 @@ public class Qwen3FP16FFNLayers extends AbstractFFNLayers { List ffnLayerTaskGraphs; public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { - super(taskGraphName, state, weights, config,schedulerType); + super(taskGraphName, state, weights, config, schedulerType); this.qwen3State = state; this.qwen3Config = config; @@ -61,71 +61,44 @@ public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe @Override public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize); - - // Q matmul worker (GQA: full query heads) - int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid matmulQRowMajorWorker = WorkerGridFactory.genericWorker(matmulQGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - - // KV matmul worker (GQA: reduced KV heads) - int matmulKVGlobal = nEmbdGqa * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid matmulKVRowMajorWorker = WorkerGridFactory.genericWorker(matmulKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - - // Current embedding head worker - WorkerGrid curWorker = WorkerGridFactory.createRmsNormWorker(nEmbdHead, 128); - - // Q current worker - WorkerGrid qCurWorker = WorkerGridFactory.genericWorker(config.numberOfHeads() * nEmbdHead, nEmbdHead); - - // K current worker - WorkerGrid kCurWorker = WorkerGridFactory.genericWorker(config.numberOfKeyValueHeads() * nEmbdHead, nEmbdHead); - - // RoPE worker (2D: heads x embedding_head/2) - int ic = nEmbdHead / 2; WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), nEmbdHead); - - // Copy to cache worker - WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(nEmbdGqa, 128); - // Parallel attention worker WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), nEmbdHead); - - // Matmul1 worker (output projection) + // attn_output_proj worker (output projection) int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC); - // FFN workers int fusedFFNW1W3Global = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid fusedFFNW1W3Worker = WorkerGridFactory.genericWorker(fusedFFNW1W3Global, LOCAL_WORK_GROUP_SIZE_ALLOC); int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid projectionTwoWorker = WorkerGridFactory.genericWorker(projectionTwoGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + int qkRmsNormGroups = config.numberOfHeads() + config.numberOfKeyValueHeads(); + WorkerGrid qkRmsNormWorker = WorkerGridFactory.genericWorker(qkRmsNormGroups * nEmbdHead, nEmbdHead); - // Map workers to tasks for each layer - for (int i = 0; i < config.numberOfLayers(); i++) { - gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - - gridScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker); - - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker); - - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker); - - gridScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); + int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); + int kvDim0 = nEmbdGqa; + int fusedQKVRows = qDim0 + 2 * kvDim0; // Q rows + K rows + V rows + int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker); - gridScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker); + // Map workers to tasks for each layer (in task execution order) + for (int i = 0; i < config.numberOfLayers(); i++) { + // === Attention Block === + gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".qk_rmsnorm", qkRmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker); + // === FFN Block === + gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + if (shouldUseFinalNormalization()) { + gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_finalize", rmsNormWorker); + } + gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker); + gridScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", projectionTwoWorker); } - return gridScheduler; } @@ -152,113 +125,273 @@ public List getFfnLayerTaskGraphs() { * Setup all FFN layers for all transformer layers */ List setupFFNLayered() { - List ffnGraphs = new ArrayList<>(); - qwen3State.temp.init(0.0f); - qwen3State.tempFFN.init(0.0f); - qwen3State.tempQcur.init(0.0f); - qwen3State.tempKcur.init(0.0f); - - for (int layerIndex = 0; layerIndex < qwen3Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, layerIndex); - if (layerIndex == qwen3Config.numberOfLayers() - 1) { + return IntStream.range(0, qwen3Config.numberOfLayers()).mapToObj(i -> { + var ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, i); + if (i == qwen3Config.numberOfLayers() - 1) { setupLastID(ffnLayer.getTaskGraphName()); } - ffnGraphs.add(ffnLayer.snapshot()); - } - - return ffnGraphs; + return ffnLayer.snapshot(); + }).toList(); } + // @formatter:off /** - * Setup a single transformer layer for Qwen3 with GQA + * Transformer Layer Task Flow (Qwen3FP16FFNLayers) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (scale factor for RMSNorm) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────────────┐ + * │ attn_rms_qkv_projection │──▶ wrapQ, wrapK, wrapV (FP32) + * └───────────┬─────────────┘ (fused: RMS apply + Q/K/V matmuls) + * │ + * ▼ + * ┌─────────────┐ + * │ qk_rmsnorm │──▶ wrapQ, wrapK normalized in-place + * └──────┬──────┘ (fused: Q + K RMSNorm reduction + apply) + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ (fused: RoPE rotation + cache write) + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (scale factor) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ ffn_rms_finalize │──▶ tempFFN (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (NVIDIA) / 10 tasks (non-NVIDIA) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points (vs baseline 18 tasks): + * • attn_rms_qkv_projection: Fused RMS apply + Q/K/V matmuls (4→1 kernel) + * • qk_rmsnorm: Fused Q + K RMSNorm (4→1 kernel) + * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel) + * + * Qwen3-Specific: + * • GQA: nHeads (Q) != nHeadKv (K/V), with gqa = nHeads / nHeadKv + * • Q/K RMSNorm: Additional normalization after QKV projection (qk_rmsnorm) + * • RoPE theta: 1,000,000 (vs Llama's 10,000 or 50,000) + * */ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) { var taskGraphName = "layer_" + layerIndex; - TaskGraph unifiedLayer = new TaskGraph(taskGraphName); + + // === Dimension Parameters === + int qDim = nEmbdHeadK * qwen3Config.numberOfHeads(); // Q output size (full heads) + int kvDim = nEmbdGqa; // K/V output size (reduced for GQA) + int inputDim = qwen3Config.dim(); // Model dimension + + var unifiedLayer = new TaskGraph(taskGraphName); + + // === Data Setup === unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - //Copy-in weights per layer for batched-layered layout - weights.rms_att_weightLayered[layerIndex].asFloatArray(), // - weights.wqLayered[layerIndex].asHalfFloatArray(), // - weights.wkLayered[layerIndex].asHalfFloatArray(), // - weights.wvLayered[layerIndex].asHalfFloatArray(), // - weights.woLayered[layerIndex].asHalfFloatArray(), // - //rms_att_KNormLayered - weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // - //rms_att_QNormLayered - weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // - weights.w1Layered[layerIndex].asHalfFloatArray(), // - weights.w2Layered[layerIndex].asHalfFloatArray(), // - weights.w3Layered[layerIndex].asHalfFloatArray() // - ); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Attention weights + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS norm weights + weights.wqLayered[layerIndex].asHalfFloatArray(), // Q projection + weights.wkLayered[layerIndex].asHalfFloatArray(), // K projection + weights.wvLayered[layerIndex].asHalfFloatArray(), // V projection + weights.woLayered[layerIndex].asHalfFloatArray(), // Output projection + // Qwen3-specific Q/K norm weights + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K RMSNorm weights + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q RMSNorm weights + // FFN weights + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // FFN RMS norm weights + weights.w1Layered[layerIndex].asHalfFloatArray(), // FFN gate + weights.w2Layered[layerIndex].asHalfFloatArray(), // FFN down + weights.w3Layered[layerIndex].asHalfFloatArray()); // FFN up unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.temp, qwen3State.wrapX, // in - qwen3Config.dim(), qwen3Config.rmsNormEps(), qwen3State.localSize).task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen3State.wrapXb, // out - qwen3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen3State.temp); - int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); - int kvDim0 = nEmbdGqa; - int qkvDim1 = qwen3Config.dim(); - unifiedLayer.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapQ, // output - weights.wqLayered[layerIndex].asHalfFloatArray(), qkvDim1, qDim0, LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapK, // output - weights.wkLayered[layerIndex].asHalfFloatArray(), qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapV, // output - weights.wvLayered[layerIndex].asHalfFloatArray(), qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC); - - // Qcur rmsnorm - unifiedLayer.task("rmsnormReduction_Qcur", Qwen3Kernels::rmsnormWithParallelOffset, context, qwen3State.tempQcur, // output - qwen3State.wrapQ, // input - qwen3State.localSize, // currently 128, should be variable of global nEmbHead - nEmbdHead, // for normalization - qwen3Config.rmsNormEps()) // for normalization - .task("rmsnormMapIndexInPlace_Qcur", Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, context, qwen3State.wrapQ, // output - weights.rms_att_QNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempQcur); - - // Kcur rmsnorm - unifiedLayer.task("rmsnormReduction_Kcur", Qwen3Kernels::rmsnormWithParallelOffset, context, qwen3State.tempKcur, // output - qwen3State.wrapK, // input - qwen3State.localSize, // currently 128, should be variable of global nEmbHead - nEmbdHead, // for normalization - qwen3Config.rmsNormEps()) // for normalization - .task("rmsnormMapIndexInPlace_Kcur", Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, context, qwen3State.wrapK, // output - weights.rms_att_KNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempKcur); - - // rope rotation task graph - unifiedLayer.task("ropeRotation", Qwen3Kernels::ropeRotation, context, qwen3State.positionHolder, qwen3State.wrapQ, // out - qwen3State.wrapK, // out - qwen3Config.numberOfKeyValueHeads(), nEmbdHead); - - unifiedLayer.task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, qwen3State.wrapKeyCache, // out - qwen3State.wrapK, // in - qwen3State.wrapValueCache, // out - qwen3State.wrapV, // in - qwen3State.positionHolder, nEmbdGqa, layerIndex, qwen3Config.contextLength()); - - unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, context, qwen3State.wrapQ, qwen3State.wrapKeyCache, qwen3State.wrapValueCache, - qwen3State.wrapXb, // out - qwen3Config.numberOfHeads(), nEmbdHead, nEmbdGqa, gqa, qwen3State.positionHolder, layerIndex, qwen3Config.contextLength()); - - unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapXb, // vector - qwen3State.wrapX, // out, should be [1024] - weights.woLayered[layerIndex].asHalfFloatArray(), // matrix - nEmbdHeadK * qwen3Config.numberOfHeads(), // dim1 = 2048 - qwen3Config.dim(), // dim0 = 1024 + // ═══════════════════════════════════════════════════════════════════════ + // ATTENTION BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + qwen3State.temp, // output: scale factor + qwen3State.wrapX, // input: hidden state + qwen3Config.dim(), // dimension + qwen3Config.rmsNormEps(), // epsilon + qwen3State.localSize); // local memory size + + // Fused RMS Apply + QKV Projection + unifiedLayer.task("attn_rms_qkv_projection", + Qwen3Kernels::fusedRmsNormQKVMatmul, + context, + qwen3State.wrapX, // input: raw hidden state (FP32) + qwen3State.wrapQ, // output: Q vectors + qwen3State.wrapK, // output: K vectors + qwen3State.wrapV, // output: V vectors + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights + qwen3State.temp, // RMS scale factor from reduction + weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq [qDim x inputDim] + weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk [kvDim x inputDim] + weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv [kvDim x inputDim] + inputDim, // input dimension + qDim, // Q output dimension + kvDim, // K/V output dimension (GQA: reduced) + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Fused Q/K RMSNorm (Qwen3-specific) + unifiedLayer.task("qk_rmsnorm", + Qwen3Kernels::fusedQKRmsNorm, + context, + qwen3State.wrapQ, // Q vectors (in/out) + qwen3State.wrapK, // K vectors (in/out) + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q norm weights + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K norm weights + qwen3Config.numberOfHeads(), // nHeads (Q heads) + qwen3Config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA) + nEmbdHead, // head dimension + nEmbdHead, // local memory size + qwen3Config.rmsNormEps()); // epsilon + + // Fused RoPE Rotation + KV Cache Write + unifiedLayer.task("rope_and_kv_cache", + Qwen3Kernels::ropeRotationWithCacheCopy, + context, + qwen3State.positionHolder, // current position + qwen3State.wrapQ, // Q vectors (in/out, rotated) + qwen3State.wrapK, // K vectors (in/out, rotated) + qwen3State.wrapV, // V vectors (in only) + qwen3State.wrapKeyCache, // key cache (out) + qwen3State.wrapValueCache, // value cache (out) + qwen3Config.numberOfKeyValueHeads(), // nHeadKv + nEmbdHead, // head dimension + nEmbdGqa, // kvDim + layerIndex, // layer index for cache offset + qwen3Config.contextLength()); // max sequence length + + // Flash Attention + unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + qwen3State.wrapQ, // query vectors + qwen3State.wrapKeyCache, // key cache + qwen3State.wrapValueCache, // value cache + qwen3State.wrapXb, // output: attention result + qwen3Config.numberOfHeads(), // nHeads + nEmbdHead, // headSize + nEmbdGqa, // kvDim + gqa, // kvMul (nHeads / nHeadKv) + qwen3State.positionHolder, // position + layerIndex, // layer index + qwen3Config.contextLength()); // context length + + // Output Projection with Residual + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + qwen3State.wrapXb, // input: attention output + qwen3State.wrapX, // output: wrapX += Wo · wrapXb + weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim x qDim] + nEmbdHeadK * qwen3Config.numberOfHeads(), // input dim (qDim) + qwen3Config.dim(), // output dim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // ═══════════════════════════════════════════════════════════════════════ + // FFN BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + qwen3State.tempFFN, // output: scale factor + qwen3State.wrapX, // input: hidden state + qwen3Config.dim(), // dimension + qwen3Config.rmsNormEps(), // epsilon + qwen3State.localSize); // local memory size + + // Final normalization (non-NVIDIA only) + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + qwen3State.tempFFN, // scale factor (in/out) + qwen3Config.dim(), // dimension + qwen3Config.rmsNormEps()); // epsilon + } + + // Fused RMS Apply + Gate/Up Projection + SiLU + GLU + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + qwen3State.wrapX, // input: raw hidden state (FP32) + qwen3State.wrapHb, // output: SiLU(x·W1) ⊙ (x·W3) + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + qwen3State.tempFFN, // RMS scale factor + weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 (gate) + weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 (up) + qwen3Config.dim(), // input dimension + qwen3Config.hiddenDim(), // hidden dimension LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.tempFFN, qwen3State.wrapX, qwen3Config.dim(), - qwen3Config.rmsNormEps(), qwen3State.localSize) - .task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, qwen3Config.dim(), qwen3Config.rmsNormEps()) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - qwen3State.tempFFN); + // Down Projection with Residual + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + qwen3State.wrapHb, // input: FFN intermediate + qwen3State.wrapX, // output: wrapX += W2 · wrapHb + weights.w2Layered[layerIndex].asHalfFloatArray(), // W2 (down) + qwen3Config.hiddenDim(), // input dim + qwen3Config.dim(), // output dim + LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice(qwen3State.wrapX); - unifiedLayer.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen3State.wrapXb, qwen3State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), - weights.w3Layered[layerIndex].asHalfFloatArray(), qwen3Config.dim(), qwen3Config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapHb, qwen3State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), - qwen3Config.hiddenDim(), qwen3Config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(qwen3State.wrapX); return unifiedLayer; } + // @formatter:on /** * Configure data transfers for first and subsequent layers @@ -266,14 +399,14 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { if (layerIndex == 0) { // First layer: Transfer temporary buffers and QKV state every execution - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.tempQcur, qwen3State.tempKcur); + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.positionHolder); + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.temp, qwen3State.tempFFN); // First execution: allocate workspace buffers unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // context, qwen3State.wrapXb, qwen3State.wrapXb2, // qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, // qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // - qwen3State.wrapAtt, qwen3State.wrapHb); + qwen3State.wrapAtt, qwen3State.wrapHb ); } else { // Subsequent layers: Consume data from previous layer unifiedLayer.consumeFromDevice(context, qwen3State.wrapXb, qwen3State.wrapXb2, // @@ -282,9 +415,8 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye qwen3State.wrapValueCache, qwen3State.wrapAtt, // qwen3State.wrapHb, qwen3State.positionHolder); // - unifiedLayer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); } return unifiedLayer; } - + // @formatter:on } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index 385b626f..ba1b6a79 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -42,25 +42,111 @@ public ImmutableTaskGraph getImmutableTaskGraph() { } List setupFFNLayered() { - state.temp.init(0.0f); - state.tempFFN.init(0.0f); - var numLayers = config.numberOfLayers(); - - return IntStream.range(0, numLayers).mapToObj(i -> { + return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); - if (i == numLayers - 1) { + if (i == config.numberOfLayers() - 1) { setupLastID(ffnLayer.getTaskGraphName()); } return ffnLayer.snapshot(); }).toList(); } + // @formatter:off + /** + * Transformer Layer Task Flow (LlamaQ8FFNLayers) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (partial sums) + * └────────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ attn_rms_finalize│──▶ temp (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌────────────────┐ + * │ attn_rms_apply │──▶ wrapXb (normalized, FP32) + * └───────┬────────┘ + * │ + * ▼ + * ┌────────────────┐ ┌─────────────────────────────┐ + * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │ + * └───────┬────────┘ └─────────────────────────────┘ + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (partial sums) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌─────────────────┐ + * │ ffn_rms_finalize│──▶ tempFFN (final scale) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fully fused: RMS reduce/apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points: + * • qkv_projection: Fused Q/K/V matmuls with Q8 dequantization (3→1 kernel) + * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_gate_up: Fully fused RMS norm + W1/W3 matmuls + SiLU + GLU (5→1 kernel) + * + * Quantization: Q8_0 format (8-bit weights with block-wise scaling) + * + */ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - //Copy-in weights per layer for batched-layered layout + // Copy-in weights per layer for batched-layered layout (Q8 format) weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].asByteArray(), weights.wkLayered[layerIndex].asByteArray(), @@ -71,35 +157,101 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, weights.w2Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapQ, - weights.wqLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapK, - weights.wkLayered[layerIndex].asByteArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapV, - weights.wvLayered[layerIndex].asByteArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), - layerIndex, config.contextLength()); - configureAttention(unifiedLayer, layerIndex); - unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapXb, state.wrapX, - weights.woLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte, context, state.wrapXb, state.wrapHb, - weights.w1Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray(), config.dim(), config.hiddenDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapHb, state.wrapX, - weights.w2Layered[layerIndex].asByteArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + + // === Attention Block === + // RMS Normalization + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, state.wrapXb, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + // QKV Projection (fused with Q8 dequantization) + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulQ8, + context, + state.wrapXb, // input (FP32) + state.wrapQ, // output Q + state.wrapK, // output K + state.wrapV, // output V + weights.wqLayered[layerIndex].asByteArray(), // Wq (Q8) + weights.wkLayered[layerIndex].asByteArray(), // Wk (Q8) + weights.wvLayered[layerIndex].asByteArray(), // Wv (Q8) + config.dim(), // dim + config.kvDim(), // kvDim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // RoPE + KV Cache + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, // Q (in/out) + state.wrapK, // K (in/out) + state.wrapV, // V (in only) + state.wrapKeyCache, // Key cache (out) + state.wrapValueCache, // Value cache (out) + config.kvDim(), + config.headSize(), + layerIndex, + config.contextLength()); + + // Attention + configureAttention(unifiedLayer, layerIndex); + + // Output Projection (Wo) with residual (Q8 dequantization) + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asByteArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // === FFN Block === + // RMS Normalization + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + // Fully fused: RMS apply + Gate/Up projections + SiLU + GLU (Q8 dequantization) + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fullyFusedRmsNormFFNGateUpQ8, + context, + state.wrapX, // raw input (FP32) + state.wrapHb, // output + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + weights.w1Layered[layerIndex].asByteArray(), // W1 (Q8) + weights.w3Layered[layerIndex].asByteArray(), // W3 (Q8) + config.dim(), // input dimension + config.hiddenDim(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Down projection (W2) with residual (Q8 dequantization) + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asByteArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Keep activation X on device for next layer + unifiedLayer.persistOnDevice(state.wrapX); + return unifiedLayer; } @@ -107,15 +259,20 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye // First layer: Transfer initial data to device (one-time transfer) if (layerIndex == 0) { // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, + state.temp, state.tempFFN); // unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - context, state.wrapXb, state.wrapXb2, // + context, + state.wrapXb, state.wrapXb2, // state.wrapQ, state.wrapK, state.wrapV, // state.wrapKeyCache, state.wrapValueCache, // state.wrapAtt, state.wrapHb); // } else { // Subsequent layers: Consume data already on device from previous layer - unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // + unifiedLayer.consumeFromDevice( + context, + state.wrapXb, state.wrapXb2, // state.wrapQ, state.wrapK, state.wrapV, // state.wrapKeyCache, state.wrapValueCache, // state.wrapAtt, state.wrapHb, // @@ -127,35 +284,42 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); + // === Worker Grid Definitions === WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Fused QKV: dim rows for Q + kvDim rows for K + kvDim rows for V + int fusedQkvGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); - WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + // === Per-Layer Grid Assignments (ordered by task graph flow) === for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + // --- Attention Block --- + // RMS Normalization + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + // --- FFN Block --- + // RMS Normalization + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + // Fused RMS + Gate/Up Projections + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + // Down Projection + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); } + return tornadoForwardScheduler; } @@ -165,15 +329,25 @@ public List getFfnLayerTaskGraphs() { private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { if (schedulerType == SchedulerType.NVIDIA) { - return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, - context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), - state.positionHolder, layerIndex, config.contextLength()); + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + state.wrapQ, state.wrapKeyCache, + state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, + config.contextLength()); } else { - return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, - state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), - state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, + state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, + config.contextLength()); } } + // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java index c8c5a753..d54bb3ef 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -2,8 +2,8 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; @@ -13,12 +13,9 @@ import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.SequencedCollection; - public class LogitsQ8_0Layer extends AbstractLayer { private String lastTaskGraphID; @@ -30,7 +27,6 @@ public class LogitsQ8_0Layer extends AbstractLayer { public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { super(taskGraphName, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; - state.tempLogits.init(0.0f); var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); this.schedulerType = schedulerType; @@ -38,39 +34,68 @@ public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Confi @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid logitsRMS; - if (weights instanceof Qwen2TornadoWeights) { - logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); - } else { - logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); - } - + var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - - tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); + tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); return tornadoForwardScheduler; } + // @formatter:off private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { - TaskGraph logits = new TaskGraph("logits"); - logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asByteArray(), - weights.rms_final_weight_as_floatArray) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (schedulerType == SchedulerType.NON_NVIDIA) { - logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); - } - logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) - .task("projection", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, // - context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asByteArray(), // - config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS) // - .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + var logits = new TaskGraph("logits"); + // === Data Setup === + logits.consumeFromDevice(lastTaskGraphID, state.wrapX); + logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); + logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, // + state.wrapLogits, // + weights.wclsByteArray.asByteArray(), // + weights.rms_final_weight_as_floatArray); + + // === Final RMS Normalization === + logits.task("rms_reduce", + TransformerComputeKernels::reductionOneBlockWithLayer, + context, + state.tempLogits, // output: partial sums + final scale factor + state.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon for numerical stability + state.localSize); // local workgroup size + + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.tempLogits, + config.dim(), + config.rmsNormEps()); + } + logits.task("mapContextLogits", + TransformerComputeKernels::reductionOneBlock2WithLogits, + context, + state.wrapX, + weights.rms_final_weight_as_floatArray.asFloatArray(), + state.tempLogits); + + // === Vocabulary vocab_proj === + logits.task("vocab_proj", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, // + context, + state.wrapX, + state.wrapLogits, + weights.wclsByteArray.asByteArray(), + config.dim(), + config.vocabularySize(), + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); + + // === Transfer Results to Host === + logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } + // @formatter:on @Override public GridScheduler getGridScheduler() { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index 441bbece..3588af9e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -69,6 +69,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); WorkerGrid splitGateUpSiLUWorker = WorkerGridFactory.genericWorker(config.hiddenDim(), 128); WorkerGrid splitQKVWorker = WorkerGridFactory.genericWorker(opSize, 128); + for (int i = 0; i < config.numberOfLayers(); i++) { tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", qkvDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);