From ca2b28a60c0181ee31d03149ab6a5b59cf4a21d1 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 3 Dec 2025 17:29:30 +0200 Subject: [PATCH 01/42] Implement FP16 support in TornadoVM by introducing HalfFloat arrays, optimized matrix-vector kernels, and SiLU-GLU activation --- .../gpullama3/inference/state/LlamaState.java | 3 + .../gpullama3/inference/state/State.java | 11 ++ .../tornadovm/TornadoVMMasterPlan.java | 1 + .../kernels/TransformerComputeKernels.java | 10 + .../TransformerComputeKernelsLayered.java | 174 ++++++++++++++++++ .../layers/type/fp16/LogitsFP16Layer.java | 11 +- 6 files changed, 207 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index 9f9fdcdb..bfb5a9de 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -4,6 +4,7 @@ import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -57,6 +58,8 @@ protected StateFields createStateFields(Configuration config) { fields.wrapK = new FloatArray(config.dim()); fields.wrapV = new FloatArray(config.dim()); + fields.wrapXFP16 = new HalfFloatArray(config.dim()); + fields.wrapXbFP16 = new HalfFloatArray(config.dim()); // dim vs kvdim fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index 01d94936..b052fa79 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -3,8 +3,11 @@ import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; +import java.awt.font.TextHitInfo; + /** * Represents the base state structure used during LLM inference. * This class provides a common foundation for handling state-related data and functionalities @@ -57,6 +60,9 @@ public abstract class State { public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM. public final IntArray positionHolder; + + public final HalfFloatArray wrapXbFP16; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage. + // store inter public int localSize; public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size. @@ -64,6 +70,7 @@ public abstract class State { public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size. public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models. + public HalfFloatArray wrapXFP16; /** last index in previous block */ protected State(Configuration config, int batchsize) { @@ -98,6 +105,9 @@ protected State(Configuration config, int batchsize) { this.wrapK = fields.wrapK; this.wrapV = fields.wrapV; + this.wrapXFP16 = fields.wrapXFP16; + this.wrapXbFP16 = fields.wrapXbFP16; + // dim vs kvdim this.wrapKeyCache = fields.wrapKeyCache; this.wrapValueCache = fields.wrapValueCache; @@ -121,6 +131,7 @@ protected static class StateFields { public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache; public IntArray positionHolder; public FloatArray temp, tempFFN, tempLogits; + public HalfFloatArray wrapXFP16, wrapXbFP16; } @Override diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 293d2c0c..b1195c65 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -180,6 +180,7 @@ private int getFinalLogitsGraphIndex() { public void forceCopyInReadOnlyDataLayered() { // Execute all TornadoVM graphs state.wrapX.init(0.0f); + state.wrapXFP16.clear(); state.positionHolder.init(0); // Execute activation update graph diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 7f69e496..9e1b5fb7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -2,7 +2,11 @@ import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +import javax.swing.plaf.PanelUI; public class TransformerComputeKernels { @@ -19,6 +23,12 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) { } } + public static void convertFP32toFP16v2(KernelContext context, FloatArray input, HalfFloatArray output) { + int i = context.globalIdx; + HalfFloat val = new HalfFloat(input.get(i)); + output.set(i,val); + } + /** * Performs RMS (Root Mean Square) normalization using parallel reduction. * This is a two-phase reduction: first within work groups, then across work groups. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index dfe4ef27..fd484644 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -3,6 +3,7 @@ import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.Int8Array; @@ -690,6 +691,32 @@ public static void matrixVectorGeneric( hb.set(rowId, sum); } } + + // @formatter:off + public static void matrixVectorGeneric( + KernelContext context, + HalfFloatArray x, + FloatArray hb, // output + HalfFloatArray w, + int dim1, // inner loop + int dim0, // outer loop + int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= dim0) { + return; + } + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, dim1); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum); + } + } // @formatter:on /** @@ -774,6 +801,26 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex } } + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, HalfFloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= d) { + return; + } + + HalfFloat sum1 = matrixVectorRowMajorOptimizedFHF(context, localWorkGroupSize, x, w1, n); + HalfFloat sum3 = matrixVectorRowMajorOptimizedFHF(context, localWorkGroupSize, x, w3, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float silu = siluActivation(sum1.getFloat32()); // Using the new SiLU method + float result = silu * sum3.getFloat32(); + hb.set(rowId, result); + } + } + /** * Gaussian Error Linear Unit (GELU) activation function. Approximation formula: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³))) * @@ -878,6 +925,133 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc return localSum[0]; } + public static HalfFloat matrixVectorRowMajorOptimizedFHF(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + HalfFloat[] localSum = context.allocateHalfFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; + // HalfFloat partialSum = new HalfFloat(0f); + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + // HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j)); + partialSum += w.get(matrixIdx).getFloat32() * x.get(j).getFloat32(); + // partialSum = HalfFloat.add(partialSum, mul); + } + + + // Store partial sum in local memory + localSum[localId] = new HalfFloat(partialSum); + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] = HalfFloat.add(localSum[localId], localSum[localId + stride]); +// localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + public static float matrixVectorRowMajorOptimizedF(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; +// HalfFloat partialSum = new HalfFloat(0f); + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; +// HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j)); + partialSum += w.get(matrixIdx).getFloat32() * x.get(j).getFloat32(); +// partialSum = HalfFloat.add(partialSum, mul); + } + + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + +public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product - UNROLLED BY 4 + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + + int j = localId; + int stride = localSize; + int stride4 = localSize << 2; // localSize * 4 + int limit = n - (stride * 3); // Safe limit for 4 elements + + // Main loop unrolled by 4 with separate accumulators + for (; j < limit; j += stride4) { + int base = rowOffset + j; + int j1 = j + stride; + int j2 = j + (stride << 1); + int j3 = j + stride * 3; + + sum0 += w.get(base).getFloat32() * x.get(j).getFloat32(); + sum1 += w.get(base + stride).getFloat32() * x.get(j1).getFloat32(); + sum2 += w.get(base + (stride << 1)).getFloat32() * x.get(j2).getFloat32(); + sum3 += w.get(base + stride * 3).getFloat32() * x.get(j3).getFloat32(); + } + + // Handle remainder + for (; j < n; j += stride) { + sum0 += w.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + + // Combine accumulators (tree reduction for better precision) + float partialSum = (sum0 + sum1) + (sum2 + sum3); + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int s = localSize >> 1; s > 0; s >>= 1) { + if (localId < s) { + localSum[localId] += localSum[localId + s]; + } + context.localBarrier(); + } + + return localSum[0]; +} + // Second kernel - Combines partial sums and computes final normalization public static void reductionFinalNormalization(KernelContext context, FloatArray output, int size, float ermsNorm) { int gid = context.globalIdx; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index a674c1c5..81001bba 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -29,6 +29,7 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; state.tempLogits.init(0.0f); + var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); this.schedulerType = schedulerType; @@ -39,15 +40,18 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration */ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { TaskGraph logits = new TaskGraph("logits"); - logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits) + logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits, state.wrapXFP16) .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), weights.rms_final_weight_as_floatArray.asFloatArray()) .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); } logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) - .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), - LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); + .task("dequantizeX", TransformerComputeKernels::convertFP32toFP16v2, context, state.wrapX, state.wrapXFP16) + .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, // + context, state.wrapXFP16, state.wrapLogits, // + weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), // + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); // logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } @@ -65,6 +69,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); + tornadoForwardScheduler.addWorkerGrid("logits.dequantizeX", logitsRMS); tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); From f0411aec4a6af573a316f65e43fa1f5d38973245 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 00:53:17 +0200 Subject: [PATCH 02/42] Introduce matrix-vector kernel with residual addition and enhance FP16 task graph setup --- .../TransformerComputeKernelsLayered.java | 20 ++ .../layers/type/fp16/LlamaFP16FFNLayers.java | 334 +++++++++--------- 2 files changed, 189 insertions(+), 165 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index fd484644..f419e999 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -759,6 +759,26 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA } } + public static void matrixVectorGenericWithResidual(KernelContext context, HalfFloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float result = hb.get(rowId) + sum; + hb.set(rowId, result); + } + } + /** * Fused feed-forward network with SiLU activation and GLU gating. Implements the SwiGLU variant used in LLaMA-style models. * diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 96acd650..83b5e05e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -1,178 +1,182 @@ -package org.beehive.gpullama3.tornadovm.layers.type.fp16; - -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.enums.DataTransferMode; - -import java.util.List; -import java.util.stream.IntStream; - -public class LlamaFP16FFNLayers extends AbstractFFNLayers { - - TaskGraph ffnTaskGraphs; - GridScheduler scheduler; - List ffnLayerTaskGraphs; - public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { - super(taskGraph, state, weights, config, schedulerType); - this.ffnLayerTaskGraphs = setupFFNLayered(); - } + package org.beehive.gpullama3.tornadovm.layers.type.fp16; + + import org.beehive.gpullama3.inference.state.State; + import org.beehive.gpullama3.inference.weights.Weights; + import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; + import org.beehive.gpullama3.model.Configuration; + import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; + import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; + import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; + import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; + import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; + import uk.ac.manchester.tornado.api.GridScheduler; + import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + import uk.ac.manchester.tornado.api.TaskGraph; + import uk.ac.manchester.tornado.api.WorkerGrid; + import uk.ac.manchester.tornado.api.enums.DataTransferMode; + + import java.util.List; + import java.util.stream.IntStream; + + public class LlamaFP16FFNLayers extends AbstractFFNLayers { + + TaskGraph ffnTaskGraphs; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + this.ffnLayerTaskGraphs = setupFFNLayered(); + } - @Override - public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim()/2, 128); - WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); - - int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - - int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); - - WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); - WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); - - // Map workers to tasks - for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim()/2, 128); + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + // Map workers to tasks + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".quantizeXb", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + } + return tornadoForwardScheduler; } - return tornadoForwardScheduler; - } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } - @Override - public TaskGraph getTaskGraph() { - return ffnTaskGraphs; - } + @Override + public TaskGraph getTaskGraph() { + return ffnTaskGraphs; + } - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } - List setupFFNLayered() { - state.temp.init(0.0f); - state.tempFFN.init(0.0f); - var numLayers = config.numberOfLayers(); - - return IntStream.range(0, numLayers) - .mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); - if (i == numLayers - 1) setupLastID(ffnLayer.getTaskGraphName()); - return ffnLayer.snapshot(); - }) - .toList(); - } + List setupFFNLayered() { + state.temp.init(0.0f); + state.tempFFN.init(0.0f); + // state.wrapXbFP16.clear(); - TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { - var layerTaskGraphName = "layer_" + layerIndex; - TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); - unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - weights.rms_att_weightLayered[layerIndex].asFloatArray(), - weights.wqLayered[layerIndex].asHalfFloatArray(), - weights.wkLayered[layerIndex].asHalfFloatArray(), - weights.wvLayered[layerIndex].asHalfFloatArray(), - weights.woLayered[layerIndex].asHalfFloatArray(), - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - weights.w1Layered[layerIndex].asHalfFloatArray(), - weights.w2Layered[layerIndex].asHalfFloatArray(), - weights.w3Layered[layerIndex].asHalfFloatArray()); - unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer - .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), - layerIndex, config.contextLength()); - configureAttention(unifiedLayer, layerIndex); - unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), - weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), - config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); - return unifiedLayer; - } + var numLayers = config.numberOfLayers(); - protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { - // First layer: Transfer initial data to device (one-time transfer) - if (layerIndex == 0) { - // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb); // - } else { - // Subsequent layers: Consume data already on device from previous layer - unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - state.positionHolder // - ); + return IntStream.range(0, numLayers) + .mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); + if (i == numLayers - 1) setupLastID(ffnLayer.getTaskGraphName()); + return ffnLayer.snapshot(); + }) + .toList(); + } + + TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + unifiedLayer + .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) + .task("quantizeXb", TransformerComputeKernels::convertFP32toFP16v2, context, state.wrapXb, state.wrapXbFP16) + .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) + .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), + layerIndex, config.contextLength()); + configureAttention(unifiedLayer, layerIndex); + unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) + .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), + config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + // First layer: Transfer initial data to device (one-time transfer) + if (layerIndex == 0) { + // Transfer all attention-related data: query, key, value matrices and their caches + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder); +// , state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, state.wrapXbFP16 , state.temp, state.tempFFN); // + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + state.positionHolder, state.wrapXbFP16, state.temp, state.tempFFN// + ); + } + return unifiedLayer; } - return unifiedLayer; - } - private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { - if (schedulerType == SchedulerType.NVIDIA) { - return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, - context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), - state.positionHolder, layerIndex, config.contextLength()); - } else { - return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } } } -} From 6334ac382c67b1f0774140a64ef6d4611d4c140f Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 01:22:15 +0200 Subject: [PATCH 03/42] Fused Q/K/V matrix-vector multiplication into a single kernel to reduce overhead, improve cache utilization, and update task graph setup to integrate fused kernel. --- .../TransformerComputeKernelsLayered.java | 184 +++++++++++++++++- .../layers/type/fp16/LlamaFP16FFNLayers.java | 76 +++++--- 2 files changed, 228 insertions(+), 32 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index f419e999..08fe4a28 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -692,6 +692,107 @@ public static void matrixVectorGeneric( } } + /** + * Fused Q/K/V matrix-vector multiplication. + * Reduces kernel launch overhead and improves input vector cache utilization. + * + * Workgroup assignment: + * - rowId [0, dim): Q projection + * - rowId [dim, dim+kvDim): K projection + * - rowId [dim+kvDim, dim+2*kvDim): V projection + */ + public static void fusedQKVMatmulX( + KernelContext context, + HalfFloatArray x, // input vector (FP16) + FloatArray q, // output Q (FP32) + FloatArray k, // output K (FP32) + FloatArray v, // output V (FP32) + HalfFloatArray wq, // Q weight matrix + HalfFloatArray wk, // K weight matrix + HalfFloatArray wv, // V weight matrix + int dim, // model dimension (Q output size) + int kvDim, // KV dimension (K/V output size) + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowId < dim) { + // ========== Q projection ========== + int rowOffset = rowId * dim; + + float partialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partialSum += wq.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + q.set(rowId, localSum[0]); + } + + } else if (rowId < dim + kvDim) { + // ========== K projection ========== + int kRow = rowId - dim; + int rowOffset = kRow * dim; + + float partialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partialSum += wk.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + k.set(kRow, localSum[0]); + } + + } else if (rowId < dim + 2 * kvDim) { + // ========== V projection ========== + int vRow = rowId - dim - kvDim; + int rowOffset = vRow * dim; + + float partialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partialSum += wv.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + v.set(vRow, localSum[0]); + } + } + } + // @formatter:off public static void matrixVectorGeneric( KernelContext context, @@ -1016,7 +1117,88 @@ public static float matrixVectorRowMajorOptimizedF(KernelContext context, int lo return localSum[0]; } -public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, + HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f; + + int stride = localSize; + int stride2 = localSize << 1; + int stride3 = localSize * 3; + int stride4 = localSize << 2; + + // Already coalesced: thread 0 reads idx 0, thread 1 reads idx 1, etc. + int j = localId; + int limit = n - stride3; + + for (; j < limit; j += stride4) { + int base = rowOffset + j; + // Hoist x.get() calls - they're reused across all rows + float x0 = x.get(j).getFloat32(); + float x1 = x.get(j + stride).getFloat32(); + float x2 = x.get(j + stride2).getFloat32(); + float x3 = x.get(j + stride3).getFloat32(); + + sum0 += w.get(base).getFloat32() * x0; + sum1 += w.get(base + stride).getFloat32() * x1; + sum2 += w.get(base + stride2).getFloat32() * x2; + sum3 += w.get(base + stride3).getFloat32() * x3; + } + + for (; j < n; j += stride) { + sum0 += w.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + + localSum[localId] = (sum0 + sum1) + (sum2 + sum3); + context.localBarrier(); + + // Reduction with minimal barriers + for (int s = localSize >> 1; s > 0; s >>= 1) { + if (localId < s) { + localSum[localId] += localSum[localId + s]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + public static void fusedQKVMatmul( + KernelContext context, + HalfFloatArray x, // input (read once!) + FloatArray q, FloatArray k, FloatArray v, // outputs + HalfFloatArray wq, HalfFloatArray wk, HalfFloatArray wv, + int dim, int kvDim, int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Determine which output this workgroup computes + int totalRows = dim + 2 * kvDim; // Q rows + K rows + V rows + + if (rowId < dim) { + // Q projection + float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wq, dim); + if (localId == 0) q.set(rowId, sum); + } else if (rowId < dim + kvDim) { + // K projection + int kRow = rowId - dim; + float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wk, dim); + if (localId == 0) k.set(kRow, sum); + } else { + // V projection + int vRow = rowId - dim - kvDim; + float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wv, dim); + if (localId == 0) v.set(vRow, sum); + } + } + + +public static float matrixVectorRowMajorOptimizedx(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { int rowId = context.groupIdx; int localId = context.localIdx; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 83b5e05e..a3e157f7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -44,11 +44,18 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + + int fusedQKVRows = config.dim() + 2 * config.kvDim(); + int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + // Map workers to tasks for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fusedQKV", fusedQKVWorker); + + // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); +// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); +// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); @@ -84,8 +91,8 @@ public List getFfnLayerTaskGraphs() { } List setupFFNLayered() { - state.temp.init(0.0f); - state.tempFFN.init(0.0f); +// state.temp.init(0.0f); +// state.tempFFN.init(0.0f); // state.wrapXbFP16.clear(); var numLayers = config.numberOfLayers(); @@ -115,33 +122,40 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, weights.w3Layered[layerIndex].asHalfFloatArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); unifiedLayer - .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) - .task("quantizeXb", TransformerComputeKernels::convertFP32toFP16v2, context, state.wrapXb, state.wrapXbFP16) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), - layerIndex, config.contextLength()); - configureAttention(unifiedLayer, layerIndex); - unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), + .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) + .task("quantizeXb", TransformerComputeKernels::convertFP32toFP16v2, context, state.wrapXb, state.wrapXbFP16) + .task("fusedQKV", TransformerComputeKernelsLayered::fusedQKVMatmulX, + context, + state.wrapXbFP16, // input (FP16) + state.wrapQ, // output Q + state.wrapK, // output K + state.wrapV, // output V + weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq + weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk + weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv + config.dim(), // dim + config.kvDim(), // kvDim LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), - weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), - config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); +// .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) +// .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) +// .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) + .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()); + configureAttention(unifiedLayer, layerIndex); + unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) + .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice(state.wrapX); // return unifiedLayer; } From 46218a7a970c43855f881835376b6f20a0e0c595 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 01:36:03 +0200 Subject: [PATCH 04/42] Fuse RoPE rotation and KV cache copy into a single kernel, update task graph to integrate `ropeRotationWithCacheCopy` kernel, and remove redundant kernels (`rope` and `copyToCaches`). --- .../TransformerComputeKernelsLayered.java | 63 +++++++++++++++++++ .../layers/type/fp16/LlamaFP16FFNLayers.java | 23 +++++-- 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index 08fe4a28..345a8a6d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -139,6 +139,69 @@ public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, Float } } + /** + * Fused RoPE rotation with KV cache copy. + * Eliminates separate copyToCaches kernel. + * + * - Rotates Q (full dim) + * - Rotates K and writes directly to keyCache + * - Copies V directly to valueCache (no rotation needed) + */ + public static void ropeRotationWithCacheCopy( + KernelContext context, + IntArray positionHolder, + FloatArray sq, // Q vector (in/out) + FloatArray sk, // K vector (in/out) + FloatArray sv, // V vector (in only) + FloatArray keyCache, // Key cache (out) + FloatArray valueCache, // Value cache (out) + int kvDim, + int headSize, + int layer, + int contextLength) { + + int i = context.globalIdx * 2; + int pos = positionHolder.get(0); + + // Bounds check for Q rotation (Q has dim elements, processed in pairs) + if (i + 1 < sq.getSize()) { + // RoPE frequency calculation + int head_dim = i % headSize; + float freq = 1.0f / TornadoMath.pow(50000.0f, head_dim / (float) headSize); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q + float v0q = sq.get(i); + float v1q = sq.get(i + 1); + sq.set(i, v0q * fcr - v1q * fci); + sq.set(i + 1, v0q * fci + v1q * fcr); + + // Rotate K AND write to cache (only for kvDim elements) + if (i + 1 < kvDim) { + float v0k = sk.get(i); + float v1k = sk.get(i + 1); + float rotated0 = v0k * fcr - v1k * fci; + float rotated1 = v0k * fci + v1k * fcr; + + // Write rotated K back to sk + sk.set(i, rotated0); + sk.set(i + 1, rotated1); + + // Direct cache write (fused - no separate copy kernel!) + int cacheOffset = layer * contextLength * kvDim + pos * kvDim; + keyCache.set(cacheOffset + i, rotated0); + keyCache.set(cacheOffset + i + 1, rotated1); + + // Copy V to cache (V doesn't need rotation) + valueCache.set(cacheOffset + i, sv.get(i)); + valueCache.set(cacheOffset + i + 1, sv.get(i + 1)); + } + } + + } + public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArray v, int dimQ, int dimKV) { int totalSize = dimQ + 2 * dimKV; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index a3e157f7..59e90e3f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -48,6 +48,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) int fusedQKVRows = config.dim() + 2 * config.kvDim(); int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); // Map workers to tasks for (int i = 0; i < config.numberOfLayers(); i++) { @@ -56,7 +57,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); +// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); @@ -66,7 +67,9 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".quantizeXb", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); +// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ropeWithCache", ropeWithCacheWorker); + } return tornadoForwardScheduler; } @@ -144,8 +147,20 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, // .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) // .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) // .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()); +// .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) +// .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()); + .task("ropeWithCache", TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, // Q (in/out) + state.wrapK, // K (in/out) + state.wrapV, // V (in only) + state.wrapKeyCache, // Key cache (out) + state.wrapValueCache, // Value cache (out) + config.kvDim(), + config.headSize(), + layerIndex, + config.contextLength()); configureAttention(unifiedLayer, layerIndex); unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); From b48ec62c52571215cffefd5e018574728682f9d2 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 01:53:51 +0200 Subject: [PATCH 05/42] Add `mapContextWithQuantize` kernel, integrate into task graph, and deprecate `mapContext` and `quantizeXb` --- .../kernels/TransformerComputeKernels.java | 13 +++++++++++++ .../kernels/TransformerComputeKernelsLayered.java | 2 ++ .../layers/type/fp16/LlamaFP16FFNLayers.java | 11 +++++++---- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 9e1b5fb7..58988d17 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -29,6 +29,19 @@ public static void convertFP32toFP16v2(KernelContext context, FloatArray input, output.set(i,val); } + public static void mapContextWithQuantize( + KernelContext context, + HalfFloatArray outputFP16, // Direct FP16 output + FloatArray x, + FloatArray weights, + FloatArray temp) { + + int gid = context.globalIdx; + float ss = temp.get(0); + float result = weights.get(gid) * (ss * x.get(gid)); + outputFP16.set(gid, new HalfFloat(result)); + } + /** * Performs RMS (Root Mean Square) normalization using parallel reduction. * This is a two-phase reduction: first within work groups, then across work groups. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index 345a8a6d..19961bd2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -965,6 +965,8 @@ public static void matrixVectorGenericWithResidual(KernelContext context, HalfFl * @param localWorkGroupSize * Work group size */ + + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 59e90e3f..43b9212e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -62,9 +62,9 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); +// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".quantizeXb", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextQuantized", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); @@ -130,8 +130,11 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, config.dim(), config.rmsNormEps()); } - unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) - .task("quantizeXb", TransformerComputeKernels::convertFP32toFP16v2, context, state.wrapXb, state.wrapXbFP16) + unifiedLayer.task("mapContextQuantized", TransformerComputeKernels::mapContextWithQuantize, + context, state.wrapXbFP16, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) +// unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) +// .task("quantizeXb", TransformerComputeKernels::convertFP32toFP16v2, context, state.wrapXb, state.wrapXbFP16) .task("fusedQKV", TransformerComputeKernelsLayered::fusedQKVMatmulX, context, state.wrapXbFP16, // input (FP16) From 943da78ff7718e98da40299f823f0ca57990d676 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 14:58:08 +0200 Subject: [PATCH 06/42] Refactor logits task graph to optimize kernel setup, update worker grids, and deprecate redundant tasks in FP16 layer. --- .../layers/type/fp16/LogitsFP16Layer.java | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index 81001bba..88ef7162 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -2,8 +2,8 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; @@ -28,7 +28,7 @@ public class LogitsFP16Layer extends AbstractLayer { public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; - state.tempLogits.init(0.0f); + state.tempLogits.clear(); var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); @@ -40,18 +40,20 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration */ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { TaskGraph logits = new TaskGraph("logits"); - logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits, state.wrapXFP16) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), weights.rms_final_weight_as_floatArray.asFloatArray()) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (schedulerType == SchedulerType.NON_NVIDIA) { - logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); - } - logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) - .task("dequantizeX", TransformerComputeKernels::convertFP32toFP16v2, context, state.wrapX, state.wrapXFP16) - .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, // - context, state.wrapXFP16, state.wrapLogits, // - weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), // - LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); // + logits.consumeFromDevice(lastTaskGraphID, state.wrapX) // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, // + state.wrapLogits, state.wrapXbFP16, // + weights.wclsByteArray.asHalfFloatArray(), // + weights.rms_final_weight_as_floatArray.asFloatArray()) // + .task("rms_reduce", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); + } + logits.task("rms_apply_fp16", TransformerComputeKernels::mapContextWithQuantizeLogits, context, state.wrapXbFP16, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) + .task("vocab_proj", TransformerComputeKernelsLayered::matrixVectorGeneric, // + context, state.wrapXbFP16, state.wrapLogits, // + weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), // + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); // logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } @@ -69,10 +71,9 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - tornadoForwardScheduler.addWorkerGrid("logits.dequantizeX", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); + tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS); return tornadoForwardScheduler; } From 386dddcfc044e17efc4f803c450142420863929f Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 16:10:49 +0200 Subject: [PATCH 07/42] Refactor FP16 FFN layers to streamline task graph setup, update worker grid assignments, and enhance attention and FFN block configurations. --- .../layers/type/fp16/LlamaFP16FFNLayers.java | 457 ++++++++++-------- 1 file changed, 265 insertions(+), 192 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 43b9212e..a4b6f922 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -1,214 +1,287 @@ - package org.beehive.gpullama3.tornadovm.layers.type.fp16; - - import org.beehive.gpullama3.inference.state.State; - import org.beehive.gpullama3.inference.weights.Weights; - import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; - import org.beehive.gpullama3.model.Configuration; - import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; - import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; - import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; - import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; - import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; - import uk.ac.manchester.tornado.api.GridScheduler; - import uk.ac.manchester.tornado.api.ImmutableTaskGraph; - import uk.ac.manchester.tornado.api.TaskGraph; - import uk.ac.manchester.tornado.api.WorkerGrid; - import uk.ac.manchester.tornado.api.enums.DataTransferMode; - - import java.util.List; - import java.util.stream.IntStream; - - public class LlamaFP16FFNLayers extends AbstractFFNLayers { - - TaskGraph ffnTaskGraphs; - GridScheduler scheduler; - List ffnLayerTaskGraphs; - public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { - super(taskGraph, state, weights, config, schedulerType); - this.ffnLayerTaskGraphs = setupFFNLayered(); +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +public class LlamaFP16FFNLayers extends AbstractFFNLayers { + + TaskGraph ffnTaskGraphs; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + this.ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + int fusedQKVRows = config.dim() + 2 * config.kvDim(); + int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); + + // Map workers to tasks + for (int i = 0; i < config.numberOfLayers(); i++) { + // === Attention Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + // === FFN Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_apply", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_gate_up", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnTaskGraphs; + } - @Override - public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim()/2, 128); - WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); - - int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - - int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); - - WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); - WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); - - int fusedQKVRows = config.dim() + 2 * config.kvDim(); - int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); - - // Map workers to tasks - for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fusedQKV", fusedQKVWorker); - - // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); -// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); -// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); -// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); -// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextQuantized", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); -// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ropeWithCache", ropeWithCacheWorker); + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + List setupFFNLayered() { + return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); + if (i == config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); } - return tornadoForwardScheduler; - } + return ffnLayer.snapshot(); + }).toList(); + } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } + // @formatter:off + TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); - @Override - public TaskGraph getTaskGraph() { - return ffnTaskGraphs; - } + // === Data Setup === + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } + // === Attention Block === + // RMS Normalization + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); } - List setupFFNLayered() { -// state.temp.init(0.0f); -// state.tempFFN.init(0.0f); - // state.wrapXbFP16.clear(); + unifiedLayer.task("attn_rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantize, + context, state.wrapXbFP16, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + // QKV Projection (fused) + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulX, + context, + state.wrapXbFP16, // input (FP16) + state.wrapQ, // output Q + state.wrapK, // output K + state.wrapV, // output V + weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq + weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk + weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv + config.dim(), // dim + config.kvDim(), // kvDim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // RoPE + KV Cache + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, // Q (in/out) + state.wrapK, // K (in/out) + state.wrapV, // V (in only) + state.wrapKeyCache, // Key cache (out) + state.wrapValueCache, // Value cache (out) + config.kvDim(), + config.headSize(), + layerIndex, + config.contextLength()); + + // Attention + configureAttention(unifiedLayer, layerIndex); - var numLayers = config.numberOfLayers(); + // Output Projection (Wo) with residual + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asHalfFloatArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); - return IntStream.range(0, numLayers) - .mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); - if (i == numLayers - 1) setupLastID(ffnLayer.getTaskGraphName()); - return ffnLayer.snapshot(); - }) - .toList(); + // === FFN Block === + // RMS Normalization + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); } - TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { - var layerTaskGraphName = "layer_" + layerIndex; - TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); - unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.task("ffn_rms_apply", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, state.wrapXb, state.wrapX, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN); + + // Gate + Up projection with SiLU activation (W1, W3) + unifiedLayer.task("ffn_gate_up", + TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, + context, state.wrapXb, state.wrapHb, + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), + config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Down projection (W2) with residual + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asHalfFloatArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + // First layer: Transfer initial data to device (one-time transfer) + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - weights.rms_att_weightLayered[layerIndex].asFloatArray(), - weights.wqLayered[layerIndex].asHalfFloatArray(), - weights.wkLayered[layerIndex].asHalfFloatArray(), - weights.wvLayered[layerIndex].asHalfFloatArray(), - weights.woLayered[layerIndex].asHalfFloatArray(), - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - weights.w1Layered[layerIndex].asHalfFloatArray(), - weights.w2Layered[layerIndex].asHalfFloatArray(), - weights.w3Layered[layerIndex].asHalfFloatArray()); - unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer - .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContextQuantized", TransformerComputeKernels::mapContextWithQuantize, - context, state.wrapXbFP16, state.wrapX, - weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) -// unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) -// .task("quantizeXb", TransformerComputeKernels::convertFP32toFP16v2, context, state.wrapXb, state.wrapXbFP16) - .task("fusedQKV", TransformerComputeKernelsLayered::fusedQKVMatmulX, - context, - state.wrapXbFP16, // input (FP16) - state.wrapQ, // output Q - state.wrapK, // output K - state.wrapV, // output V - weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq - weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk - weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv - config.dim(), // dim - config.kvDim(), // kvDim - LOCAL_WORK_GROUP_SIZE_ALLOC) -// .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) -// .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) -// .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXbFP16, state.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) -// .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) -// .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()); - .task("ropeWithCache", TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + // Kernel context context, - state.positionHolder, - state.wrapQ, // Q (in/out) - state.wrapK, // K (in/out) - state.wrapV, // V (in only) - state.wrapKeyCache, // Key cache (out) - state.wrapValueCache, // Value cache (out) + // Intermediate buffers + state.wrapXb, state.wrapXb2, + // QKV vectors + state.wrapQ, state.wrapK, state.wrapV, + // KV cache + state.wrapKeyCache, state.wrapValueCache, + // Attention & FFN buffers + state.wrapAtt, state.wrapHb, state.wrapXbFP16, + // Reduction temporaries + state.temp, state.tempFFN); + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice( + // Kernel context + context, + // Intermediate buffers + state.wrapXb, state.wrapXb2, + // QKV vectors + state.wrapQ, state.wrapK, state.wrapV, + // KV cache + state.wrapKeyCache, state.wrapValueCache, + // Attention & FFN buffers + state.wrapAtt, state.wrapHb, + // Position & misc + state.positionHolder, state.wrapXbFP16, + // Reduction temporaries + state.temp, state.tempFFN); + } + return unifiedLayer; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + // Flash Attention (optimized for NVIDIA GPUs) + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), config.kvDim(), + config.kvMul(), + state.positionHolder, + layerIndex, + config.contextLength()); + } else { + // Standard parallel attention (for non-NVIDIA backends) + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), config.headSize(), + config.kvDim(), + config.kvMul(), + config.contextLength(), // seqLen parameter + state.positionHolder, + state.wrapAtt, // Attention weights buffer layerIndex, config.contextLength()); - configureAttention(unifiedLayer, layerIndex); - unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice(state.wrapX); // - return unifiedLayer; - } - - protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { - // First layer: Transfer initial data to device (one-time transfer) - if (layerIndex == 0) { - // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder); -// , state.temp, state.tempFFN); // - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, state.wrapXbFP16 , state.temp, state.tempFFN); // - } else { - // Subsequent layers: Consume data already on device from previous layer - unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - state.positionHolder, state.wrapXbFP16, state.temp, state.tempFFN// - ); - } - return unifiedLayer; - } - - private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { - if (schedulerType == SchedulerType.NVIDIA) { - return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, - context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), - state.positionHolder, layerIndex, config.contextLength()); - } else { - return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); - } } } + // @formatter:on + +} From b202bb41eb59caf1013905ac98c48655f00ec177 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 16:11:44 +0200 Subject: [PATCH 08/42] Refactor FP16 FFN layers to streamline task graph setup, update worker grid assignments, and enhance attention and FFN block configurations. --- .../tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index a4b6f922..a8e60a76 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -159,10 +159,8 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, config.headSize(), layerIndex, config.contextLength()); - // Attention configureAttention(unifiedLayer, layerIndex); - // Output Projection (Wo) with residual unifiedLayer.task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, From 3eba3b3644ae3f8570db2ff5a69c4cb712714b25 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 16:14:31 +0200 Subject: [PATCH 09/42] Refactor `LogitsFP16Layer` task graph to improve readability, optimize kernel setup, and enhance FP16 task processing. --- .../layers/type/fp16/LogitsFP16Layer.java | 65 +++++++++++++++---- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index 88ef7162..3ed5e444 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -35,29 +35,70 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration this.schedulerType = schedulerType; } + /** * Builds the logits computation graph. */ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { TaskGraph logits = new TaskGraph("logits"); - logits.consumeFromDevice(lastTaskGraphID, state.wrapX) // - .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, // - state.wrapLogits, state.wrapXbFP16, // - weights.wclsByteArray.asHalfFloatArray(), // - weights.rms_final_weight_as_floatArray.asFloatArray()) // - .task("rms_reduce", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + + // === Data Setup === + logits.consumeFromDevice(lastTaskGraphID, state.wrapX); + logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Kernel context + context, + // Output buffer + state.wrapLogits, + // Intermediate FP16 buffer + state.wrapXbFP16, + // Weights + weights.wclsByteArray.asHalfFloatArray(), + weights.rms_final_weight_as_floatArray.asFloatArray()); + + // === Final RMS Normalization === + logits.task("rms_reduce", + TransformerComputeKernels::reductionOneBlockWithLayer, + context, + state.tempLogits, // output: partial sums + final scale factor + state.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon for numerical stability + state.localSize); // local workgroup size + if (schedulerType == SchedulerType.NON_NVIDIA) { - logits.task("rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); + logits.task("rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.tempLogits, // in/out: combines partial sums + config.dim(), // dimension + config.rmsNormEps()); // epsilon } - logits.task("rms_apply_fp16", TransformerComputeKernels::mapContextWithQuantizeLogits, context, state.wrapXbFP16, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) - .task("vocab_proj", TransformerComputeKernelsLayered::matrixVectorGeneric, // - context, state.wrapXbFP16, state.wrapLogits, // - weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), // - LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); // + + logits.task("rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantizeLogits, + context, + state.wrapXbFP16, // output: normalized (FP16) + state.wrapX, // input: hidden state + weights.rms_final_weight_as_floatArray.asFloatArray(), // RMS weights + state.tempLogits); // scale factor from reduction + + // === Vocabulary Projection === + logits.task("vocab_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, + state.wrapXbFP16, // input (FP16) + state.wrapLogits, // output + weights.wclsByteArray.asHalfFloatArray(), // vocabulary weights + config.dim(), // input dimension + config.vocabularySize(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); + + // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } + @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { WorkerGrid logitsRMS; From 2e010b1dd226cb66c598620a6ddd5a8a5a40b0ae Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 16:15:40 +0200 Subject: [PATCH 10/42] Add `fusedFeedForwardWithSiLUAndGLUActivation` kernel for HalfFloat arrays and `mapContextWithQuantizeLogits` kernel, enhancing FP16 computation capabilities --- .../kernels/TransformerComputeKernels.java | 8 +++++++ .../TransformerComputeKernelsLayered.java | 21 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 58988d17..871673e6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -117,4 +117,12 @@ public static void reductionOneBlock2WithLogits(KernelContext context, FloatArra output.set(gid, weights.get(gid) * (ss * output.get(gid))); } + public static void mapContextWithQuantizeLogits(KernelContext context, HalfFloatArray output, FloatArray input, FloatArray weights, FloatArray temp) { + int gid = context.globalIdx; + float ss = temp.get(0); + float in = ss * input.get(gid); + float interim = weights.get(gid) * in; + output.set(gid, new HalfFloat(interim)); + } + } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index 19961bd2..00e5358f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -967,6 +967,27 @@ public static void matrixVectorGenericWithResidual(KernelContext context, HalfFl */ + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, HalfFloatArray x, HalfFloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= d) { + return; + } + + float sum1 = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w1, n); + float sum3 = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w3, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float silu = siluActivation(sum1); // Using the new SiLU method + float result = silu * sum3; + hb.set(rowId, new HalfFloat(result)); + } + } + + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; From 4aef300287c72739ed920c7b8df24936bf311023 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 16:57:02 +0200 Subject: [PATCH 11/42] Document Transformer Layer Task Flow for `LlamaFP16FFNLayers` with detailed data flow, task breakdown, and fusion points --- .../layers/type/fp16/LlamaFP16FFNLayers.java | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index a8e60a76..9cd7c7bf 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -94,6 +94,96 @@ List setupFFNLayered() { } // @formatter:off + /** + * Transformer Layer Task Flow (LlamaFP16FFNLayers) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (partial sums) + * └────────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ attn_rms_finalize│──▶ temp (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌─────────────────────┐ + * │ attn_rms_apply_fp16 │──▶ wrapXbFP16 (normalized, FP16) + * └──────────┬──────────┘ + * │ + * ▼ + * ┌────────────────┐ ┌─────────────────────────────┐ + * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │ + * └───────┬────────┘ └─────────────────────────────┘ + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (partial sums) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌─────────────────┐ + * │ ffn_rms_finalize│──▶ tempFFN (final scale) + * └────────┬────────┘ + * │ + * ▼ + * ┌───────────────┐ + * │ ffn_rms_apply │──▶ wrapXb (normalized, FP32) + * └───────┬───────┘ + * │ + * ▼ + * ┌─────────────┐ + * │ ffn_gate_up │──▶ wrapHb = SiLU(xb·W1) ⊙ (xb·W3) + * └──────┬──────┘ + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 10 tasks (8 if NVIDIA, skipping rms_finalize steps) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points: + * • qkv_projection: Fused Q/K/V matmuls (3→1 kernel) + * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) + * • ffn_gate_up: Fused W1/W3 matmuls + SiLU + GLU (3→1 kernel) + * + */ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); From 177ec9d273aec6af93a79edc4161438f97f6d47c Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 17:14:52 +0200 Subject: [PATCH 12/42] Set default profiler dump directory relative to `LLAMA_ROOT` when not provided --- llama-tornado | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llama-tornado b/llama-tornado index 9c0d6ba8..c98090f8 100755 --- a/llama-tornado +++ b/llama-tornado @@ -422,7 +422,7 @@ def create_parser() -> argparse.ArgumentParser: ) debug_group.add_argument( "--profiler-dump-dir", - default="/home/mikepapadim/repos/gpu-llama3.java/prof.json", + default=None, help="Directory for profiler output", ) @@ -498,6 +498,11 @@ def main(): parser = create_parser() args = parser.parse_args() + # Set default profiler log path relative to LLAMA_ROOT + if args.profiler_dump_dir is None: + llama_root = os.environ.get("LLAMA_ROOT") + args.profiler_dump_dir = os.path.join(llama_root, "profiler-log.json") + # Set default seed if not provided if args.seed is None: args.seed = int(time.time()) From a1c94fb2404a75c6a019eb0fad985130543ce318 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 17:38:02 +0200 Subject: [PATCH 13/42] Add `fusedRmsNormFFNGateUp` kernel and update FP16 FFN task graph to incorporate fused RMS normalization, gate, and up-projection --- .../TransformerComputeKernelsLayered.java | 71 +++++++++++++++++++ .../layers/type/fp16/LlamaFP16FFNLayers.java | 47 ++++++++---- 2 files changed, 103 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index 00e5358f..1c79d731 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -17,6 +17,77 @@ public class TransformerComputeKernelsLayered { public TransformerComputeKernelsLayered() { } + public static void fusedRmsNormFFNGateUp( + KernelContext context, + FloatArray x, // raw input (FP32) + FloatArray hb, // output + FloatArray rmsWeights, // RMS norm weights + FloatArray rmsScale, // temp[0] = scale factor + HalfFloatArray w1, + HalfFloatArray w3, + int dim, // input dimension + int hiddenDim, // output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= hiddenDim) { + return; + } + + float scale = rmsScale.get(0); + + // Allocate shared memory for normalized input (reused for both W1 and W3) + float[] xNorm = context.allocateFloatLocalArray(localWorkGroupSize); + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + int rowOffsetW1 = rowId * dim; + int rowOffsetW3 = rowId * dim; + + // === W1 matmul with inline normalization === + float sum1 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + sum1 += w1.get(rowOffsetW1 + j).getFloat32() * normalized; + } + + localSum[localId] = sum1; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float result1 = localSum[0]; + + // === W3 matmul with inline normalization (same computation) === + float sum3 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + sum3 += w3.get(rowOffsetW3 + j).getFloat32() * normalized; + } + + localSum[localId] = sum3; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float result3 = localSum[0]; + + // === SiLU + GLU === + if (localId == 0) { + float silu = result1 / (1.0f + TornadoMath.exp(-result1)); + hb.set(rowId, silu * result3); + } + } + /** * Performs RMS (Root Mean Square) normalization using parallel reduction. This is the first phase of RMS normalization that computes the variance and scaling factor across all work groups. * diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 9cd7c7bf..c1dcd939 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -57,8 +57,12 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); // === FFN Block === tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_apply", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_gate_up", configHiddenDimRowMajorWorker); +// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_apply", rmsNormWorker); +// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_gate_up", configHiddenDimRowMajorWorker); + + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); } return tornadoForwardScheduler; @@ -224,7 +228,7 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, unifiedLayer.task("qkv_projection", TransformerComputeKernelsLayered::fusedQKVMatmulX, context, - state.wrapXbFP16, // input (FP16) + state.wrapXbFP16, // input (FP32) state.wrapQ, // output Q state.wrapK, // output K state.wrapV, // output V @@ -271,18 +275,31 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, context, state.tempFFN, config.dim(), config.rmsNormEps()); } - unifiedLayer.task("ffn_rms_apply", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, - context, state.wrapXb, state.wrapX, - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN); - - // Gate + Up projection with SiLU activation (W1, W3) - unifiedLayer.task("ffn_gate_up", - TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, - context, state.wrapXb, state.wrapHb, - weights.w1Layered[layerIndex].asHalfFloatArray(), - weights.w3Layered[layerIndex].asHalfFloatArray(), - config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); +// unifiedLayer.task("ffn_rms_apply", +// TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, +// context, state.wrapXb, state.wrapX, +// weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN); +// +// // Gate + Up projection with SiLU activation (W1, W3) +// unifiedLayer.task("ffn_gate_up", +// TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, +// context, state.wrapXb, state.wrapHb, +// weights.w1Layered[layerIndex].asHalfFloatArray(), +// weights.w3Layered[layerIndex].asHalfFloatArray(), +// config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + state.wrapX, // raw input (FP32) + state.wrapHb, // output + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + state.tempFFN, // RMS scale factor + weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 + weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 + config.dim(), // input dimension + config.hiddenDim(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); // Down projection (W2) with residual unifiedLayer.task("ffn_down_proj", From 577b6b1503b961145f36d47f80357c57467fee7c Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 17:54:50 +0200 Subject: [PATCH 14/42] Increase `BLOCK_SIZE_C` to 16 for Transformer kernel and update FP16 FFN task graphs by removing deprecated tasks, consolidating RMS normalization and FFN operations into `rms_ffn_gate_up`. --- .../TransformerComputeKernelsLayered.java | 2 +- .../layers/type/fp16/LlamaFP16FFNLayers.java | 39 ++++--------------- 2 files changed, 9 insertions(+), 32 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index 1c79d731..6af9ac22 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -485,7 +485,7 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray int pos = positionHolder.get(0); int loff = layer * contextLength * kvDim; int kvHeadIdx = h / kvMul; - int BLOCK_SIZE_C = 8; + int BLOCK_SIZE_C = 16; // Allocate shared memory for tiled computation float[] q_shared = context.allocateFloatLocalArray(headSize); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index c1dcd939..91bf2499 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -57,12 +57,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); // === FFN Block === tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); -// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_apply", rmsNormWorker); -// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_gate_up", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); - - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); } return tornadoForwardScheduler; @@ -157,16 +152,11 @@ List setupFFNLayered() { * └────────┬────────┘ * │ * ▼ - * ┌───────────────┐ - * │ ffn_rms_apply │──▶ wrapXb (normalized, FP32) - * └───────┬───────┘ - * │ - * ▼ - * ┌─────────────┐ - * │ ffn_gate_up │──▶ wrapHb = SiLU(xb·W1) ⊙ (xb·W3) - * └──────┬──────┘ - * │ - * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ * ┌──────────────┐ * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) * └──────┬───────┘ @@ -176,16 +166,16 @@ List setupFFNLayered() { * * ══════════════════════════════════════════════════════════════════════════════ * - * Task Count: 10 tasks (8 if NVIDIA, skipping rms_finalize steps) + * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps) * * Data Flow Summary: * Input: wrapX (FP32) - hidden state from previous layer * Output: wrapX (FP32) - updated hidden state with residual connections * * Key Fusion Points: - * • qkv_projection: Fused Q/K/V matmuls (3→1 kernel) + * • qkv_projection: Fused Q/K/V matmuls (3→1 kernel) * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) - * • ffn_gate_up: Fused W1/W3 matmuls + SiLU + GLU (3→1 kernel) + * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel) * */ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { @@ -275,19 +265,6 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, context, state.tempFFN, config.dim(), config.rmsNormEps()); } -// unifiedLayer.task("ffn_rms_apply", -// TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, -// context, state.wrapXb, state.wrapX, -// weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN); -// -// // Gate + Up projection with SiLU activation (W1, W3) -// unifiedLayer.task("ffn_gate_up", -// TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, -// context, state.wrapXb, state.wrapHb, -// weights.w1Layered[layerIndex].asHalfFloatArray(), -// weights.w3Layered[layerIndex].asHalfFloatArray(), -// config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.task("rms_ffn_gate_up", TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, context, From d5c1206ed0ece127a7769074857dae52c73eaade Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 18:11:02 +0200 Subject: [PATCH 15/42] Increase `ropeWithCacheWorker` local work group size to 512 in FP16 FFN layers to optimize worker grid configuration. --- .../tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 91bf2499..b0634dbb 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -44,7 +44,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) int fusedQKVRows = config.dim() + 2 * config.kvDim(); int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); // Map workers to tasks for (int i = 0; i < config.numberOfLayers(); i++) { From f91108c8cac92db47348269b6ea3ec83aab053a0 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 19:34:25 +0200 Subject: [PATCH 16/42] Add fused kernels for Qwen3: `ropeRotationWithCacheCopy`, `fusedQKVMatmul`, and `fusedRmsNormQKVMatmul`. Refactor workers and task graphs to utilize new computations and streamline layer configurations for improved performance and reduced memory transfers. --- .../gpullama3/inference/state/Qwen3State.java | 3 + .../tornadovm/kernels/Qwen3Kernels.java | 272 ++++++++++++++++++ .../layers/type/fp16/Qwen3FP16FFNLayers.java | 127 ++++---- 3 files changed, 349 insertions(+), 53 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java index d6a6d087..870462e7 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -67,6 +68,8 @@ protected StateFields createStateFields(Configuration configuration) { // TornadoVM wrappers with Qwen3-specific sizes fields.wrapX = new FloatArray(config.dim()); fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads()); + fields.wrapXbFP16 = new HalfFloatArray(nEmbdHeadK * config.numberOfHeads()); + fields.wrapXb2 = new FloatArray(config.dim()); fields.wrapHb = new FloatArray(config.hiddenDim()); fields.wrapHb2 = new FloatArray(config.hiddenDim()); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java index 930e1774..f018fd7c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java @@ -4,6 +4,7 @@ import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; // @formatter:off @@ -292,5 +293,276 @@ private static void processHeadTornado( } } + /** + * Fused RoPE rotation with KV cache copy for Qwen3. + * Combines ropeRotation + copyToCache into a single kernel. + */ + public static void ropeRotationWithCacheCopy( + KernelContext context, + IntArray positionHolder, + FloatArray q, // Q vector (in/out) + FloatArray k, // K vector (in/out) + FloatArray v, // V vector (in only) + FloatArray keyCache, // Key cache (out) + FloatArray valueCache, // Value cache (out) + int numberOfKeyValueHeads, + int nEmbdHead, + int nEmbdGqa, + int layer, + int contextLength) { + + int h = context.globalIdx; + int ic = context.globalIdy; + + int pos = positionHolder.get(0); + int rotn = h < numberOfKeyValueHeads ? 2 : 1; + int poffset = h * nEmbdHead; + int nComplEmbdHead = nEmbdHead / 2; + + // Compute RoPE frequencies for Qwen3 (theta = 1000000.0f) + float theta = 1000000.0f; + int i = ic * 2; + float freq = 1.0f / TornadoMath.pow(theta, (float) i / (float) nEmbdHead); + + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q (all heads) + float v0q = q.get(poffset + ic); + float v1q = q.get(poffset + ic + nComplEmbdHead); + q.set(poffset + ic, v0q * fcr - v1q * fci); + q.set(poffset + ic + nComplEmbdHead, v0q * fci + v1q * fcr); + + // Rotate K and copy K/V to cache (only for KV heads) + if (rotn > 1 && (poffset + ic + nComplEmbdHead) < k.getSize()) { + float v0k = k.get(poffset + ic); + float v1k = k.get(poffset + ic + nComplEmbdHead); + float rotatedK0 = v0k * fcr - v1k * fci; + float rotatedK1 = v0k * fci + v1k * fcr; + + // Write rotated K back + k.set(poffset + ic, rotatedK0); + k.set(poffset + ic + nComplEmbdHead, rotatedK1); + + // Direct cache write (fused - no separate copy kernel!) + int cacheOffset = layer * contextLength * nEmbdGqa + pos * nEmbdGqa; + int kvIdx = h * nEmbdHead; + + keyCache.set(cacheOffset + kvIdx + ic, rotatedK0); + keyCache.set(cacheOffset + kvIdx + ic + nComplEmbdHead, rotatedK1); + + // Copy V to cache (V doesn't need rotation) + valueCache.set(cacheOffset + kvIdx + ic, v.get(poffset + ic)); + valueCache.set(cacheOffset + kvIdx + ic + nComplEmbdHead, v.get(poffset + ic + nComplEmbdHead)); + } + } + + /** + * Fused Q/K/V matrix-vector multiplication for Qwen3 GQA. + * Q has full head dimension, K/V have reduced KV head dimension. + * + * Workgroup assignment: + * - rowId [0, qDim): Q projection + * - rowId [qDim, qDim+kvDim): K projection + * - rowId [qDim+kvDim, qDim+2*kvDim): V projection + */ + public static void fusedQKVMatmul( + KernelContext context, + FloatArray x, // input vector + FloatArray q, // output Q + FloatArray k, // output K + FloatArray v, // output V + HalfFloatArray wq, // Q weight matrix + HalfFloatArray wk, // K weight matrix + HalfFloatArray wv, // V weight matrix + int inputDim, // input dimension (config.dim()) + int qDim, // Q output dimension + int kvDim, // KV output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowId < qDim) { + // ========== Q projection ========== + int rowOffset = rowId * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partialSum += wq.get(rowOffset + j).getFloat32() * x.get(j); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + q.set(rowId, localSum[0]); + } + + } else if (rowId < qDim + kvDim) { + // ========== K projection ========== + int kRow = rowId - qDim; + int rowOffset = kRow * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partialSum += wk.get(rowOffset + j).getFloat32() * x.get(j); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + k.set(kRow, localSum[0]); + } + + } else if (rowId < qDim + 2 * kvDim) { + // ========== V projection ========== + int vRow = rowId - qDim - kvDim; + int rowOffset = vRow * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partialSum += wv.get(rowOffset + j).getFloat32() * x.get(j); + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + v.set(vRow, localSum[0]); + } + } + } + + /** + * Fused RMSNorm apply + Q/K/V projection for Qwen3 GQA. + * Eliminates intermediate wrapXb buffer write/read. + */ + public static void fusedRmsNormQKVMatmul( + KernelContext context, + FloatArray x, // raw input (FP32) + FloatArray q, // output Q + FloatArray k, // output K + FloatArray v, // output V + FloatArray rmsWeights, // RMS norm weights + FloatArray rmsScale, // temp[0] = scale factor + HalfFloatArray wq, // Q weight matrix + HalfFloatArray wk, // K weight matrix + HalfFloatArray wv, // V weight matrix + int inputDim, // input dimension (config.dim()) + int qDim, // Q output dimension + int kvDim, // KV output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + float scale = rmsScale.get(0); + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowId < qDim) { + // ========== Q projection with inline normalization ========== + int rowOffset = rowId * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum += wq.get(rowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + q.set(rowId, localSum[0]); + } + + } else if (rowId < qDim + kvDim) { + // ========== K projection with inline normalization ========== + int kRow = rowId - qDim; + int rowOffset = kRow * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum += wk.get(rowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + k.set(kRow, localSum[0]); + } + + } else if (rowId < qDim + 2 * kvDim) { + // ========== V projection with inline normalization ========== + int vRow = rowId - qDim - kvDim; + int rowOffset = vRow * inputDim; + + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum += wv.get(rowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + v.set(vRow, localSum[0]); + } + } + } + } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 379921c3..e466cee1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -70,9 +70,6 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int matmulKVGlobal = nEmbdGqa * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid matmulKVRowMajorWorker = WorkerGridFactory.genericWorker(matmulKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - // Current embedding head worker - WorkerGrid curWorker = WorkerGridFactory.createRmsNormWorker(nEmbdHead, 128); - // Q current worker WorkerGrid qCurWorker = WorkerGridFactory.genericWorker(config.numberOfHeads() * nEmbdHead, nEmbdHead); @@ -80,12 +77,8 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { WorkerGrid kCurWorker = WorkerGridFactory.genericWorker(config.numberOfKeyValueHeads() * nEmbdHead, nEmbdHead); // RoPE worker (2D: heads x embedding_head/2) - int ic = nEmbdHead / 2; WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), nEmbdHead); - // Copy to cache worker - WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(nEmbdGqa, 128); - // Parallel attention worker WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), nEmbdHead); @@ -100,6 +93,14 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid projectionTwoWorker = WorkerGridFactory.genericWorker(projectionTwoGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); + int kvDim0 = nEmbdGqa; + // Add this 1: + int fusedQKVRows = qDim0 + 2 * kvDim0; // Q rows + K rows + V rows + int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + // Map workers to tasks for each layer for (int i = 0; i < config.numberOfLayers(); i++) { gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); @@ -109,21 +110,26 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { gridScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker); gridScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker); gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker); gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker); gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + + gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); + + gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker); + + gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); - gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker); - gridScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker); +// gridScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker); + gridScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", projectionTwoWorker); } return gridScheduler; @@ -153,11 +159,6 @@ public List getFfnLayerTaskGraphs() { */ List setupFFNLayered() { List ffnGraphs = new ArrayList<>(); - qwen3State.temp.init(0.0f); - qwen3State.tempFFN.init(0.0f); - qwen3State.tempQcur.init(0.0f); - qwen3State.tempKcur.init(0.0f); - for (int layerIndex = 0; layerIndex < qwen3Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, layerIndex); if (layerIndex == qwen3Config.numberOfLayers() - 1) { @@ -174,6 +175,9 @@ List setupFFNLayered() { */ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) { var taskGraphName = "layer_" + layerIndex; + int qDim = nEmbdHeadK * qwen3Config.numberOfHeads(); // Q output size + int kvDim = nEmbdGqa; // K/V output size + int qkvDim1 = qwen3Config.dim(); TaskGraph unifiedLayer = new TaskGraph(taskGraphName); unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // @@ -194,18 +198,23 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) ); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.temp, qwen3State.wrapX, // in - qwen3Config.dim(), qwen3Config.rmsNormEps(), qwen3State.localSize).task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen3State.wrapXb, // out - qwen3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen3State.temp); - - int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); - int kvDim0 = nEmbdGqa; - int qkvDim1 = qwen3Config.dim(); - unifiedLayer.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapQ, // output - weights.wqLayered[layerIndex].asHalfFloatArray(), qkvDim1, qDim0, LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapK, // output - weights.wkLayered[layerIndex].asHalfFloatArray(), qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapV, // output - weights.wvLayered[layerIndex].asHalfFloatArray(), qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC); + qwen3Config.dim(), qwen3Config.rmsNormEps(), qwen3State.localSize); + + unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmul, + context, + qwen3State.wrapX, // raw input (not normalized) + qwen3State.wrapQ, // output Q + qwen3State.wrapK, // output K + qwen3State.wrapV, // output V + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights + qwen3State.temp, // RMS scale factor + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + qkvDim1, // input dim (config.dim()) + qDim, // Q output dim + kvDim, // KV output dim + LOCAL_WORK_GROUP_SIZE_ALLOC); // Qcur rmsnorm unifiedLayer.task("rmsnormReduction_Qcur", Qwen3Kernels::rmsnormWithParallelOffset, context, qwen3State.tempQcur, // output @@ -225,18 +234,21 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) .task("rmsnormMapIndexInPlace_Kcur", Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, context, qwen3State.wrapK, // output weights.rms_att_KNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempKcur); - // rope rotation task graph - unifiedLayer.task("ropeRotation", Qwen3Kernels::ropeRotation, context, qwen3State.positionHolder, qwen3State.wrapQ, // out - qwen3State.wrapK, // out - qwen3Config.numberOfKeyValueHeads(), nEmbdHead); - - unifiedLayer.task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, qwen3State.wrapKeyCache, // out - qwen3State.wrapK, // in - qwen3State.wrapValueCache, // out - qwen3State.wrapV, // in - qwen3State.positionHolder, nEmbdGqa, layerIndex, qwen3Config.contextLength()); - - unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, context, qwen3State.wrapQ, qwen3State.wrapKeyCache, qwen3State.wrapValueCache, + unifiedLayer.task("rope_and_kv_cache", Qwen3Kernels::ropeRotationWithCacheCopy, + context, + qwen3State.positionHolder, + qwen3State.wrapQ, // Q (in/out) + qwen3State.wrapK, // K (in/out) + qwen3State.wrapV, // V (in only) + qwen3State.wrapKeyCache, // Key cache (out) + qwen3State.wrapValueCache, // Value cache (out) + qwen3Config.numberOfKeyValueHeads(), + nEmbdHead, + nEmbdGqa, + layerIndex, + qwen3Config.contextLength()); + + unifiedLayer.task("attention", TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, context, qwen3State.wrapQ, qwen3State.wrapKeyCache, qwen3State.wrapValueCache, qwen3State.wrapXb, // out qwen3Config.numberOfHeads(), nEmbdHead, nEmbdGqa, gqa, qwen3State.positionHolder, layerIndex, qwen3Config.contextLength()); @@ -247,15 +259,24 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3Config.dim(), // dim0 = 1024 LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.tempFFN, qwen3State.wrapX, qwen3Config.dim(), + unifiedLayer.task("ffn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.tempFFN, qwen3State.wrapX, qwen3Config.dim(), qwen3Config.rmsNormEps(), qwen3State.localSize) - .task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, qwen3Config.dim(), qwen3Config.rmsNormEps()) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - qwen3State.tempFFN); + .task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, qwen3Config.dim(), qwen3Config.rmsNormEps()); + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + qwen3State.wrapX, // raw input (FP32) + qwen3State.wrapHb, // output + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + qwen3State.tempFFN, // RMS scale factor + weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 + weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 + qwen3Config.dim(), // input dimension + qwen3Config.hiddenDim(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen3State.wrapXb, qwen3State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), - weights.w3Layered[layerIndex].asHalfFloatArray(), qwen3Config.dim(), qwen3Config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapHb, qwen3State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), + unifiedLayer.task("ffn_down_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapHb, qwen3State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), qwen3Config.hiddenDim(), qwen3Config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(qwen3State.wrapX); return unifiedLayer; } @@ -266,21 +287,21 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { if (layerIndex == 0) { // First layer: Transfer temporary buffers and QKV state every execution - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.tempQcur, qwen3State.tempKcur); + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.positionHolder); + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION); // First execution: allocate workspace buffers unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // context, qwen3State.wrapXb, qwen3State.wrapXb2, // qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, // qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // - qwen3State.wrapAtt, qwen3State.wrapHb); + qwen3State.wrapAtt, qwen3State.wrapHb, qwen3State.temp, qwen3State.tempFFN, qwen3State.tempQcur, qwen3State.tempKcur); } else { // Subsequent layers: Consume data from previous layer unifiedLayer.consumeFromDevice(context, qwen3State.wrapXb, qwen3State.wrapXb2, // qwen3State.wrapQ, qwen3State.wrapK, // qwen3State.wrapV, qwen3State.wrapKeyCache, // qwen3State.wrapValueCache, qwen3State.wrapAtt, // - qwen3State.wrapHb, qwen3State.positionHolder); // + qwen3State.wrapHb, qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); // unifiedLayer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); } From cfa3ba033c3cb39f0a1d9fee4b7dce5c5360a03a Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 19:51:02 +0200 Subject: [PATCH 17/42] Add fused Q and K RMSNorm kernel and refactor task graph to consolidate Q/K RMSNorm into a single operation. Cleanup deprecated workers, update task names, and streamline layer configuration. --- .../tornadovm/TornadoVMMasterPlan.java | 1 - .../tornadovm/kernels/Qwen3Kernels.java | 133 ++++++++++++++++++ .../layers/type/fp16/Qwen3FP16FFNLayers.java | 67 +++------ 3 files changed, 153 insertions(+), 48 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index b1195c65..293d2c0c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -180,7 +180,6 @@ private int getFinalLogitsGraphIndex() { public void forceCopyInReadOnlyDataLayered() { // Execute all TornadoVM graphs state.wrapX.init(0.0f); - state.wrapXFP16.clear(); state.positionHolder.init(0); // Execute activation update graph diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java index f018fd7c..88f006b6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java @@ -564,5 +564,138 @@ public static void fusedRmsNormQKVMatmul( } } + /** + * Fused Q and K RMSNorm for Qwen3. + * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. + * + * Workgroup assignment: + * - Workgroups [0, nHeads): Process Q heads + * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads + * + * Each workgroup computes reduction and applies normalization for one head. + */ + /** + * Fused Q and K RMSNorm for Qwen3. + * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. + * + * Workgroup assignment: + * - Workgroups [0, nHeads): Process Q heads + * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads + */ + /** + * Fused Q and K RMSNorm for Qwen3. + * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. + * + * Workgroup assignment: + * - Workgroups [0, nHeads): Process Q heads + * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads + */ + /** + * Fused Q and K RMSNorm for Qwen3. + * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. + * + * Workgroup assignment: + * - Workgroups [0, nHeads): Process Q heads + * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads + */ + /** + * Fused Q and K RMSNorm for Qwen3. + * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. + * + * Workgroup assignment: + * - Workgroups [0, nHeads): Process Q heads + * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads + */ + public static void fusedQKRmsNorm( + KernelContext context, + FloatArray q, // Q vector (in/out) + FloatArray k, // K vector (in/out) + FloatArray qWeights, // Q RMS norm weights + FloatArray kWeights, // K RMS norm weights + int nHeads, // number of Q heads + int nHeadKv, // number of K heads + int nEmbdHead, // head dimension + int localMemSize, // local memory size (must be fixed) + float rmsNormEps) { + + int groupId = context.groupIdx; + int localId = context.localIdx; + int localSize = context.localGroupSizeX; + + // Allocate local memory with FIXED size parameter + float[] localSum = context.allocateFloatLocalArray(localMemSize); + + if (groupId < nHeads) { + // === Process Q head === + int headOffset = groupId * nEmbdHead; + + // Step 1: Compute sum of squares (reduction) + float partialSum = 0.0f; + for (int i = localId; i < nEmbdHead; i += localSize) { + float val = q.get(headOffset + i); + partialSum += val * val; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + // Compute normalization factor + float ss = localSum[0]; + ss = ss / nEmbdHead + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + + context.localBarrier(); + + // Step 2: Apply normalization with weights (in-place) + for (int i = localId; i < nEmbdHead; i += localSize) { + float normalized = ss * q.get(headOffset + i); + q.set(headOffset + i, qWeights.get(i) * normalized); + } + + } else if (groupId < nHeads + nHeadKv) { + // === Process K head === + int headIdx = groupId - nHeads; + int headOffset = headIdx * nEmbdHead; + + // Step 1: Compute sum of squares (reduction) + float partialSum = 0.0f; + for (int i = localId; i < nEmbdHead; i += localSize) { + float val = k.get(headOffset + i); + partialSum += val * val; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + // Compute normalization factor + float ss = localSum[0]; + ss = ss / nEmbdHead + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + + context.localBarrier(); + + // Step 2: Apply normalization with weights (in-place) + for (int i = localId; i < nEmbdHead; i += localSize) { + float normalized = ss * k.get(headOffset + i); + k.set(headOffset + i, kWeights.get(i) * normalized); + } + } + } } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index e466cee1..2322af3c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -64,25 +64,13 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { // Q matmul worker (GQA: full query heads) int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid matmulQRowMajorWorker = WorkerGridFactory.genericWorker(matmulQGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - // KV matmul worker (GQA: reduced KV heads) - int matmulKVGlobal = nEmbdGqa * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid matmulKVRowMajorWorker = WorkerGridFactory.genericWorker(matmulKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - - // Q current worker - WorkerGrid qCurWorker = WorkerGridFactory.genericWorker(config.numberOfHeads() * nEmbdHead, nEmbdHead); - - // K current worker - WorkerGrid kCurWorker = WorkerGridFactory.genericWorker(config.numberOfKeyValueHeads() * nEmbdHead, nEmbdHead); - - // RoPE worker (2D: heads x embedding_head/2) WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), nEmbdHead); // Parallel attention worker WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), nEmbdHead); - // Matmul1 worker (output projection) + // attn_output_proj worker (output projection) int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC); @@ -92,7 +80,8 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid projectionTwoWorker = WorkerGridFactory.genericWorker(projectionTwoGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - + int qkRmsNormGroups = config.numberOfHeads() + config.numberOfKeyValueHeads(); + WorkerGrid qkRmsNormWorker = WorkerGridFactory.genericWorker(qkRmsNormGroups * nEmbdHead, nEmbdHead); int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); int kvDim0 = nEmbdGqa; @@ -103,32 +92,22 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { // Map workers to tasks for each layer for (int i = 0; i < config.numberOfLayers(); i++) { - gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker); - - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker); - - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker); - + + gridScheduler.addWorkerGrid("layer_" + i + ".qk_rmsnorm", qkRmsNormWorker); gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker); gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); + gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker); gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); -// gridScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker); gridScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", projectionTwoWorker); } @@ -197,7 +176,7 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) weights.w3Layered[layerIndex].asHalfFloatArray() // ); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.temp, qwen3State.wrapX, // in + unifiedLayer.task("attn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.temp, qwen3State.wrapX, // in qwen3Config.dim(), qwen3Config.rmsNormEps(), qwen3State.localSize); unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmul, @@ -216,23 +195,17 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) kvDim, // KV output dim LOCAL_WORK_GROUP_SIZE_ALLOC); - // Qcur rmsnorm - unifiedLayer.task("rmsnormReduction_Qcur", Qwen3Kernels::rmsnormWithParallelOffset, context, qwen3State.tempQcur, // output - qwen3State.wrapQ, // input - qwen3State.localSize, // currently 128, should be variable of global nEmbHead - nEmbdHead, // for normalization - qwen3Config.rmsNormEps()) // for normalization - .task("rmsnormMapIndexInPlace_Qcur", Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, context, qwen3State.wrapQ, // output - weights.rms_att_QNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempQcur); - - // Kcur rmsnorm - unifiedLayer.task("rmsnormReduction_Kcur", Qwen3Kernels::rmsnormWithParallelOffset, context, qwen3State.tempKcur, // output - qwen3State.wrapK, // input - qwen3State.localSize, // currently 128, should be variable of global nEmbHead - nEmbdHead, // for normalization - qwen3Config.rmsNormEps()) // for normalization - .task("rmsnormMapIndexInPlace_Kcur", Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, context, qwen3State.wrapK, // output - weights.rms_att_KNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempKcur); + unifiedLayer.task("qk_rmsnorm", Qwen3Kernels::fusedQKRmsNorm, + context, + qwen3State.wrapQ, + qwen3State.wrapK, + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), + qwen3Config.numberOfHeads(), + qwen3Config.numberOfKeyValueHeads(), + nEmbdHead, + nEmbdHead, // localMemSize = nEmbdHead (e.g., 128) + qwen3Config.rmsNormEps()); unifiedLayer.task("rope_and_kv_cache", Qwen3Kernels::ropeRotationWithCacheCopy, context, @@ -252,7 +225,7 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapXb, // out qwen3Config.numberOfHeads(), nEmbdHead, nEmbdGqa, gqa, qwen3State.positionHolder, layerIndex, qwen3Config.contextLength()); - unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapXb, // vector + unifiedLayer.task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapXb, // vector qwen3State.wrapX, // out, should be [1024] weights.woLayered[layerIndex].asHalfFloatArray(), // matrix nEmbdHeadK * qwen3Config.numberOfHeads(), // dim1 = 2048 From abf12d424bd0c7688dfb220512b5b7754c670a12 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 20:10:01 +0200 Subject: [PATCH 18/42] Refactor Qwen3 FP16 FFN layers to streamline worker grid setup, update task graphs with fused kernels, reorganize attention and FFN block mapping, and integrate final normalization for non-NVIDIA devices. Add detailed Transformer layer task flow documentation. --- .../layers/type/fp16/Qwen3FP16FFNLayers.java | 358 ++++++++++++------ 1 file changed, 249 insertions(+), 109 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 2322af3c..2c2ce71b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -16,6 +16,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.stream.IntStream; /** * Qwen3FP16FFNLayers: FP16 FFN layers for Qwen3 with Group Query Attention (GQA) support. @@ -43,7 +44,7 @@ public class Qwen3FP16FFNLayers extends AbstractFFNLayers { List ffnLayerTaskGraphs; public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { - super(taskGraphName, state, weights, config,schedulerType); + super(taskGraphName, state, weights, config, schedulerType); this.qwen3State = state; this.qwen3Config = config; @@ -61,19 +62,12 @@ public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe @Override public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize); - - // Q matmul worker (GQA: full query heads) - int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), nEmbdHead); - // Parallel attention worker WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), nEmbdHead); - // attn_output_proj worker (output projection) int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC); - // FFN workers int fusedFFNW1W3Global = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid fusedFFNW1W3Worker = WorkerGridFactory.genericWorker(fusedFFNW1W3Global, LOCAL_WORK_GROUP_SIZE_ALLOC); @@ -85,32 +79,27 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); int kvDim0 = nEmbdGqa; - // Add this 1: int fusedQKVRows = qDim0 + 2 * kvDim0; // Q rows + K rows + V rows int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - // Map workers to tasks for each layer + // Map workers to tasks for each layer (in task execution order) for (int i = 0; i < config.numberOfLayers(); i++) { + // === Attention Block === gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".qk_rmsnorm", qkRmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); - - gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker); - gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker); - + // === FFN Block === gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + if (shouldUseFinalNormalization()) { + gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_finalize", rmsNormWorker); + } + gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker); gridScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", projectionTwoWorker); } - return gridScheduler; } @@ -137,122 +126,273 @@ public List getFfnLayerTaskGraphs() { * Setup all FFN layers for all transformer layers */ List setupFFNLayered() { - List ffnGraphs = new ArrayList<>(); - for (int layerIndex = 0; layerIndex < qwen3Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, layerIndex); - if (layerIndex == qwen3Config.numberOfLayers() - 1) { + return IntStream.range(0, qwen3Config.numberOfLayers()).mapToObj(i -> { + var ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, i); + if (i == qwen3Config.numberOfLayers() - 1) { setupLastID(ffnLayer.getTaskGraphName()); } - ffnGraphs.add(ffnLayer.snapshot()); - } - - return ffnGraphs; + return ffnLayer.snapshot(); + }).toList(); } + // @formatter:off /** - * Setup a single transformer layer for Qwen3 with GQA + * Transformer Layer Task Flow (Qwen3FP16FFNLayers) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (scale factor for RMSNorm) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────────────┐ + * │ attn_rms_qkv_projection │──▶ wrapQ, wrapK, wrapV (FP32) + * └───────────┬─────────────┘ (fused: RMS apply + Q/K/V matmuls) + * │ + * ▼ + * ┌─────────────┐ + * │ qk_rmsnorm │──▶ wrapQ, wrapK normalized in-place + * └──────┬──────┘ (fused: Q + K RMSNorm reduction + apply) + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ (fused: RoPE rotation + cache write) + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (scale factor) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ ffn_rms_finalize │──▶ tempFFN (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (NVIDIA) / 10 tasks (non-NVIDIA) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points (vs baseline 18 tasks): + * • attn_rms_qkv_projection: Fused RMS apply + Q/K/V matmuls (4→1 kernel) + * • qk_rmsnorm: Fused Q + K RMSNorm (4→1 kernel) + * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel) + * + * Qwen3-Specific: + * • GQA: nHeads (Q) != nHeadKv (K/V), with gqa = nHeads / nHeadKv + * • Q/K RMSNorm: Additional normalization after QKV projection (qk_rmsnorm) + * • RoPE theta: 1,000,000 (vs Llama's 10,000 or 50,000) + * */ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) { var taskGraphName = "layer_" + layerIndex; - int qDim = nEmbdHeadK * qwen3Config.numberOfHeads(); // Q output size - int kvDim = nEmbdGqa; // K/V output size - int qkvDim1 = qwen3Config.dim(); + + // === Dimension Parameters === + int qDim = nEmbdHeadK * qwen3Config.numberOfHeads(); // Q output size (full heads) + int kvDim = nEmbdGqa; // K/V output size (reduced for GQA) + int inputDim = qwen3Config.dim(); // Model dimension + TaskGraph unifiedLayer = new TaskGraph(taskGraphName); + + // === Data Setup === unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - //Copy-in weights per layer for batched-layered layout - weights.rms_att_weightLayered[layerIndex].asFloatArray(), // - weights.wqLayered[layerIndex].asHalfFloatArray(), // - weights.wkLayered[layerIndex].asHalfFloatArray(), // - weights.wvLayered[layerIndex].asHalfFloatArray(), // - weights.woLayered[layerIndex].asHalfFloatArray(), // - //rms_att_KNormLayered - weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // - //rms_att_QNormLayered - weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // - weights.w1Layered[layerIndex].asHalfFloatArray(), // - weights.w2Layered[layerIndex].asHalfFloatArray(), // - weights.w3Layered[layerIndex].asHalfFloatArray() // - ); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Attention weights + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS norm weights + weights.wqLayered[layerIndex].asHalfFloatArray(), // Q projection + weights.wkLayered[layerIndex].asHalfFloatArray(), // K projection + weights.wvLayered[layerIndex].asHalfFloatArray(), // V projection + weights.woLayered[layerIndex].asHalfFloatArray(), // Output projection + // Qwen3-specific Q/K norm weights + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K RMSNorm weights + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q RMSNorm weights + // FFN weights + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // FFN RMS norm weights + weights.w1Layered[layerIndex].asHalfFloatArray(), // FFN gate + weights.w2Layered[layerIndex].asHalfFloatArray(), // FFN down + weights.w3Layered[layerIndex].asHalfFloatArray()); // FFN up unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("attn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.temp, qwen3State.wrapX, // in - qwen3Config.dim(), qwen3Config.rmsNormEps(), qwen3State.localSize); - unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmul, + // ═══════════════════════════════════════════════════════════════════════ + // ATTENTION BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + qwen3State.temp, // output: scale factor + qwen3State.wrapX, // input: hidden state + qwen3Config.dim(), // dimension + qwen3Config.rmsNormEps(), // epsilon + qwen3State.localSize); // local memory size + + // Fused RMS Apply + QKV Projection + unifiedLayer.task("attn_rms_qkv_projection", + Qwen3Kernels::fusedRmsNormQKVMatmul, context, - qwen3State.wrapX, // raw input (not normalized) - qwen3State.wrapQ, // output Q - qwen3State.wrapK, // output K - qwen3State.wrapV, // output V + qwen3State.wrapX, // input: raw hidden state (FP32) + qwen3State.wrapQ, // output: Q vectors + qwen3State.wrapK, // output: K vectors + qwen3State.wrapV, // output: V vectors weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights - qwen3State.temp, // RMS scale factor - weights.wqLayered[layerIndex].asHalfFloatArray(), - weights.wkLayered[layerIndex].asHalfFloatArray(), - weights.wvLayered[layerIndex].asHalfFloatArray(), - qkvDim1, // input dim (config.dim()) - qDim, // Q output dim - kvDim, // KV output dim + qwen3State.temp, // RMS scale factor from reduction + weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq [qDim x inputDim] + weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk [kvDim x inputDim] + weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv [kvDim x inputDim] + inputDim, // input dimension + qDim, // Q output dimension + kvDim, // K/V output dimension (GQA: reduced) LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.task("qk_rmsnorm", Qwen3Kernels::fusedQKRmsNorm, + // Fused Q/K RMSNorm (Qwen3-specific) + unifiedLayer.task("qk_rmsnorm", + Qwen3Kernels::fusedQKRmsNorm, context, - qwen3State.wrapQ, - qwen3State.wrapK, - weights.rms_att_QNormLayered[layerIndex].asFloatArray(), - weights.rms_att_KNormLayered[layerIndex].asFloatArray(), - qwen3Config.numberOfHeads(), - qwen3Config.numberOfKeyValueHeads(), - nEmbdHead, - nEmbdHead, // localMemSize = nEmbdHead (e.g., 128) - qwen3Config.rmsNormEps()); - - unifiedLayer.task("rope_and_kv_cache", Qwen3Kernels::ropeRotationWithCacheCopy, + qwen3State.wrapQ, // Q vectors (in/out) + qwen3State.wrapK, // K vectors (in/out) + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q norm weights + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K norm weights + qwen3Config.numberOfHeads(), // nHeads (Q heads) + qwen3Config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA) + nEmbdHead, // head dimension + nEmbdHead, // local memory size + qwen3Config.rmsNormEps()); // epsilon + + // Fused RoPE Rotation + KV Cache Write + unifiedLayer.task("rope_and_kv_cache", + Qwen3Kernels::ropeRotationWithCacheCopy, context, - qwen3State.positionHolder, - qwen3State.wrapQ, // Q (in/out) - qwen3State.wrapK, // K (in/out) - qwen3State.wrapV, // V (in only) - qwen3State.wrapKeyCache, // Key cache (out) - qwen3State.wrapValueCache, // Value cache (out) - qwen3Config.numberOfKeyValueHeads(), - nEmbdHead, - nEmbdGqa, - layerIndex, - qwen3Config.contextLength()); - - unifiedLayer.task("attention", TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, context, qwen3State.wrapQ, qwen3State.wrapKeyCache, qwen3State.wrapValueCache, - qwen3State.wrapXb, // out - qwen3Config.numberOfHeads(), nEmbdHead, nEmbdGqa, gqa, qwen3State.positionHolder, layerIndex, qwen3Config.contextLength()); - - unifiedLayer.task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapXb, // vector - qwen3State.wrapX, // out, should be [1024] - weights.woLayered[layerIndex].asHalfFloatArray(), // matrix - nEmbdHeadK * qwen3Config.numberOfHeads(), // dim1 = 2048 - qwen3Config.dim(), // dim0 = 1024 + qwen3State.positionHolder, // current position + qwen3State.wrapQ, // Q vectors (in/out, rotated) + qwen3State.wrapK, // K vectors (in/out, rotated) + qwen3State.wrapV, // V vectors (in only) + qwen3State.wrapKeyCache, // key cache (out) + qwen3State.wrapValueCache, // value cache (out) + qwen3Config.numberOfKeyValueHeads(), // nHeadKv + nEmbdHead, // head dimension + nEmbdGqa, // kvDim + layerIndex, // layer index for cache offset + qwen3Config.contextLength()); // max sequence length + + // Flash Attention + unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, + context, + qwen3State.wrapQ, // query vectors + qwen3State.wrapKeyCache, // key cache + qwen3State.wrapValueCache, // value cache + qwen3State.wrapXb, // output: attention result + qwen3Config.numberOfHeads(), // nHeads + nEmbdHead, // headSize + nEmbdGqa, // kvDim + gqa, // kvMul (nHeads / nHeadKv) + qwen3State.positionHolder, // position + layerIndex, // layer index + qwen3Config.contextLength()); // context length + + // Output Projection with Residual + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + qwen3State.wrapXb, // input: attention output + qwen3State.wrapX, // output: wrapX += Wo · wrapXb + weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim x qDim] + nEmbdHeadK * qwen3Config.numberOfHeads(), // input dim (qDim) + qwen3Config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.task("ffn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.tempFFN, qwen3State.wrapX, qwen3Config.dim(), - qwen3Config.rmsNormEps(), qwen3State.localSize) - .task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, qwen3Config.dim(), qwen3Config.rmsNormEps()); + // ═══════════════════════════════════════════════════════════════════════ + // FFN BLOCK + // ═══════════════════════════════════════════════════════════════════════ + // RMS Normalization - compute scale factor + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + qwen3State.tempFFN, // output: scale factor + qwen3State.wrapX, // input: hidden state + qwen3Config.dim(), // dimension + qwen3Config.rmsNormEps(), // epsilon + qwen3State.localSize); // local memory size + + // Final normalization (non-NVIDIA only) + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + qwen3State.tempFFN, // scale factor (in/out) + qwen3Config.dim(), // dimension + qwen3Config.rmsNormEps()); // epsilon + } + + // Fused RMS Apply + Gate/Up Projection + SiLU + GLU unifiedLayer.task("rms_ffn_gate_up", TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, context, - qwen3State.wrapX, // raw input (FP32) - qwen3State.wrapHb, // output - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights - qwen3State.tempFFN, // RMS scale factor - weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 - weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 - qwen3Config.dim(), // input dimension - qwen3Config.hiddenDim(), // output dimension + qwen3State.wrapX, // input: raw hidden state (FP32) + qwen3State.wrapHb, // output: SiLU(x·W1) ⊙ (x·W3) + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + qwen3State.tempFFN, // RMS scale factor + weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 (gate) + weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 (up) + qwen3Config.dim(), // input dimension + qwen3Config.hiddenDim(), // hidden dimension LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.task("ffn_down_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapHb, qwen3State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), - qwen3Config.hiddenDim(), qwen3Config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(qwen3State.wrapX); + // Down Projection with Residual + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + qwen3State.wrapHb, // input: FFN intermediate + qwen3State.wrapX, // output: wrapX += W2 · wrapHb + weights.w2Layered[layerIndex].asHalfFloatArray(), // W2 (down) + qwen3Config.hiddenDim(), // input dim + qwen3Config.dim(), // output dim + LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice(qwen3State.wrapX); + return unifiedLayer; } + // @formatter:on /** * Configure data transfers for first and subsequent layers @@ -280,5 +420,5 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye } return unifiedLayer; } - + // @formatter:on } \ No newline at end of file From 042b0b5d781bb4184d70f93b4b66f0e6f09efad4 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 20:21:03 +0200 Subject: [PATCH 19/42] Add `processHeadsFlashAttentionOptV2` kernel with static memory size fixes, improved attention computation logic, and optimized handling of large models. Update task graph to revert to `processHeadsFlashAttention` for compatibility. --- .../tornadovm/kernels/Qwen3Kernels.java | 35 +--- .../TransformerComputeKernelsLayered.java | 194 ++++++++++++++++++ .../layers/type/fp16/Qwen3FP16FFNLayers.java | 2 +- 3 files changed, 196 insertions(+), 35 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java index 88f006b6..506d3fdf 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java @@ -564,40 +564,7 @@ public static void fusedRmsNormQKVMatmul( } } - /** - * Fused Q and K RMSNorm for Qwen3. - * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. - * - * Workgroup assignment: - * - Workgroups [0, nHeads): Process Q heads - * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads - * - * Each workgroup computes reduction and applies normalization for one head. - */ - /** - * Fused Q and K RMSNorm for Qwen3. - * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. - * - * Workgroup assignment: - * - Workgroups [0, nHeads): Process Q heads - * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads - */ - /** - * Fused Q and K RMSNorm for Qwen3. - * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. - * - * Workgroup assignment: - * - Workgroups [0, nHeads): Process Q heads - * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads - */ - /** - * Fused Q and K RMSNorm for Qwen3. - * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. - * - * Workgroup assignment: - * - Workgroups [0, nHeads): Process Q heads - * - Workgroups [nHeads, nHeads + nHeadKv): Process K heads - */ + /** * Fused Q and K RMSNorm for Qwen3. * Combines rmsnormReduction + rmsnormMapIndexInPlace for both Q and K into one kernel. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index 6af9ac22..ec6d7e0c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -592,6 +592,200 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray } } + public static void processHeadsFlashAttentionOptV2( + KernelContext context, + FloatArray q, + FloatArray key_cache, + FloatArray value_cache, + FloatArray xb, + int nHeads, + int headSize, // NOTE: Still used for logic, but not for allocation size + int kvDim, + int kvMul, + IntArray positionHolder, + int layer, + int contextLength) { + + // --- STATIC CONSTANTS FOR OPENCL ALLOCATIONS --- + // These must be large enough to handle the maximum expected values for + // headSize and localSize in your model/hardware setup. + // Assuming Max Head Size is 256 and Max Local Size is 256. + final int MAX_HEAD_SIZE = 256; + final int MAX_LOCAL_SIZE = 256; + final int MAX_BLOCK_SIZE_C = 32; + final int MAX_TILE_ELEMENTS = MAX_BLOCK_SIZE_C * MAX_HEAD_SIZE; + + int tid = context.localIdx; + int h = context.groupIdx; + int localSize = context.localGroupSizeX; + + if (h >= nHeads) { + return; + } + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_SIZE_C = 32; + + // === Shared memory allocations (FIXED: using static sizes) === + // ERROR FIX 1: Use MAX_HEAD_SIZE instead of dynamic headSize + float[] q_shared = context.allocateFloatLocalArray(MAX_HEAD_SIZE); + // ERROR FIX 2: Use MAX_TILE_ELEMENTS instead of BLOCK_SIZE_C * headSize + float[] k_tile = context.allocateFloatLocalArray(MAX_TILE_ELEMENTS); + float[] v_tile = context.allocateFloatLocalArray(MAX_TILE_ELEMENTS); + + float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); // Size is constant (32) + float[] exp_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); // Size is constant (32) + + // ERROR FIX 3: Use MAX_LOCAL_SIZE instead of dynamic localSize + float[] reduction_shared = context.allocateFloatLocalArray(MAX_LOCAL_SIZE); + + float[] state_shared = context.allocateFloatLocalArray(4); // Size is constant (4) + + // === Dimension partitioning: each thread handles subset of output dims === + int dimsPerThread = (headSize + localSize - 1) / localSize; + int myStartDim = tid * dimsPerThread; + int myEndDim = Math.min(myStartDim + dimsPerThread, headSize); + int myDimCount = myEndDim - myStartDim; + + // FIX from previous iteration: ensuring output array is statically sized + final int MAX_OUTPUT_DIMS = MAX_HEAD_SIZE / 8; // e.g., 32 if MAX_HEAD_SIZE=256 + float[] output = new float[MAX_OUTPUT_DIMS]; + + // Initialize thread-local output + for (int i = 0; i < myDimCount; i++) { + output[i] = 0.0f; + } + + // Initialize shared state + if (tid == 0) { + state_shared[0] = Float.NEGATIVE_INFINITY; + state_shared[1] = 0.0f; + } + + // Load query into shared memory (cooperative) + // NOTE: Loop bound must still use headSize to read correct data volume + for (int i = tid; i < headSize; i += localSize) { + q_shared[i] = q.get(h * headSize + i); + } + context.localBarrier(); + + // Process sequence in tiles + for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) { + int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos); + int tileLen = tileEnd - tileC + 1; + + // === Cooperative K/V tile loading === + int totalElements = tileLen * headSize; + int elementsPerThread = (totalElements + localSize - 1) / localSize; + int startElem = tid * elementsPerThread; + int endElem = Math.min(startElem + elementsPerThread, totalElements); + + for (int globalElemIdx = startElem; globalElemIdx < endElem; globalElemIdx++) { + int seqIdx = globalElemIdx / headSize; + int dimIdx = globalElemIdx % headSize; + int kvOffset = loff + (tileC + seqIdx) * kvDim + kvHeadIdx * headSize + dimIdx; + int tileMemOffset = seqIdx * headSize + dimIdx; + + // Check bounds just to be safe, though kvDim/headSize should ensure this is valid. + if (tileMemOffset < MAX_TILE_ELEMENTS) { + k_tile[tileMemOffset] = key_cache.get(kvOffset); + v_tile[tileMemOffset] = value_cache.get(kvOffset); + } + } + context.localBarrier(); + + // === Compute attention scores (cooperative) === + for (int t = tid; t < tileLen; t += localSize) { + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += q_shared[d] * k_tile[t * headSize + d]; + } + s_tile[t] = score / TornadoMath.sqrt(headSize); + } + context.localBarrier(); + + // ... (Parallel reduction for tileMax - uses reduction_shared, which is now fixed) + float threadMax = Float.NEGATIVE_INFINITY; + for (int t = tid; t < tileLen; t += localSize) { + if (s_tile[t] > threadMax) { + threadMax = s_tile[t]; + } + } + reduction_shared[tid] = threadMax; + context.localBarrier(); + + for (int stride = localSize / 2; stride > 0; stride /= 2) { + if (tid < stride) { + reduction_shared[tid] = Math.max(reduction_shared[tid], reduction_shared[tid + stride]); + } + context.localBarrier(); + } + float tileMax = reduction_shared[0]; + + // === Update running max and rescale if needed === + float prevMax = state_shared[0]; + float newMax = Math.max(prevMax, tileMax); + float scale = 1.0f; + + if (newMax != prevMax && prevMax != Float.NEGATIVE_INFINITY) { + scale = TornadoMath.exp(prevMax - newMax); + for (int i = 0; i < myDimCount; i++) { + output[i] *= scale; + } + } + + // === Compute exp(score - max) and tile sum (cooperative) === + for (int t = tid; t < tileLen; t += localSize) { + exp_tile[t] = TornadoMath.exp(s_tile[t] - newMax); + } + context.localBarrier(); + + // Parallel reduction for tile sum + // ... (Uses reduction_shared, which is now fixed) + float threadSum = 0.0f; + for (int t = tid; t < tileLen; t += localSize) { + threadSum += exp_tile[t]; + } + reduction_shared[tid] = threadSum; + context.localBarrier(); + + for (int stride = localSize / 2; stride > 0; stride /= 2) { + if (tid < stride) { + reduction_shared[tid] += reduction_shared[tid + stride]; + } + context.localBarrier(); + } + float tileSum = reduction_shared[0]; + + // Update shared state (thread 0) + if (tid == 0) { + state_shared[0] = newMax; + state_shared[1] = state_shared[1] * scale + tileSum; + } + context.localBarrier(); + + // === Accumulate output (each thread handles its dimensions) === + for (int t = 0; t < tileLen; t++) { + float expScore = exp_tile[t]; + for (int i = 0; i < myDimCount; i++) { + int d = myStartDim + i; + output[i] += expScore * v_tile[t * headSize + d]; + } + } + context.localBarrier(); + } + + // === Final normalization and write === + float sumExp = state_shared[1]; + float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; + + int baseOffset = h * headSize + myStartDim; + for (int i = 0; i < myDimCount; i++) { + xb.set(baseOffset + i, output[i] * normFactor); + } + } /** * Same as processHeadsFlashAttention but with some optimizations that seem to lower attention's execution time, especially in larger models. */ diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 2c2ce71b..7b4477cb 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -315,7 +315,7 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) // Flash Attention unifiedLayer.task("attention", - TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, + TransformerComputeKernelsLayered::processHeadsFlashAttention, context, qwen3State.wrapQ, // query vectors qwen3State.wrapKeyCache, // key cache From 1cbe03a4e146dff544eac191eda26f94a8d07e5b Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 20:30:02 +0200 Subject: [PATCH 20/42] Refactor Qwen3 FP16 FFN layers: remove unused imports, replace explicit TaskGraph type with `var`, streamline task graph configuration by removing unused temp variables. --- .../tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 7b4477cb..27565dca 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -14,7 +14,6 @@ import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; import java.util.List; import java.util.stream.IntStream; @@ -230,7 +229,7 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) int kvDim = nEmbdGqa; // K/V output size (reduced for GQA) int inputDim = qwen3Config.dim(); // Model dimension - TaskGraph unifiedLayer = new TaskGraph(taskGraphName); + var unifiedLayer = new TaskGraph(taskGraphName); // === Data Setup === unifiedLayer.consumeFromDevice(state.wrapX); @@ -407,7 +406,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye context, qwen3State.wrapXb, qwen3State.wrapXb2, // qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, // qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // - qwen3State.wrapAtt, qwen3State.wrapHb, qwen3State.temp, qwen3State.tempFFN, qwen3State.tempQcur, qwen3State.tempKcur); + qwen3State.wrapAtt, qwen3State.wrapHb, qwen3State.temp, qwen3State.tempFFN); } else { // Subsequent layers: Consume data from previous layer unifiedLayer.consumeFromDevice(context, qwen3State.wrapXb, qwen3State.wrapXb2, // @@ -416,7 +415,6 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye qwen3State.wrapValueCache, qwen3State.wrapAtt, // qwen3State.wrapHb, qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); // - unifiedLayer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); } return unifiedLayer; } From a4bc159bbbef9c2759500fbc05730232c8c7b208 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 21:26:14 +0200 Subject: [PATCH 21/42] Refactor Qwen2 FP16 task graph: consolidate attention and FFN tasks with fused kernels, update worker grid configurations, and streamline data transfer logic. --- .../gpullama3/inference/state/Qwen2State.java | 2 + .../layers/type/fp16/Qwen2FP16FFNLayers.java | 106 +++++++++++------- 2 files changed, 65 insertions(+), 43 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java index da6d7046..a4ef530b 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -42,6 +43,7 @@ protected StateFields createStateFields(Configuration configuration) { // TornadoVM wrappers with Qwen2 dimensions fields.wrapX = new FloatArray(config.dim()); fields.wrapXb = new FloatArray(config.dim()); + fields.wrapXbFP16 = new HalfFloatArray(config.dim()); fields.wrapXb2 = new FloatArray(config.dim()); fields.wrapHb = new FloatArray(config.hiddenDim()); fields.wrapHb2 = new FloatArray(config.hiddenDim()); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 858848ea..8cd8bbc7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -51,7 +51,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) int ic = config.headSize() / 2; WorkerGrid ropeWorker = new WorkerGrid2D(h, ic); ropeWorker.setGlobalWork(h, ic, 1); - ropeWorker.setLocalWork(1, 1, 1); + ropeWorker.setLocalWork(h / 2, ic / 2, 1); int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); @@ -95,24 +95,25 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) copyToCachesWorker.setGlobalWork(config.kvDim(), 1, 1); copyToCachesWorker.setLocalWork(32, 1, 1); // Set local work size to 32 (for copying to caches) + int fusedQKVGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = new WorkerGrid1D(fusedQKVGlobal); + fusedQKVWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + // Map workers to tasks for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qbias", qBiasWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kbias", kvBiasWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vbias", kvBiasWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); } return tornadoForwardScheduler; } @@ -136,15 +137,8 @@ public List getFfnLayerTaskGraphs() { return ffnLayerTaskGraphs; } - /** - * Setup all FFN layers for all transformer layers - */ List setupFFNLayered() { - List ffnGraphs = new ArrayList<>(); - - qwen2State.temp.init(0.0f); - qwen2State.tempFFN.init(0.0f); - + List ffnGraphs = new ArrayList<>(qwen2Config.numberOfLayers()); for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSingleQwen2FFNLayer((Qwen2TornadoWeights) weights, layerIndex); if (layerIndex == qwen2Config.numberOfLayers() - 1) { @@ -152,15 +146,16 @@ List setupFFNLayered() { } ffnGraphs.add(ffnLayer.snapshot()); } - return ffnGraphs; } + // @formatter:off /** * Setup a single transformer layer for Qwen2 with GQA */ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) { var taskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(taskGraphName); unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // @@ -178,36 +173,61 @@ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) weights.w3Layered[layerIndex].asHalfFloatArray()); // unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); // - unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.temp, qwen2State.wrapX, config.dim(), config.rmsNormEps(), - qwen2State.localSize) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), - qwen2State.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC).task("qbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim()) + unifiedLayer.task("attn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.temp, qwen2State.wrapX, config.dim(), config.rmsNormEps(), qwen2State.localSize); + unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmul, context, + qwen2State.wrapX, // input + qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, // outputs + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // rms weights + qwen2State.temp, // scale + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + config.dim(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + + .task("qbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim()) .task("kbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapK, weights.k_biasLayered[layerIndex].asFloatArray(), config.kvDim()) - .task("vbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim()) - .task("rope", Qwen3Kernels::ropeRotation, context, qwen2State.positionHolder, qwen2State.wrapQ, qwen2State.wrapK, config.numberOfKeyValueHeads(), config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, qwen2State.wrapKeyCache, qwen2State.wrapK, qwen2State.wrapValueCache, qwen2State.wrapV, qwen2State.positionHolder, - config.kvDim(), layerIndex, config.contextLength()) - .task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context, qwen2State.wrapQ, qwen2State.wrapKeyCache, qwen2State.wrapValueCache, qwen2State.wrapXb, + .task("vbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim()); + unifiedLayer.task("rope_and_kv_cache", + Qwen3Kernels::ropeRotationWithCacheCopy, + context, + qwen2State.positionHolder, // current sequence position + qwen2State.wrapQ, // Q (rotated in-place) + qwen2State.wrapK, // K (rotated in-place) + qwen2State.wrapV, // V (unchanged, copied to cache) + qwen2State.wrapKeyCache, // key cache (write) + qwen2State.wrapValueCache, // value cache (write) + config.numberOfKeyValueHeads(), // nHeadKv + config.headSize(), // per-head dimension + config.kvDim(), // kvDim after group reduction + layerIndex, // layer offset + config.contextLength()) // max sequence length + .task("attention", Qwen2Kernels::processHeadsFlashAttention, context, qwen2State.wrapQ, qwen2State.wrapKeyCache, qwen2State.wrapValueCache, qwen2State.wrapXb, config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), qwen2State.positionHolder, layerIndex, config.contextLength()) - .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapXb, qwen2State.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), + .task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapXb, qwen2State.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.tempFFN, qwen2State.wrapX, config.dim(), config.rmsNormEps(), + .task("ffn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.tempFFN, qwen2State.wrapX, config.dim(), config.rmsNormEps(), qwen2State.localSize) .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), qwen2State.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen2State.wrapXb, qwen2State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), + .task("rms_ffn_gate_up", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen2State.wrapXb, qwen2State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapHb, qwen2State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), +// unifiedLayer.task("rms_ffn_gate_up", TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, +// context, +// qwen2State.wrapXb, // input: raw hidden state (FP32/FP16 as appropriate) +// qwen2State.wrapHb, // output: SiLU(x·W1) ⊙ (x·W3) +// weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights +// qwen2State.tempFFN, // RMS scale factor (can also be computed inside) +// weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 (gate) +// weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 (up) +// config.dim(), // input dimension +// config.hiddenDim(), // hidden dimension +// LOCAL_WORK_GROUP_SIZE_ALLOC) // local work size + .task("ffn_down_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapHb, qwen2State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); return unifiedLayer; } + // @formatter:on /** * Configure data transfers for first and subsequent layers @@ -216,12 +236,12 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye // First layer: Transfer initial data to device (one-time transfer) if (layerIndex == 0) { // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen2State.positionHolder, qwen2State.temp, qwen2State.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen2State.positionHolder); // unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // context, qwen2State.wrapXb, qwen2State.wrapXb2, // qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, // qwen2State.wrapKeyCache, qwen2State.wrapValueCache, // - qwen2State.wrapAtt, qwen2State.wrapHb); // + qwen2State.wrapAtt, qwen2State.wrapHb, qwen2State.temp, qwen2State.tempFFN); // } else { // Subsequent layers: Consume data already on device from previous layer unifiedLayer.consumeFromDevice( // @@ -229,7 +249,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, // qwen2State.wrapKeyCache, qwen2State.wrapValueCache, // qwen2State.wrapAtt, qwen2State.wrapHb, // - qwen2State.positionHolder // + qwen2State.positionHolder, qwen2State.temp, qwen2State.tempFFN // ); } return unifiedLayer; From e15c229c282f2752ac583854e0a9f9f367424dc4 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 22:19:02 +0200 Subject: [PATCH 22/42] Add `fusedQKvBiasAddition` kernel, refactor Qwen2 FP16 task graph to consolidate Q/K/V bias addition into a single operation, and update worker grid configurations. Streamline attention block with optimized task mapping and detailed layer flow documentation. --- .../TransformerComputeKernelsLayered.java | 21 ++ .../layers/type/fp16/Qwen2FP16FFNLayers.java | 325 +++++++++++++----- 2 files changed, 269 insertions(+), 77 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index ec6d7e0c..7bd1da29 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -17,6 +17,27 @@ public class TransformerComputeKernelsLayered { public TransformerComputeKernelsLayered() { } + public static void fusedQKvBiasAddition( + KernelContext context, + FloatArray q_out, FloatArray k_out, FloatArray qBias, + FloatArray v_out, FloatArray kBias, FloatArray vBias, + int dimQ, int dimKV) { + + int gid = context.globalIdx; + + if (gid < dimQ) { + // 1. Add Q bias + q_out.set(gid, q_out.get(gid) + qBias.get(gid)); + + // 2. Conditionally Add K and V Bias + if (gid < dimKV) { + k_out.set(gid, k_out.get(gid) + kBias.get(gid)); + v_out.set(gid, v_out.get(gid) + vBias.get(gid)); + } + } + } + + public static void fusedRmsNormFFNGateUp( KernelContext context, FloatArray x, // raw input (FP32) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 8cd8bbc7..b6bfbb28 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -61,13 +61,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - WorkerGrid qBiasWorker = new WorkerGrid1D(config.dim()); - qBiasWorker.setGlobalWork(config.dim(), 1, 1); - qBiasWorker.setLocalWork(config.dim() / 8, 1, 1); - WorkerGrid kvBiasWorker = new WorkerGrid1D(config.kvDim()); - kvBiasWorker.setGlobalWork(config.kvDim(), 1, 1); - kvBiasWorker.setLocalWork(32, 1, 1); - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); @@ -87,25 +80,30 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) } } + // WorkerGrid for fused QKV bias addition (dimension is dimQ) + WorkerGrid fusedQKVBiasWorker = new WorkerGrid1D(config.dim()); + fusedQKVBiasWorker.setGlobalWork(config.dim(), 1, 1); + fusedQKVBiasWorker.setLocalWork(32, 1, 1); // Or an optimized local size + WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * optimalLocalSize, 1, 1); parallelAttentionWorker.setLocalWork(optimalLocalSize, 1, 1); - WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); - copyToCachesWorker.setGlobalWork(config.kvDim(), 1, 1); - copyToCachesWorker.setLocalWork(32, 1, 1); // Set local work size to 32 (for copying to caches) - int fusedQKVGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid fusedQKVWorker = new WorkerGrid1D(fusedQKVGlobal); fusedQKVWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + // Fused QKV bias worker (covers dimQ which is largest) + WorkerGrid fusedQKVBiasWorkerNorm = new WorkerGrid1D(config.dim()); + fusedQKVBiasWorkerNorm.setGlobalWork(config.dim(), 1, 1); + fusedQKVBiasWorkerNorm.setLocalWork(32, 1, 1); + // Map workers to tasks for (int i = 0; i < config.numberOfLayers(); i++) { tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qbias", qBiasWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kbias", kvBiasWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vbias", kvBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_qkv_bias", fusedQKVBiasWorkerNorm); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_qkv_bias", fusedQKVBiasWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); @@ -151,79 +149,252 @@ List setupFFNLayered() { // @formatter:off /** - * Setup a single transformer layer for Qwen2 with GQA + * Transformer Layer Task Flow (Qwen2FP16FFNLayers - Optimized) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (scale factor for RMSNorm) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────────────┐ + * │ attn_rms_qkv_projection │──▶ wrapQ, wrapK, wrapV (FP32) + * └───────────┬─────────────┘ (fused: RMS apply + Q/K/V matmuls) + * │ + * ▼ + * ┌────────────────┐ + * │ fused_qkv_bias │──▶ wrapQ, wrapK, wrapV += biases + * └───────┬────────┘ (fused: Q + K + V bias addition) + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ (fused: RoPE rotation + cache write) + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (scale factor) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ ffn_rms_finalize │──▶ tempFFN (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (NVIDIA) / 10 tasks (non-NVIDIA) + * Previous: 12 tasks + * Reduction: 3 tasks eliminated (25% fewer kernel launches) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points (vs previous 12 tasks): + * • fused_qkv_bias: Fused Q + K + V bias addition (3→1 kernel) + * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU + * (eliminates separate mapContextFFN kernel) + * + * Qwen2-Specific: + * • GQA: nHeads (Q) != nHeadKv (K/V), with kvMul = nHeads / nHeadKv + * • Bias terms: Q, K, V projections include bias (unlike Qwen3) + * • No Q/K RMSNorm: Unlike Qwen3, Qwen2 doesn't normalize Q/K after projection + * */ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) { var taskGraphName = "layer_" + layerIndex; - TaskGraph unifiedLayer = new TaskGraph(taskGraphName); unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - weights.rms_att_weightLayered[layerIndex].asFloatArray(), // - weights.wqLayered[layerIndex].asHalfFloatArray(), // - weights.wkLayered[layerIndex].asHalfFloatArray(), // - weights.wvLayered[layerIndex].asHalfFloatArray(), // - weights.woLayered[layerIndex].asHalfFloatArray(), // - weights.q_biasLayered[layerIndex].asFloatArray(), // - weights.k_biasLayered[layerIndex].asFloatArray(), // - weights.v_biasLayered[layerIndex].asFloatArray(), // - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // - weights.w1Layered[layerIndex].asHalfFloatArray(), // - weights.w2Layered[layerIndex].asHalfFloatArray(), // - weights.w3Layered[layerIndex].asHalfFloatArray()); // - unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); // - - unifiedLayer.task("attn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.temp, qwen2State.wrapX, config.dim(), config.rmsNormEps(), qwen2State.localSize); - unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmul, context, - qwen2State.wrapX, // input - qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, // outputs - weights.rms_att_weightLayered[layerIndex].asFloatArray(), // rms weights - qwen2State.temp, // scale + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Attention weights + weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].asHalfFloatArray(), weights.wkLayered[layerIndex].asHalfFloatArray(), weights.wvLayered[layerIndex].asHalfFloatArray(), - config.dim(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + weights.woLayered[layerIndex].asHalfFloatArray(), + // Qwen2-specific bias terms + weights.q_biasLayered[layerIndex].asFloatArray(), + weights.k_biasLayered[layerIndex].asFloatArray(), + weights.v_biasLayered[layerIndex].asFloatArray(), + // FFN weights + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // ═══════════════════════════════════════════════════════════════════════ + // ATTENTION BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + qwen2State.temp, // output: scale factor + qwen2State.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon + qwen2State.localSize); // local memory size + + // Fused RMS Apply + QKV Projection + unifiedLayer.task("attn_rms_qkv_projection", + Qwen3Kernels::fusedRmsNormQKVMatmul, + context, + qwen2State.wrapX, // input: raw hidden state (FP32) + qwen2State.wrapQ, // output: Q vectors + qwen2State.wrapK, // output: K vectors + qwen2State.wrapV, // output: V vectors + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights + qwen2State.temp, // RMS scale factor from reduction + weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq + weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk + weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv + config.dim(), // input dimension + config.dim(), // Q output dimension + config.kvDim(), // K/V output dimension (GQA) + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Fused Q/K/V Bias Addition (3→1 kernel fusion) + unifiedLayer.task("fused_qkv_bias", + TransformerComputeKernelsLayered::fusedQKvBiasAddition, + context, + qwen2State.wrapQ, // Q (in/out) + qwen2State.wrapK, // K (in/out) + weights.q_biasLayered[layerIndex].asFloatArray(), // Q bias + qwen2State.wrapV, // V (in/out) + weights.k_biasLayered[layerIndex].asFloatArray(), // K bias + weights.v_biasLayered[layerIndex].asFloatArray(), // V bias + config.dim(), // dimQ + config.kvDim()); // dimKV + + // Fused RoPE Rotation + KV Cache Write + unifiedLayer.task("rope_and_kv_cache", + Qwen3Kernels::ropeRotationWithCacheCopy, + context, + qwen2State.positionHolder, // current sequence position + qwen2State.wrapQ, // Q (rotated in-place) + qwen2State.wrapK, // K (rotated in-place) + qwen2State.wrapV, // V (copied to cache) + qwen2State.wrapKeyCache, // key cache (write) + qwen2State.wrapValueCache, // value cache (write) + config.numberOfKeyValueHeads(), // nHeadKv + config.headSize(), // per-head dimension + config.kvDim(), // kvDim + layerIndex, // layer offset + config.contextLength()); // max sequence length + + // Flash Attention + unifiedLayer.task("attention", + Qwen2Kernels::processHeadsFlashAttention, + context, + qwen2State.wrapQ, // query vectors + qwen2State.wrapKeyCache, // key cache + qwen2State.wrapValueCache, // value cache + qwen2State.wrapXb, // output: attention result + config.numberOfHeads(), // nHeads + config.headSize(), // headSize + config.kvDim(), // kvDim + config.kvMul(), // kvMul (nHeads / nHeadKv) + qwen2State.positionHolder, // position + layerIndex, // layer index + config.contextLength()); // context length + + // Output Projection with Residual + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + qwen2State.wrapXb, // input: attention output + qwen2State.wrapX, // output: wrapX += Wo · wrapXb + weights.woLayered[layerIndex].asHalfFloatArray(), // Wo + config.dim(), // input dim + config.dim(), // output dim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // ═══════════════════════════════════════════════════════════════════════ + // FFN BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + qwen2State.tempFFN, // output: scale factor + qwen2State.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon + qwen2State.localSize); // local memory size + + // Final normalization (non-NVIDIA only) + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + qwen2State.tempFFN, // scale factor (in/out) + config.dim(), // dimension + config.rmsNormEps()); // epsilon + } + + // Fused RMS Apply + Gate/Up Projection + SiLU + GLU + // (Replaces mapContextFFN + fusedFeedForwardWithSiLUAndGLUActivation) + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + qwen2State.wrapX, // input: raw hidden state (FP32) + qwen2State.wrapHb, // output: SiLU(x·W1) ⊙ (x·W3) + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + qwen2State.tempFFN, // RMS scale factor + weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 (gate) + weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 (up) + config.dim(), // input dimension + config.hiddenDim(), // hidden dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); - .task("qbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim()) - .task("kbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapK, weights.k_biasLayered[layerIndex].asFloatArray(), config.kvDim()) - .task("vbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim()); - unifiedLayer.task("rope_and_kv_cache", - Qwen3Kernels::ropeRotationWithCacheCopy, + // Down Projection with Residual + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - qwen2State.positionHolder, // current sequence position - qwen2State.wrapQ, // Q (rotated in-place) - qwen2State.wrapK, // K (rotated in-place) - qwen2State.wrapV, // V (unchanged, copied to cache) - qwen2State.wrapKeyCache, // key cache (write) - qwen2State.wrapValueCache, // value cache (write) - config.numberOfKeyValueHeads(), // nHeadKv - config.headSize(), // per-head dimension - config.kvDim(), // kvDim after group reduction - layerIndex, // layer offset - config.contextLength()) // max sequence length - .task("attention", Qwen2Kernels::processHeadsFlashAttention, context, qwen2State.wrapQ, qwen2State.wrapKeyCache, qwen2State.wrapValueCache, qwen2State.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), qwen2State.positionHolder, layerIndex, config.contextLength()) - .task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapXb, qwen2State.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), - config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("ffn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.tempFFN, qwen2State.wrapX, config.dim(), config.rmsNormEps(), - qwen2State.localSize) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - qwen2State.tempFFN) - .task("rms_ffn_gate_up", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen2State.wrapXb, qwen2State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), - weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) -// unifiedLayer.task("rms_ffn_gate_up", TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, -// context, -// qwen2State.wrapXb, // input: raw hidden state (FP32/FP16 as appropriate) -// qwen2State.wrapHb, // output: SiLU(x·W1) ⊙ (x·W3) -// weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights -// qwen2State.tempFFN, // RMS scale factor (can also be computed inside) -// weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 (gate) -// weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 (up) -// config.dim(), // input dimension -// config.hiddenDim(), // hidden dimension -// LOCAL_WORK_GROUP_SIZE_ALLOC) // local work size - .task("ffn_down_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapHb, qwen2State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), - config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + qwen2State.wrapHb, // input: FFN intermediate + qwen2State.wrapX, // output: wrapX += W2 · wrapHb + weights.w2Layered[layerIndex].asHalfFloatArray(), // W2 (down) + config.hiddenDim(), // input dim + config.dim(), // output dim + LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice(state.wrapX); return unifiedLayer; } From e7d79c9140d10e18ace42eaa16abcea302f10b3c Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 22:34:49 +0200 Subject: [PATCH 23/42] Add support for HalfFloatArray in Phi3State and initialize FP16 wrapper arrays --- .../java/org/beehive/gpullama3/inference/state/Phi3State.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java index d29ba130..6a186a03 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -81,6 +82,8 @@ protected StateFields createStateFields(Configuration config) { // TornadoVM wrapper arrays for GPU acceleration fields.wrapX = new FloatArray(dim); fields.wrapXb = new FloatArray(dim); + fields.wrapXFP16 = new HalfFloatArray(dim); + fields.wrapXbFP16 = new HalfFloatArray(dim); fields.wrapXb2 = new FloatArray(dim); fields.wrapHb = new FloatArray(2 * hiddenDim); fields.wrapHb2 = new FloatArray(hiddenDim); From 02b1a2c89eca33c9ec7940b3ded500c4c347e8e1 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 22:40:06 +0200 Subject: [PATCH 24/42] Add `splitQKV` and `splitGateUpSiLU` worker grids to Phi3 FP16 FFN layers and update grid scheduler configuration --- .../tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 75f9f531..b205fbee 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -89,8 +89,16 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int ffnDownGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid splitQKVWorker = WorkerGridFactory.genericWorker(opSize, 128); + + // SplitGateUpAndSiLU worker + WorkerGrid splitGateUpSiLUWorker = WorkerGridFactory.genericWorker(config.hiddenDim(), 128); + + // Map workers to tasks for each layer for (int i = 0; i < config.numberOfLayers(); i++) { + gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker); gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); gridScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", matmulQkvRowMajorWorker); From 428e5cc02f5dd4adc27232b12565ab0b4831e53d Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 22:44:26 +0200 Subject: [PATCH 25/42] Refactor Phi3 FP16 FFN layers: replace `createRoPEWorker` with generic worker grid, update RoPE task configuration, and streamline layer setup. --- .../tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index b205fbee..0099e92d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -70,7 +70,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { // RoPE worker (2D: heads x embedding_head/2) int ic = config.headSize() / 2; - WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), config.headSize()); +// WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), config.headSize()); // Copy to cache worker WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.kvDim(), 32); @@ -94,6 +94,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { // SplitGateUpAndSiLU worker WorkerGrid splitGateUpSiLUWorker = WorkerGridFactory.genericWorker(config.hiddenDim(), 128); + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); // Map workers to tasks for each layer for (int i = 0; i < config.numberOfLayers(); i++) { From 6c1ac6f321d86406bf7dd397b55ede2f8ec3cd8c Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 22:54:43 +0200 Subject: [PATCH 26/42] Add Phi3-specific fused kernels for RMSNorm+QKV and RMSNorm+Gate/Up, update Phi3 FP16 FFN layers with optimized worker grid configurations, fused workflows for attention and FFN blocks, and detailed task flow documentation. --- .../tornadovm/kernels/Phi3Kernels.java | 195 ++++++++ .../layers/type/fp16/Phi3FP16FFNLayers.java | 469 ++++++++++-------- .../layers/type/q8_0/Phi3Q8_0FFNLayers.java | 1 + 3 files changed, 470 insertions(+), 195 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java new file mode 100644 index 00000000..6c02cb17 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java @@ -0,0 +1,195 @@ +package org.beehive.gpullama3.tornadovm.kernels; + +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +/** + * Phi3Kernels: Optimized GPU kernels for Phi3 model family. + * + *

Key differences from Qwen/Llama kernels:

+ *
    + *
  • Generic fused RMS + matmul (single output matrix)
  • + *
  • Phi3 RoPE with headSize/2 offset pattern
  • + *
  • Combined gate/up structure support
  • + *
+ */ +public class Phi3Kernels { + + /** + * Fused RMSNorm apply + single matrix-vector multiplication. + * + *

Combines RMS normalization application with a generic matmul in one kernel, + * reducing memory bandwidth by avoiding intermediate storage.

+ * + *

Formula: output[row] = sum_j(W[row,j] * rmsWeight[j] * scale * x[j])

+ * + *

Use cases:

+ *
    + *
  • Phi3 combined QKV projection (output = wqkv · RMSNorm(x))
  • + *
  • Phi3 combined gate/up projection (output = wUp · RMSNorm(x))
  • + *
  • Any single-matrix projection after RMSNorm
  • + *
+ * + * @param context Kernel execution context + * @param x Input hidden state (FP32) [dim] + * @param output Output buffer (FP32) [outputDim] + * @param rmsWeights RMS normalization weights (FP32) [dim] + * @param rmsScale Precomputed RMS scale factor [1] (from reduction kernel) + * @param w Weight matrix (FP16) [outputDim × dim] + * @param inputDim Input dimension (dim) + * @param outputDim Output dimension + * @param localWorkGroupSize Local work group size for reduction + */ + public static void fusedRmsNormMatmul( + KernelContext context, + FloatArray x, // input (FP32) + FloatArray output, // output (FP32) + FloatArray rmsWeights, // RMS norm weights + FloatArray rmsScale, // temp[0] = scale factor + HalfFloatArray w, // weight matrix + int inputDim, // input dimension + int outputDim, // output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= outputDim) { + return; + } + + float scale = rmsScale.get(0); + + // Allocate shared memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + int rowOffset = rowId * inputDim; + + // Each thread computes partial dot product with inline normalization + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum += w.get(rowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + // Thread 0 writes final result + if (localId == 0) { + output.set(rowId, localSum[0]); + } + } + + /** + * Phi3 RoPE rotation with fused KV cache copy. + * + *

Phi3 uses a different RoPE pattern than Llama/Qwen:

+ *
    + *
  • Pairs elements with offset headSize/2 (not adjacent pairs)
  • + *
  • Each thread processes one dimension pair across all heads
  • + *
  • Iterates over heads internally
  • + *
+ * + *

This fused kernel combines:

+ *
    + *
  • Phi3-style RoPE rotation for Q and K
  • + *
  • Direct cache write for rotated K
  • + *
  • Direct cache copy for V (no rotation)
  • + *
+ * + * @param context Kernel execution context + * @param positionHolder Current position in sequence [1] + * @param sq Query vectors (in/out, rotated) [dim] + * @param sk Key vectors (in/out, rotated) [kvDim] + * @param sv Value vectors (in only) [kvDim] + * @param keyCache Key cache (out) [layers × contextLength × kvDim] + * @param valueCache Value cache (out) [layers × contextLength × kvDim] + * @param nHeadKv Number of KV heads + * @param headSize Dimension per head + * @param kvDim Total KV dimension (nHeadKv × headSize) + * @param layer Current layer index + * @param contextLength Maximum sequence length + */ + public static void ropeRotationWithCacheCopyPhi3( + KernelContext context, + IntArray positionHolder, + FloatArray sq, // Q vector (in/out) + FloatArray sk, // K vector (in/out) + FloatArray sv, // V vector (in only) + FloatArray keyCache, // Key cache (out) + FloatArray valueCache, // Value cache (out) + int nHeadKv, + int headSize, + int kvDim, + int layer, + int contextLength) { + + int idx = context.globalIdx; + int dimHalf = headSize / 2; + + // Each thread processes one dimension pair + if (idx >= dimHalf) { + return; + } + + int pos = positionHolder.get(0); + int cacheOffset = layer * contextLength * kvDim + pos * kvDim; + + // Calculate frequency for this dimension + float freq = 1.0f / TornadoMath.pow(10000.0f, (float) (idx * 2) / (float) headSize); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Process Q: all heads (dim = nHeads × headSize) + int totalDimQ = sq.getSize(); + for (int base = 0; base < totalDimQ; base += headSize) { + if (base + idx >= totalDimQ || base + idx + dimHalf >= totalDimQ) { + break; + } + + // Rotate Q with offset pattern + float v0 = sq.get(base + idx); + float v1 = sq.get(base + idx + dimHalf); + sq.set(base + idx, v0 * fcr - v1 * fci); + sq.set(base + idx + dimHalf, v0 * fci + v1 * fcr); + } + + // Process K: only kvDim elements, with cache write + for (int base = 0; base < kvDim; base += headSize) { + if (base + idx >= kvDim || base + idx + dimHalf >= kvDim) { + break; + } + + // Rotate K with offset pattern + float k0 = sk.get(base + idx); + float k1 = sk.get(base + idx + dimHalf); + float rotated0 = k0 * fcr - k1 * fci; + float rotated1 = k0 * fci + k1 * fcr; + + // Write rotated K back + sk.set(base + idx, rotated0); + sk.set(base + idx + dimHalf, rotated1); + + // Fused cache write for K + keyCache.set(cacheOffset + base + idx, rotated0); + keyCache.set(cacheOffset + base + idx + dimHalf, rotated1); + + // Fused cache copy for V (no rotation needed) + valueCache.set(cacheOffset + base + idx, sv.get(base + idx)); + valueCache.set(cacheOffset + base + idx + dimHalf, sv.get(base + idx + dimHalf)); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 0099e92d..4623da38 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -3,6 +3,7 @@ import org.beehive.gpullama3.inference.state.Phi3State; import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Phi3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; @@ -19,40 +20,27 @@ /** * Phi3FP16FFNLayers: FP16 FFN layers for Phi3 with Group Query Attention (GQA) support. * - * Key Differences from Qwen2/Qwen3: - * - Uses combined QKV matrix (wqkv) instead of separate Q, K, V matrices - * - Includes splitQKV task to separate combined buffer - * - Uses ropeRotationPhi3 kernel for position embeddings - * - FFN uses single wUp matrix that outputs both Gate and Up (2 * hiddenDim) - * - Includes splitGateUpAndSiLU task for FFN activation - * - Uses wDown for final FFN projection - * - No Q, K, V bias terms + * Key Differences from Qwen2/Qwen3: - Uses combined QKV matrix (wqkv) instead of separate Q, K, V matrices - Includes splitQKV task to separate combined buffer - Uses ropeRotationPhi3 kernel for + * position embeddings - FFN uses single wUp matrix that outputs both Gate and Up (2 * hiddenDim) - Includes splitGateUpAndSiLU task for FFN activation - Uses wDown for final FFN projection - No Q, K, + * V bias terms * * Works directly with Phi3State to access and mutate Phi3-specific state fields. */ public class Phi3FP16FFNLayers extends AbstractFFNLayers { - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; - List ffnLayerTaskGraphs; - // Typed references to Phi3-specific state and config private final Phi3State phi3State; private final Phi3Configuration phi3Config; - // Phi3-specific dimension for combined QKV buffer private final int opSize; + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { - super(taskGraphName, state, weights, config,schedulerType); + super(taskGraphName, state, weights, config, schedulerType); this.phi3State = state; this.phi3Config = config; - - // Ensure we have Phi3-specific weights - if (!(weights instanceof Phi3TornadoWeights phi3Weights)) { - throw new IllegalArgumentException("Phi3FP16FFNLayers requires Phi3TornadoWeights with TornadoTensor layout"); - } - // Calculate opSize for combined QKV buffer // opSize = num_heads * head_dim + 2 * (num_key_value_heads * head_dim) = dim + 2 * kvDim this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); @@ -64,55 +52,49 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { // RMS norm worker WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize); - // Combined QKV matmul worker - int matmulQkvGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid matmulQkvRowMajorWorker = WorkerGridFactory.genericWorker(matmulQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + // Fused RMS + QKV matmul worker + int fusedQkvGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - // RoPE worker (2D: heads x embedding_head/2) - int ic = config.headSize() / 2; -// WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), config.headSize()); + // SplitQKV worker + WorkerGrid splitQKVWorker = WorkerGridFactory.genericWorker(opSize, 128); - // Copy to cache worker - WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.kvDim(), 32); + // Fused RoPE + cache copy worker (Phi3 uses dim/2 pattern) + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); // Parallel attention worker WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); - // Matmul1 worker (output projection) + // Output projection worker int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC); - // FFN workers - int ffnUpGlobal = (2 * config.hiddenDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid ffnUpWorker = WorkerGridFactory.genericWorker(ffnUpGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - - int ffnDownGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - - WorkerGrid splitQKVWorker = WorkerGridFactory.genericWorker(opSize, 128); + // Fused RMS + FFN gate/up worker + int fusedFFNGlobal = (2 * config.hiddenDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedFFNWorker = WorkerGridFactory.genericWorker(fusedFFNGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); // SplitGateUpAndSiLU worker WorkerGrid splitGateUpSiLUWorker = WorkerGridFactory.genericWorker(config.hiddenDim(), 128); - WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); + // FFN down projection worker + int ffnDownGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - // Map workers to tasks for each layer + // Map workers to tasks for each layer (in task execution order) for (int i = 0; i < config.numberOfLayers(); i++) { + // === Attention Block === + gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_matmul", fusedQkvWorker); gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker); + // === FFN Block === + gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNWorker); gridScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", matmulQkvRowMajorWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); - gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".wGateUp", ffnUpWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".wDown", ffnDownWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", ffnDownWorker); } - return gridScheduler; } @@ -140,11 +122,6 @@ public List getFfnLayerTaskGraphs() { */ List setupFFNLayered() { List ffnGraphs = new ArrayList<>(); - - // Initialize buffers using Phi3State directly - phi3State.temp.init(0.0f); - phi3State.tempFFN.init(0.0f); - for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSinglePhi3FFNLayer((Phi3TornadoWeights) weights, layerIndex); if (layerIndex == phi3Config.numberOfLayers() - 1) { @@ -152,160 +129,257 @@ List setupFFNLayered() { } ffnGraphs.add(ffnLayer.snapshot()); } - return ffnGraphs; } + // @formatter:off /** - * Setup a single transformer layer for Phi3 with combined QKV and gate/up FFN + * Transformer Layer Task Flow (Phi3FP16FFNLayers - Optimized) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (scale factor for RMSNorm) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────────┐ + * │ attn_rms_qkv_matmul │──▶ wrapQkv (combined Q+K+V, FP32) + * └──────────┬──────────┘ (fused: RMS apply + QKV matmul) + * │ + * ▼ + * ┌──────────┐ + * │ splitQKV │──▶ wrapQ, wrapK, wrapV (separated) + * └────┬─────┘ + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ (fused: Phi3 RoPE + cache write) + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (scale factor) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ ffn_rms_finalize │──▶ tempFFN (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = RMSNorm(x) · wUp [2×hiddenDim] + * └────────┬────────┘ (fused: RMS apply + gate/up matmul) + * │ + * ▼ + * ┌────────────┐ + * │ gateUpSiLU │──▶ wrapHbU = SiLU(gate) ⊙ up + * └──────┬─────┘ + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += wDown · wrapHbU (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 10 tasks (NVIDIA) / 11 tasks (non-NVIDIA) + * Previous: 13 tasks + * Reduction: 3 tasks eliminated (23% fewer kernel launches) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points (vs previous 13 tasks): + * • attn_rms_qkv_matmul: Fused RMS apply + combined QKV matmul (2→1 kernel) + * • rope_and_kv_cache: Fused Phi3 RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_gate_up: Fused RMS apply + gate/up matmul (2→1 kernel) + * + * Phi3-Specific: + * • Combined wqkv: Single [opSize × dim] matrix for Q+K+V projection + * • Split after projection: splitQKV separates combined buffer + * • Phi3 RoPE: Uses headSize/2 offset pattern (different from Llama/Qwen) + * • Combined wUp: Single [2×hiddenDim × dim] matrix for gate+up + * • Split after SiLU: gateUpSiLU separates and applies activation + * */ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { + var taskGraphName = "layer_" + layerIndex; - TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); + TaskGraph unifiedLayer = new TaskGraph(taskGraphName); unifiedLayer.consumeFromDevice(phi3State.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - // Copy-in weights per layer for batched-layered layout + // Attention weights weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqkvLayered[layerIndex].asHalfFloatArray(), weights.woLayered[layerIndex].asHalfFloatArray(), + // FFN weights weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), weights.wUpLayered[layerIndex].asHalfFloatArray(), - weights.wDownLayered[layerIndex].asHalfFloatArray() - ); + weights.wDownLayered[layerIndex].asHalfFloatArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - // RMSNorm for attention input - unifiedLayer.task("reductionsOneBlock", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, - phi3State.temp, - phi3State.wrapX, - phi3Config.dim(), - phi3Config.rmsNormEps(), - phi3State.localSize) - .task("mapContext", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, - context, - phi3State.wrapXb, - phi3State.wrapX, - weights.rms_att_weightLayered[layerIndex].asFloatArray(), - phi3State.temp); - - // Combined QKV projection - unifiedLayer.task("qkvmatmul", - TransformerComputeKernelsLayered::matrixVectorGeneric, - context, - phi3State.wrapXb, - phi3State.wrapQkv, - weights.wqkvLayered[layerIndex].asHalfFloatArray(), - phi3Config.dim(), - opSize, - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("splitQKV", - TransformerComputeKernelsLayered::splitQKV, - phi3State.wrapQkv, - phi3State.wrapQ, - phi3State.wrapK, - phi3State.wrapV, - phi3Config.dim(), - phi3Config.headSize() * phi3Config.numberOfKeyValueHeads()); - - // RoPE rotation (Phi3-specific kernel) - unifiedLayer.task("rope", - TransformerComputeKernelsLayered::ropeRotationPhi3, + // ═══════════════════════════════════════════════════════════════════════ + // ATTENTION BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, - phi3State.positionHolder, - phi3State.wrapQ, - phi3State.wrapK, - phi3Config.kvDim(), - phi3Config.headSize()); + phi3State.temp, // output: scale factor + phi3State.wrapX, // input: hidden state + phi3Config.dim(), // dimension + phi3Config.rmsNormEps(), // epsilon + phi3State.localSize); // local memory size + + // Fused RMS Apply + QKV Projection (combined matrix) + unifiedLayer.task("attn_rms_qkv_matmul", + Phi3Kernels::fusedRmsNormMatmul, + context, + phi3State.wrapX, // input: raw hidden state (FP32) + phi3State.wrapQkv, // output: combined Q+K+V + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights + phi3State.temp, // RMS scale factor from reduction + weights.wqkvLayered[layerIndex].asHalfFloatArray(), // Wqkv [opSize × dim] + phi3Config.dim(), // input dimension + opSize, // output dimension (Q + K + V) + LOCAL_WORK_GROUP_SIZE_ALLOC); - // Copy to caches - unifiedLayer.task("copyToCaches", - TransformerComputeKernelsLayered::copyToCache, - phi3State.wrapKeyCache, + // Split combined QKV into separate Q, K, V buffers + unifiedLayer.task("splitQKV", + TransformerComputeKernelsLayered::splitQKV, + phi3State.wrapQkv, + phi3State.wrapQ, phi3State.wrapK, - phi3State.wrapValueCache, phi3State.wrapV, - phi3State.positionHolder, - phi3Config.kvDim(), - layerIndex, - phi3Config.contextLength()); + phi3Config.dim(), + phi3Config.headSize() * phi3Config.numberOfKeyValueHeads()); - // Parallel attention - unifiedLayer.task("parallel-attention", + // Fused Phi3 RoPE Rotation + KV Cache Write + unifiedLayer.task("rope_and_kv_cache", + Phi3Kernels::ropeRotationWithCacheCopyPhi3, + context, + phi3State.positionHolder, // current position + phi3State.wrapQ, // Q vectors (in/out, rotated) + phi3State.wrapK, // K vectors (in/out, rotated) + phi3State.wrapV, // V vectors (in only) + phi3State.wrapKeyCache, // key cache (out) + phi3State.wrapValueCache, // value cache (out) + phi3Config.numberOfKeyValueHeads(), // nHeadKv + phi3Config.headSize(), // head dimension + phi3Config.kvDim(), // kvDim + layerIndex, // layer index for cache offset + phi3Config.contextLength()); // max sequence length + + // Flash Attention + unifiedLayer.task("attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, - phi3State.wrapQ, - phi3State.wrapKeyCache, - phi3State.wrapValueCache, - phi3State.wrapXb, - phi3Config.numberOfHeads(), - phi3Config.headSize(), - phi3Config.kvDim(), - phi3Config.kvMul(), - phi3State.positionHolder, - layerIndex, - phi3Config.contextLength()); - - // Output projection - unifiedLayer.task("matmul1", + phi3State.wrapQ, // query vectors + phi3State.wrapKeyCache, // key cache + phi3State.wrapValueCache, // value cache + phi3State.wrapXb, // output: attention result + phi3Config.numberOfHeads(), // nHeads + phi3Config.headSize(), // headSize + phi3Config.kvDim(), // kvDim + phi3Config.kvMul(), // kvMul (nHeads / nHeadKv) + phi3State.positionHolder, // position + layerIndex, // layer index + phi3Config.contextLength()); // context length + + // Output Projection with Residual + unifiedLayer.task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - phi3State.wrapXb, - phi3State.wrapX, - weights.woLayered[layerIndex].asHalfFloatArray(), - phi3Config.dim(), - phi3Config.dim(), + phi3State.wrapXb, // input: attention output + phi3State.wrapX, // output: wrapX += Wo · wrapXb + weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim × dim] + phi3Config.dim(), // input dim + phi3Config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); - // FFN section: RMSNorm - unifiedLayer.task("reductionsOneBlockFFN", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, - phi3State.tempFFN, - phi3State.wrapX, - phi3Config.dim(), - phi3Config.rmsNormEps(), - phi3State.localSize) - .task("mapContextFFN", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, - context, - phi3State.wrapXb, - phi3State.wrapX, - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - phi3State.tempFFN); - - // FFN: combined Up and Gate projection (outputs 2 * hiddenDim) - unifiedLayer.task("wGateUp", - TransformerComputeKernelsLayered::matrixVectorGeneric, - context, - phi3State.wrapXb, - phi3State.wrapHb, - weights.wUpLayered[layerIndex].asHalfFloatArray(), - phi3Config.dim(), - 2 * phi3Config.hiddenDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("gateUpSiLU", - TransformerComputeKernelsLayered::splitGateUpAndSiLU, - phi3State.wrapHb, - phi3State.wrapHbG, - phi3State.wrapHbU, - phi3Config.hiddenDim()); - - // FFN: Down projection with residual - unifiedLayer.task("wDown", + // ═══════════════════════════════════════════════════════════════════════ + // FFN BLOCK + // ═══════════════════════════════════════════════════════════════════════ + + // RMS Normalization - compute scale factor + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + phi3State.tempFFN, // output: scale factor + phi3State.wrapX, // input: hidden state + phi3Config.dim(), // dimension + phi3Config.rmsNormEps(), // epsilon + phi3State.localSize); // local memory size + + // Final normalization (non-NVIDIA only) + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + phi3State.tempFFN, // scale factor (in/out) + phi3Config.dim(), // dimension + phi3Config.rmsNormEps()); // epsilon + } + + // Fused RMS Apply + Gate/Up Projection (combined output) + unifiedLayer.task("rms_ffn_gate_up", + Phi3Kernels::fusedRmsNormMatmul, + context, + phi3State.wrapX, // input: raw hidden state (FP32) + phi3State.wrapHb, // output: gate + up combined [2×hiddenDim] + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + phi3State.tempFFN, // RMS scale factor + weights.wUpLayered[layerIndex].asHalfFloatArray(), // wUp [2×hiddenDim × dim] + phi3Config.dim(), // input dimension + 2 * phi3Config.hiddenDim(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Split Gate/Up and apply SiLU activation + unifiedLayer.task("gateUpSiLU", + TransformerComputeKernelsLayered::splitGateUpAndSiLU, + phi3State.wrapHb, // input: gate + up combined + phi3State.wrapHbG, // output: SiLU(gate) (intermediate) + phi3State.wrapHbU, // output: SiLU(gate) ⊙ up + phi3Config.hiddenDim()); // hidden dimension + + // Down Projection with Residual + unifiedLayer.task("ffn_down_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - phi3State.wrapHbU, - phi3State.wrapX, - weights.wDownLayered[layerIndex].asHalfFloatArray(), - phi3Config.hiddenDim(), - phi3Config.dim(), + phi3State.wrapHbU, // input: FFN intermediate + phi3State.wrapX, // output: wrapX += wDown · wrapHbU + weights.wDownLayered[layerIndex].asHalfFloatArray(), // wDown [dim × hiddenDim] + phi3Config.hiddenDim(), // input dim + phi3Config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice( - phi3State.wrapX - ); + .persistOnDevice(phi3State.wrapX); + return unifiedLayer; } @@ -313,26 +387,31 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { * Configure data transfers for first and subsequent layers */ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { - // First layer: Transfer initial data to device (one-time transfer) if (layerIndex == 0) { - // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); // + // First layer: Transfer temporary buffers and state every execution + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + phi3State.positionHolder); + // First execution: allocate workspace buffers + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, phi3State.wrapXb, phi3State.wrapXb2, + phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, + phi3State.wrapKeyCache, phi3State.wrapValueCache, + phi3State.wrapAtt, phi3State.wrapHb, + phi3State.temp, phi3State.tempFFN, + phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); } else { - // Subsequent layers: Consume data already on device from previous layer - unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - state.positionHolder, // / + // Subsequent layers: Consume data from previous layer + unifiedLayer.consumeFromDevice( + context, phi3State.wrapXb, phi3State.wrapXb2, + phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, + phi3State.wrapKeyCache, phi3State.wrapValueCache, + phi3State.wrapAtt, phi3State.wrapHb, + phi3State.positionHolder, + phi3State.temp, phi3State.tempFFN, phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); } return unifiedLayer; } + // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index d4328a1d..63981c55 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -69,6 +69,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); WorkerGrid splitGateUpSiLUWorker = WorkerGridFactory.genericWorker(config.hiddenDim(), 128); WorkerGrid splitQKVWorker = WorkerGridFactory.genericWorker(opSize, 128); + for (int i = 0; i < config.numberOfLayers(); i++) { tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", qkvDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); From ed74652176fff0f7a403577ecea31dbd986026ad Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 23:03:44 +0200 Subject: [PATCH 27/42] Replace `splitQKV` kernel with `fusedRmsNormQKVMatmulDirect`, refactor Phi3 FP16 FFN layers to consolidate QKV projection tasks, and update worker grid/task configurations. --- .../tornadovm/kernels/Phi3Kernels.java | 91 +++++++++++++++++++ .../layers/type/fp16/Phi3FP16FFNLayers.java | 63 ++++++++----- 2 files changed, 132 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java index 6c02cb17..18003cc4 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java @@ -192,4 +192,95 @@ public static void ropeRotationWithCacheCopyPhi3( valueCache.set(cacheOffset + base + idx + dimHalf, sv.get(base + idx + dimHalf)); } } + + /** + * Fused RMSNorm apply + QKV projection with direct output to separate Q, K, V buffers. + * + *

Eliminates the need for a separate splitQKV kernel by routing outputs + * directly based on row index:

+ *
    + *
  • Rows [0, dim): Q projection
  • + *
  • Rows [dim, dim+kvDim): K projection
  • + *
  • Rows [dim+kvDim, dim+2*kvDim): V projection
  • + *
+ * + *

Formula: output[row] = sum_j(Wqkv[row,j] * rmsWeight[j] * scale * x[j])

+ * + * @param context Kernel execution context + * @param x Input hidden state (FP32) [dim] + * @param q Output Q buffer (FP32) [dim] + * @param k Output K buffer (FP32) [kvDim] + * @param v Output V buffer (FP32) [kvDim] + * @param rmsWeights RMS normalization weights (FP32) [dim] + * @param rmsScale Precomputed RMS scale factor [1] + * @param wqkv Combined QKV weight matrix (FP16) [opSize × dim] + * @param dim Model dimension (Q output size) + * @param kvDim KV dimension (K/V output size) + * @param localWorkGroupSize Local work group size for reduction + */ + public static void fusedRmsNormQKVMatmulDirect( + KernelContext context, + FloatArray x, // input (FP32) + FloatArray q, // output Q (FP32) + FloatArray k, // output K (FP32) + FloatArray v, // output V (FP32) + FloatArray rmsWeights, // RMS norm weights + FloatArray rmsScale, // temp[0] = scale factor + HalfFloatArray wqkv, // combined QKV weight matrix + int dim, // input dim and Q output dim + int kvDim, // K/V output dim + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Total rows = dim (Q) + kvDim (K) + kvDim (V) + int totalRows = dim + 2 * kvDim; + if (rowId >= totalRows) { + return; + } + + float scale = rmsScale.get(0); + + // Allocate shared memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + int rowOffset = rowId * dim; + + // Each thread computes partial dot product with inline normalization + float partialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum += wqkv.get(rowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + // Thread 0 writes to appropriate output buffer + if (localId == 0) { + float result = localSum[0]; + + if (rowId < dim) { + // Q projection: rows [0, dim) + q.set(rowId, result); + } else if (rowId < dim + kvDim) { + // K projection: rows [dim, dim+kvDim) + int kIdx = rowId - dim; + k.set(kIdx, result); + } else { + // V projection: rows [dim+kvDim, dim+2*kvDim) + int vIdx = rowId - dim - kvDim; + v.set(vIdx, result); + } + } + } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 4623da38..1df3cc8e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -79,13 +79,17 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { // FFN down projection worker int ffnDownGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + // Same worker as before - total rows = dim + 2*kvDim = opSize + // Remove: gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); // Map workers to tasks for each layer (in task execution order) for (int i = 0; i < config.numberOfLayers(); i++) { // === Attention Block === gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_matmul", fusedQkvWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); +// gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_matmul", fusedQkvWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQkvWorker); + +// gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker); @@ -257,29 +261,44 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3Config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size - // Fused RMS Apply + QKV Projection (combined matrix) - unifiedLayer.task("attn_rms_qkv_matmul", - Phi3Kernels::fusedRmsNormMatmul, +// // Fused RMS Apply + QKV Projection (combined matrix) +// unifiedLayer.task("attn_rms_qkv_matmul", +// Phi3Kernels::fusedRmsNormMatmul, +// context, +// phi3State.wrapX, // input: raw hidden state (FP32) +// phi3State.wrapQkv, // output: combined Q+K+V +// weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights +// phi3State.temp, // RMS scale factor from reduction +// weights.wqkvLayered[layerIndex].asHalfFloatArray(), // Wqkv [opSize × dim] +// phi3Config.dim(), // input dimension +// opSize, // output dimension (Q + K + V) +// LOCAL_WORK_GROUP_SIZE_ALLOC); +// +// // Split combined QKV into separate Q, K, V buffers +// unifiedLayer.task("splitQKV", +// TransformerComputeKernelsLayered::splitQKV, +// phi3State.wrapQkv, +// phi3State.wrapQ, +// phi3State.wrapK, +// phi3State.wrapV, +// phi3Config.dim(), +// phi3Config.headSize() * phi3Config.numberOfKeyValueHeads()); + + // AFTER: 1 task + unifiedLayer.task("attn_rms_qkv_projection", + Phi3Kernels::fusedRmsNormQKVMatmulDirect, context, - phi3State.wrapX, // input: raw hidden state (FP32) - phi3State.wrapQkv, // output: combined Q+K+V - weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights - phi3State.temp, // RMS scale factor from reduction - weights.wqkvLayered[layerIndex].asHalfFloatArray(), // Wqkv [opSize × dim] - phi3Config.dim(), // input dimension - opSize, // output dimension (Q + K + V) + phi3State.wrapX, // input + phi3State.wrapQ, // output Q + phi3State.wrapK, // output K + phi3State.wrapV, // output V + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + phi3State.temp, // RMS scale + weights.wqkvLayered[layerIndex].asHalfFloatArray(), + phi3Config.dim(), // dim + phi3Config.kvDim(), // kvDim LOCAL_WORK_GROUP_SIZE_ALLOC); - // Split combined QKV into separate Q, K, V buffers - unifiedLayer.task("splitQKV", - TransformerComputeKernelsLayered::splitQKV, - phi3State.wrapQkv, - phi3State.wrapQ, - phi3State.wrapK, - phi3State.wrapV, - phi3Config.dim(), - phi3Config.headSize() * phi3Config.numberOfKeyValueHeads()); - // Fused Phi3 RoPE Rotation + KV Cache Write unifiedLayer.task("rope_and_kv_cache", Phi3Kernels::ropeRotationWithCacheCopyPhi3, From 8b52fbe9b8fb32bb655b8db4173be7bca50a44af Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 23:05:52 +0200 Subject: [PATCH 28/42] Remove unused `splitQKV` and RMS Apply+QKV Projection kernels, update Phi3 FP16 FFN layers to streamline task configuration and clean up commented code. --- .../layers/type/fp16/Phi3FP16FFNLayers.java | 27 ------------------- 1 file changed, 27 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 1df3cc8e..608e4c7b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -86,10 +86,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { for (int i = 0; i < config.numberOfLayers(); i++) { // === Attention Block === gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); -// gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_matmul", fusedQkvWorker); gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQkvWorker); - -// gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker); gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker); @@ -261,30 +258,6 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3Config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size -// // Fused RMS Apply + QKV Projection (combined matrix) -// unifiedLayer.task("attn_rms_qkv_matmul", -// Phi3Kernels::fusedRmsNormMatmul, -// context, -// phi3State.wrapX, // input: raw hidden state (FP32) -// phi3State.wrapQkv, // output: combined Q+K+V -// weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights -// phi3State.temp, // RMS scale factor from reduction -// weights.wqkvLayered[layerIndex].asHalfFloatArray(), // Wqkv [opSize × dim] -// phi3Config.dim(), // input dimension -// opSize, // output dimension (Q + K + V) -// LOCAL_WORK_GROUP_SIZE_ALLOC); -// -// // Split combined QKV into separate Q, K, V buffers -// unifiedLayer.task("splitQKV", -// TransformerComputeKernelsLayered::splitQKV, -// phi3State.wrapQkv, -// phi3State.wrapQ, -// phi3State.wrapK, -// phi3State.wrapV, -// phi3Config.dim(), -// phi3Config.headSize() * phi3Config.numberOfKeyValueHeads()); - - // AFTER: 1 task unifiedLayer.task("attn_rms_qkv_projection", Phi3Kernels::fusedRmsNormQKVMatmulDirect, context, From 977f0baad7e7d1d67616f4efa88530563a795f46 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 23:12:14 +0200 Subject: [PATCH 29/42] Add `fusedRmsNormFFNGateUpSiLU` kernel to optimize Phi3 FFN flow, replace `rms_ffn_gate_up` and `gateUpSiLU` tasks with a single fused task, streamline task graph and update documentation. --- .../tornadovm/kernels/Phi3Kernels.java | 94 ++++++++++ .../layers/type/fp16/Phi3FP16FFNLayers.java | 166 +++++------------- 2 files changed, 142 insertions(+), 118 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java index 18003cc4..d45a62c4 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java @@ -283,4 +283,98 @@ public static void fusedRmsNormQKVMatmulDirect( } } } + /** + * Fused RMSNorm apply + Gate/Up projection + SiLU + GLU in one kernel. + * + *

Eliminates the need for separate gateUpSiLU kernel by computing both + * gate and up projections per workgroup and applying activation inline.

+ * + *

For each output index i:

+ *
    + *
  • gate[i] = dot(wUp[i], RMSNorm(x))
  • + *
  • up[i] = dot(wUp[hiddenDim + i], RMSNorm(x))
  • + *
  • output[i] = SiLU(gate[i]) × up[i]
  • + *
+ * + * @param context Kernel execution context + * @param x Input hidden state (FP32) [dim] + * @param output Output buffer (FP32) [hiddenDim] - final FFN result + * @param rmsWeights RMS normalization weights (FP32) [dim] + * @param rmsScale Precomputed RMS scale factor [1] + * @param wUp Combined gate+up weight matrix (FP16) [2×hiddenDim × dim] + * @param dim Input dimension + * @param hiddenDim Hidden dimension (output size) + * @param localWorkGroupSize Local work group size for reduction + */ + public static void fusedRmsNormFFNGateUpSiLU( + KernelContext context, + FloatArray x, // input (FP32) + FloatArray output, // output (FP32) [hiddenDim] + FloatArray rmsWeights, // RMS norm weights + FloatArray rmsScale, // temp[0] = scale factor + HalfFloatArray wUp, // combined gate+up weights [2×hiddenDim × dim] + int dim, // input dimension + int hiddenDim, // output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= hiddenDim) { + return; + } + + float scale = rmsScale.get(0); + + // Allocate shared memory for reduction + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + // === Compute GATE (row i) === + int gateRowOffset = rowId * dim; + + float gatePartialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + gatePartialSum += wUp.get(gateRowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = gatePartialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + float gateResult = localSum[0]; + + // === Compute UP (row hiddenDim + i) === + int upRowOffset = (hiddenDim + rowId) * dim; + + float upPartialSum = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normalized = rmsWeights.get(j) * scale * x.get(j); + upPartialSum += wUp.get(upRowOffset + j).getFloat32() * normalized; + } + + localSum[localId] = upPartialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + float upResult = localSum[0]; + + // === Apply SiLU(gate) × up === + if (localId == 0) { + float silu = gateResult / (1.0f + TornadoMath.exp(-gateResult)); + output.set(rowId, silu * upResult); + } + } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 608e4c7b..a10e9dc0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -81,8 +81,6 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); // Same worker as before - total rows = dim + 2*kvDim = opSize - // Remove: gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); - // Map workers to tasks for each layer (in task execution order) for (int i = 0; i < config.numberOfLayers(); i++) { // === Attention Block === gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); @@ -92,8 +90,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker); // === FFN Block === gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_silu", fusedFFNWorker); gridScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", ffnDownWorker); } return gridScheduler; @@ -135,7 +132,7 @@ List setupFFNLayered() { // @formatter:off /** - * Transformer Layer Task Flow (Phi3FP16FFNLayers - Optimized) + * Transformer Layer Task Flow (Phi3FP16FFNLayers - Fully Optimized) * * ══════════════════════════════════════════════════════════════════════════════ * ATTENTION BLOCK @@ -149,16 +146,11 @@ List setupFFNLayered() { * └────────┬────────┘ * │ * ▼ - * ┌─────────────────────┐ - * │ attn_rms_qkv_matmul │──▶ wrapQkv (combined Q+K+V, FP32) - * └──────────┬──────────┘ (fused: RMS apply + QKV matmul) - * │ - * ▼ - * ┌──────────┐ - * │ splitQKV │──▶ wrapQ, wrapK, wrapV (separated) - * └────┬─────┘ - * │ - * ▼ + * ┌────────────────────────┐ + * │ attn_rms_qkv_projection│──▶ wrapQ, wrapK, wrapV (direct output) + * └───────────┬────────────┘ (fused: RMS apply + QKV matmul + split) + * │ + * ▼ * ┌───────────────────┐ ┌─────────────────────────────────────┐ * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ * └─────────┬─────────┘ └─────────────────────────────────────┘ @@ -188,14 +180,9 @@ List setupFFNLayered() { * └────────┬─────────┘ * │ * ▼ - * ┌─────────────────┐ - * │ rms_ffn_gate_up │──▶ wrapHb = RMSNorm(x) · wUp [2×hiddenDim] - * └────────┬────────┘ (fused: RMS apply + gate/up matmul) - * │ - * ▼ - * ┌────────────┐ - * │ gateUpSiLU │──▶ wrapHbU = SiLU(gate) ⊙ up - * └──────┬─────┘ + * ┌──────────────┐ + * │ rms_ffn_silu │──▶ wrapHbU = SiLU(RMSNorm(x)·Wgate) ⊙ (RMSNorm(x)·Wup) + * └──────┬───────┘ (fused: RMS apply + gate/up matmul + SiLU + GLU) * │ * ▼ * ┌──────────────┐ @@ -207,41 +194,37 @@ List setupFFNLayered() { * * ══════════════════════════════════════════════════════════════════════════════ * - * Task Count: 10 tasks (NVIDIA) / 11 tasks (non-NVIDIA) - * Previous: 13 tasks - * Reduction: 3 tasks eliminated (23% fewer kernel launches) + * Task Count: 8 tasks (NVIDIA) / 9 tasks (non-NVIDIA) + * Original: 13 tasks + * Reduction: 5 tasks eliminated (38% fewer kernel launches) * * Data Flow Summary: * Input: wrapX (FP32) - hidden state from previous layer * Output: wrapX (FP32) - updated hidden state with residual connections * - * Key Fusion Points (vs previous 13 tasks): - * • attn_rms_qkv_matmul: Fused RMS apply + combined QKV matmul (2→1 kernel) - * • rope_and_kv_cache: Fused Phi3 RoPE rotation + cache write (2→1 kernel) - * • rms_ffn_gate_up: Fused RMS apply + gate/up matmul (2→1 kernel) + * Key Fusion Points (vs original 13 tasks): + * • attn_rms_qkv_projection: Fused RMS apply + QKV matmul + direct split (3→1 kernel) + * • rope_and_kv_cache: Fused Phi3 RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_silu: Fused RMS apply + gate/up matmul + SiLU + GLU (3→1 kernel) * * Phi3-Specific: * • Combined wqkv: Single [opSize × dim] matrix for Q+K+V projection - * • Split after projection: splitQKV separates combined buffer + * • Direct QKV output: No intermediate buffer, routes by row index * • Phi3 RoPE: Uses headSize/2 offset pattern (different from Llama/Qwen) * • Combined wUp: Single [2×hiddenDim × dim] matrix for gate+up - * • Split after SiLU: gateUpSiLU separates and applies activation + * • Inline SiLU+GLU: No intermediate wrapHb buffer needed * */ + // @formatter:on TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { var taskGraphName = "layer_" + layerIndex; - - TaskGraph unifiedLayer = new TaskGraph(taskGraphName); + var unifiedLayer = new TaskGraph(taskGraphName); unifiedLayer.consumeFromDevice(phi3State.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights - weights.rms_att_weightLayered[layerIndex].asFloatArray(), - weights.wqkvLayered[layerIndex].asHalfFloatArray(), - weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqkvLayered[layerIndex].asHalfFloatArray(), weights.woLayered[layerIndex].asHalfFloatArray(), // FFN weights - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - weights.wUpLayered[layerIndex].asHalfFloatArray(), - weights.wDownLayered[layerIndex].asHalfFloatArray()); + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), weights.wUpLayered[layerIndex].asHalfFloatArray(), weights.wDownLayered[layerIndex].asHalfFloatArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); // ═══════════════════════════════════════════════════════════════════════ @@ -249,34 +232,23 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { // ═══════════════════════════════════════════════════════════════════════ // RMS Normalization - compute scale factor - unifiedLayer.task("attn_rms_reduce", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, - phi3State.temp, // output: scale factor + unifiedLayer.task("attn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, phi3State.temp, // output: scale factor phi3State.wrapX, // input: hidden state phi3Config.dim(), // dimension phi3Config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size - unifiedLayer.task("attn_rms_qkv_projection", - Phi3Kernels::fusedRmsNormQKVMatmulDirect, - context, - phi3State.wrapX, // input + unifiedLayer.task("attn_rms_qkv_projection", Phi3Kernels::fusedRmsNormQKVMatmulDirect, context, phi3State.wrapX, // input phi3State.wrapQ, // output Q phi3State.wrapK, // output K phi3State.wrapV, // output V - weights.rms_att_weightLayered[layerIndex].asFloatArray(), - phi3State.temp, // RMS scale - weights.wqkvLayered[layerIndex].asHalfFloatArray(), - phi3Config.dim(), // dim + weights.rms_att_weightLayered[layerIndex].asFloatArray(), phi3State.temp, // RMS scale + weights.wqkvLayered[layerIndex].asHalfFloatArray(), phi3Config.dim(), // dim phi3Config.kvDim(), // kvDim LOCAL_WORK_GROUP_SIZE_ALLOC); // Fused Phi3 RoPE Rotation + KV Cache Write - unifiedLayer.task("rope_and_kv_cache", - Phi3Kernels::ropeRotationWithCacheCopyPhi3, - context, - phi3State.positionHolder, // current position + unifiedLayer.task("rope_and_kv_cache", Phi3Kernels::ropeRotationWithCacheCopyPhi3, context, phi3State.positionHolder, // current position phi3State.wrapQ, // Q vectors (in/out, rotated) phi3State.wrapK, // K vectors (in/out, rotated) phi3State.wrapV, // V vectors (in only) @@ -289,10 +261,7 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3Config.contextLength()); // max sequence length // Flash Attention - unifiedLayer.task("attention", - TransformerComputeKernelsLayered::processHeadsFlashAttention, - context, - phi3State.wrapQ, // query vectors + unifiedLayer.task("attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, phi3State.wrapQ, // query vectors phi3State.wrapKeyCache, // key cache phi3State.wrapValueCache, // value cache phi3State.wrapXb, // output: attention result @@ -305,10 +274,7 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3Config.contextLength()); // context length // Output Projection with Residual - unifiedLayer.task("attn_output_proj", - TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, - context, - phi3State.wrapXb, // input: attention output + unifiedLayer.task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, phi3State.wrapXb, // input: attention output phi3State.wrapX, // output: wrapX += Wo · wrapXb weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim × dim] phi3Config.dim(), // input dim @@ -320,10 +286,7 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { // ═══════════════════════════════════════════════════════════════════════ // RMS Normalization - compute scale factor - unifiedLayer.task("ffn_rms_reduce", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, - phi3State.tempFFN, // output: scale factor + unifiedLayer.task("ffn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, phi3State.tempFFN, // output: scale factor phi3State.wrapX, // input: hidden state phi3Config.dim(), // dimension phi3Config.rmsNormEps(), // epsilon @@ -331,46 +294,25 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { // Final normalization (non-NVIDIA only) if (shouldUseFinalNormalization()) { - unifiedLayer.task("ffn_rms_finalize", - TransformerComputeKernelsLayered::reductionFinalNormalization, - context, - phi3State.tempFFN, // scale factor (in/out) + unifiedLayer.task("ffn_rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, context, phi3State.tempFFN, // scale factor (in/out) phi3Config.dim(), // dimension phi3Config.rmsNormEps()); // epsilon } - // Fused RMS Apply + Gate/Up Projection (combined output) - unifiedLayer.task("rms_ffn_gate_up", - Phi3Kernels::fusedRmsNormMatmul, - context, - phi3State.wrapX, // input: raw hidden state (FP32) - phi3State.wrapHb, // output: gate + up combined [2×hiddenDim] - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights - phi3State.tempFFN, // RMS scale factor - weights.wUpLayered[layerIndex].asHalfFloatArray(), // wUp [2×hiddenDim × dim] - phi3Config.dim(), // input dimension - 2 * phi3Config.hiddenDim(), // output dimension + unifiedLayer.task("rms_ffn_silu", Phi3Kernels::fusedRmsNormFFNGateUpSiLU, context, phi3State.wrapX, // input + phi3State.wrapHbU, // output (direct to final FFN buffer) + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), phi3State.tempFFN, // RMS scale + weights.wUpLayered[layerIndex].asHalfFloatArray(), phi3Config.dim(), // input dim + phi3Config.hiddenDim(), // output dim (hiddenDim, not 2×hiddenDim!) LOCAL_WORK_GROUP_SIZE_ALLOC); - // Split Gate/Up and apply SiLU activation - unifiedLayer.task("gateUpSiLU", - TransformerComputeKernelsLayered::splitGateUpAndSiLU, - phi3State.wrapHb, // input: gate + up combined - phi3State.wrapHbG, // output: SiLU(gate) (intermediate) - phi3State.wrapHbU, // output: SiLU(gate) ⊙ up - phi3Config.hiddenDim()); // hidden dimension - // Down Projection with Residual - unifiedLayer.task("ffn_down_proj", - TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, - context, - phi3State.wrapHbU, // input: FFN intermediate - phi3State.wrapX, // output: wrapX += wDown · wrapHbU - weights.wDownLayered[layerIndex].asHalfFloatArray(), // wDown [dim × hiddenDim] - phi3Config.hiddenDim(), // input dim - phi3Config.dim(), // output dim - LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice(phi3State.wrapX); + unifiedLayer.task("ffn_down_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, phi3State.wrapHbU, // input: FFN intermediate + phi3State.wrapX, // output: wrapX += wDown · wrapHbU + weights.wDownLayered[layerIndex].asHalfFloatArray(), // wDown [dim × hiddenDim] + phi3Config.hiddenDim(), // input dim + phi3Config.dim(), // output dim + LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(phi3State.wrapX); return unifiedLayer; } @@ -381,26 +323,14 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { if (layerIndex == 0) { // First layer: Transfer temporary buffers and state every execution - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, - phi3State.positionHolder); + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, phi3State.positionHolder); // First execution: allocate workspace buffers - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, phi3State.wrapXb, phi3State.wrapXb2, - phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, - phi3State.wrapKeyCache, phi3State.wrapValueCache, - phi3State.wrapAtt, phi3State.wrapHb, - phi3State.temp, phi3State.tempFFN, - phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, context, phi3State.wrapXb, phi3State.wrapXb2, phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, phi3State.wrapKeyCache, + phi3State.wrapValueCache, phi3State.wrapAtt, phi3State.wrapHb, phi3State.temp, phi3State.tempFFN, phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); } else { // Subsequent layers: Consume data from previous layer - unifiedLayer.consumeFromDevice( - context, phi3State.wrapXb, phi3State.wrapXb2, - phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, - phi3State.wrapKeyCache, phi3State.wrapValueCache, - phi3State.wrapAtt, phi3State.wrapHb, - phi3State.positionHolder, - phi3State.temp, phi3State.tempFFN, - phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); + unifiedLayer.consumeFromDevice(context, phi3State.wrapXb, phi3State.wrapXb2, phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, phi3State.wrapKeyCache, phi3State.wrapValueCache, + phi3State.wrapAtt, phi3State.wrapHb, phi3State.positionHolder, phi3State.temp, phi3State.tempFFN, phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); } return unifiedLayer; } From 7e1903269aaad0f134a175a971b65dfc83009125 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 23:13:19 +0200 Subject: [PATCH 30/42] Remove unused `splitQKV` and `splitGateUpSiLU` workers, clean up commented code, and streamline Phi3 FP16 FFN layer configurations. --- .../tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index a10e9dc0..07efb32b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -41,8 +41,6 @@ public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeigh super(taskGraphName, state, weights, config, schedulerType); this.phi3State = state; this.phi3Config = config; - // Calculate opSize for combined QKV buffer - // opSize = num_heads * head_dim + 2 * (num_key_value_heads * head_dim) = dim + 2 * kvDim this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); ffnLayerTaskGraphs = setupFFNLayered(); } @@ -56,9 +54,6 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int fusedQkvGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - // SplitQKV worker - WorkerGrid splitQKVWorker = WorkerGridFactory.genericWorker(opSize, 128); - // Fused RoPE + cache copy worker (Phi3 uses dim/2 pattern) WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); @@ -73,9 +68,6 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int fusedFFNGlobal = (2 * config.hiddenDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid fusedFFNWorker = WorkerGridFactory.genericWorker(fusedFFNGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - // SplitGateUpAndSiLU worker - WorkerGrid splitGateUpSiLUWorker = WorkerGridFactory.genericWorker(config.hiddenDim(), 128); - // FFN down projection worker int ffnDownGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); From 1e4640589424fd6f7ae8100ca680cc97e7e59e4d Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 23:14:18 +0200 Subject: [PATCH 31/42] Refactor Phi3 FP16 FFN layer task graph: improve readability by adjusting line breaks in data transfer logic and disabling formatter for consistent formatting. --- .../tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 07efb32b..1fb8aeb1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -207,16 +207,20 @@ List setupFFNLayered() { * • Inline SiLU+GLU: No intermediate wrapHb buffer needed * */ - // @formatter:on + // @formatter:off TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { var taskGraphName = "layer_" + layerIndex; var unifiedLayer = new TaskGraph(taskGraphName); unifiedLayer.consumeFromDevice(phi3State.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights - weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqkvLayered[layerIndex].asHalfFloatArray(), weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqkvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), // FFN weights - weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), weights.wUpLayered[layerIndex].asHalfFloatArray(), weights.wDownLayered[layerIndex].asHalfFloatArray()); + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.wUpLayered[layerIndex].asHalfFloatArray(), + weights.wDownLayered[layerIndex].asHalfFloatArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); // ═══════════════════════════════════════════════════════════════════════ From 7c63dc4cafc45028f87999fced0019204a6274ba Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 4 Dec 2025 23:19:34 +0200 Subject: [PATCH 32/42] Refactor LogitsFP16Layer: streamline task graph setup, consolidate grid scheduler logic, and improve readability by adjusting formatting. --- .../layers/type/fp16/LogitsFP16Layer.java | 24 +++++-------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index 3ed5e444..713d4437 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -29,19 +29,14 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; state.tempLogits.clear(); - var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); this.schedulerType = schedulerType; } - - /** - * Builds the logits computation graph. - */ + // @formatter:off private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { - TaskGraph logits = new TaskGraph("logits"); - + var logits = new TaskGraph("logits"); // === Data Setup === logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, @@ -97,24 +92,17 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } - + // @formatter:on @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid logitsRMS; - if (weights instanceof Qwen2TornadoWeights) { - logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); - } else { - logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); - } - + WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - - tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); return tornadoForwardScheduler; } From 1a98725f3005dbb7eafb956b349a085a8abe4206 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Sun, 7 Dec 2025 15:56:27 +0200 Subject: [PATCH 33/42] Refactor `TransformerComputeKernelsLayered`: replace `matrixVectorRowMajorOptimized` logic with `matrixVectorRowMajorOptimizedSingle`, remove unused floats, and streamline memory allocation and reduction. --- .../TransformerComputeKernelsLayered.java | 93 ++++++++++++++++++- 1 file changed, 90 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index 7bd1da29..ce69883d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -60,7 +60,6 @@ public static void fusedRmsNormFFNGateUp( float scale = rmsScale.get(0); // Allocate shared memory for normalized input (reused for both W1 and W3) - float[] xNorm = context.allocateFloatLocalArray(localWorkGroupSize); float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); int rowOffsetW1 = rowId * dim; @@ -1160,7 +1159,7 @@ public static void matrixVectorGeneric( if (rowId >= dim0) { return; } - float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, dim1); + float sum = matrixVectorRowMajorOptimizedSingle(context, localSize, x, w, dim1); // Thread 0 in each workgroup writes the final result if (localId == 0) { @@ -1489,7 +1488,7 @@ public static float matrixVectorRowMajorOptimizedF(KernelContext context, int lo return localSum[0]; } - public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, + public static float matrixVectorRowMajorOptimizedXX(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { int rowId = context.groupIdx; int localId = context.localIdx; @@ -1539,6 +1538,94 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc return localSum[0]; } + public static float matrixVectorRowMajorOptimizedSingle(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + HalfFloat partialSum = new HalfFloat(0f); + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j)); + partialSum = HalfFloat.add(partialSum, mul); + } + + + // Store partial sum in local memory + localSum[localId] = partialSum.getHalfFloatValue(); + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, + HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Accumulate in HalfFloat to avoid conversions in inner loop + HalfFloat sum0 = new HalfFloat(0f); + HalfFloat sum1 = new HalfFloat(0f); + HalfFloat sum2 = new HalfFloat(0f); + HalfFloat sum3 = new HalfFloat(0f); + + int stride = localSize; + int stride2 = localSize << 1; + int stride3 = localSize * 3; + int stride4 = localSize << 2; + + int j = localId; + int limit = n - stride3; + + for (; j < limit; j += stride4) { + int base = rowOffset + j; + + // Stay in HalfFloat - no getFloat32() calls + HalfFloat x0 = x.get(j); + HalfFloat x1 = x.get(j + stride); + HalfFloat x2 = x.get(j + stride2); + HalfFloat x3 = x.get(j + stride3); + + sum0 = HalfFloat.add(sum0, HalfFloat.mult(w.get(base), x0)); + sum1 = HalfFloat.add(sum1, HalfFloat.mult(w.get(base + stride), x1)); + sum2 = HalfFloat.add(sum2, HalfFloat.mult(w.get(base + stride2), x2)); + sum3 = HalfFloat.add(sum3, HalfFloat.mult(w.get(base + stride3), x3)); + } + + // Cleanup loop + for (; j < n; j += stride) { + sum0 = HalfFloat.add(sum0, HalfFloat.mult(w.get(rowOffset + j), x.get(j))); + } + + // Convert to float32 only at the end for reduction + localSum[localId] = sum0.getFloat32() + sum1.getFloat32() + sum2.getFloat32() + sum3.getFloat32(); + context.localBarrier(); + + for (int s = localSize >> 1; s > 0; s >>= 1) { + if (localId < s) { + localSum[localId] += localSum[localId + s]; + } + context.localBarrier(); + } + + return localSum[0]; + } + public static void fusedQKVMatmul( KernelContext context, HalfFloatArray x, // input (read once!) From d1ec40845fa75cc8e53430596881038183c98739 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Sun, 7 Dec 2025 16:14:05 +0200 Subject: [PATCH 34/42] Refactor `TransformerComputeKernelsLayered`: rename `matrixVectorRowMajorOptimizedSingle` to `matrixVectorRowMajorOptimized`, correct misnamed `matrixVectorRowMajorOptimiz`, and update references in logic flow. --- .../tornadovm/kernels/TransformerComputeKernelsLayered.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index ce69883d..1906eda5 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -1159,7 +1159,7 @@ public static void matrixVectorGeneric( if (rowId >= dim0) { return; } - float sum = matrixVectorRowMajorOptimizedSingle(context, localSize, x, w, dim1); + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, dim1); // Thread 0 in each workgroup writes the final result if (localId == 0) { @@ -1538,7 +1538,7 @@ public static float matrixVectorRowMajorOptimizedXX(KernelContext context, int l return localSum[0]; } - public static float matrixVectorRowMajorOptimizedSingle(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { int rowId = context.groupIdx; int localId = context.localIdx; @@ -1570,7 +1570,7 @@ public static float matrixVectorRowMajorOptimizedSingle(KernelContext context, i return localSum[0]; } - public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, + public static float matrixVectorRowMajorOptimiz(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { int rowId = context.groupIdx; int localId = context.localIdx; From d1ff213a7a17195724c3372fe22b2b05a15545f2 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Sun, 7 Dec 2025 16:18:59 +0200 Subject: [PATCH 35/42] Refactor `TransformerComputeKernelsLayered`: rename `matrixVectorRowMajorOptimiz` to `matrixVectorRowMajorOptimized`, update references to `matrixVectorRowMajorOptimizedSingle`, and adjust method calls for consistency. --- .../tornadovm/kernels/TransformerComputeKernelsLayered.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index 1906eda5..ce69883d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -1159,7 +1159,7 @@ public static void matrixVectorGeneric( if (rowId >= dim0) { return; } - float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, dim1); + float sum = matrixVectorRowMajorOptimizedSingle(context, localSize, x, w, dim1); // Thread 0 in each workgroup writes the final result if (localId == 0) { @@ -1538,7 +1538,7 @@ public static float matrixVectorRowMajorOptimizedXX(KernelContext context, int l return localSum[0]; } - public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + public static float matrixVectorRowMajorOptimizedSingle(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { int rowId = context.groupIdx; int localId = context.localIdx; @@ -1570,7 +1570,7 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc return localSum[0]; } - public static float matrixVectorRowMajorOptimiz(KernelContext context, int localSize, + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { int rowId = context.groupIdx; int localId = context.localIdx; From 2c0c55c6828981522e991a9cdcb36f8a92e16497 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Mon, 8 Dec 2025 13:05:55 +0200 Subject: [PATCH 36/42] Fix import --- .../tornadovm/kernels/TransformerComputeKernelsLayered.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index 3d14e326..35fa093c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -4,6 +4,7 @@ import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.Int8Array; From bbdffd354af393c49ee84e526b22cca5c48335ab Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Mon, 8 Dec 2025 15:54:30 +0200 Subject: [PATCH 37/42] Refactor LogitsQ8_0Layer: simplify grid scheduler setup, consolidate task graph logic, and improve readability by adjusting formatting and renaming worker grid identifiers. --- .../layers/type/q8_0/LogitsQ8_0Layer.java | 82 ++++++++++++------- 1 file changed, 54 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java index c8c5a753..20782b0e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -2,8 +2,8 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; @@ -13,12 +13,9 @@ import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.SequencedCollection; - public class LogitsQ8_0Layer extends AbstractLayer { private String lastTaskGraphID; @@ -38,39 +35,68 @@ public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Confi @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid logitsRMS; - if (weights instanceof Qwen2TornadoWeights) { - logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); - } else { - logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); - } - + var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - - tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); + tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); return tornadoForwardScheduler; } + // @formatter:off private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { - TaskGraph logits = new TaskGraph("logits"); - logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asByteArray(), - weights.rms_final_weight_as_floatArray) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (schedulerType == SchedulerType.NON_NVIDIA) { - logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); - } - logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) - .task("projection", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, // - context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asByteArray(), // - config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS) // - .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + var logits = new TaskGraph("logits"); + // === Data Setup === + logits.consumeFromDevice(lastTaskGraphID, state.wrapX); + logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); + logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, // + state.wrapLogits, // + weights.wclsByteArray.asByteArray(), // + weights.rms_final_weight_as_floatArray); + + // === Final RMS Normalization === + logits.task("rms_reduce", + TransformerComputeKernels::reductionOneBlockWithLayer, + context, + state.tempLogits, // output: partial sums + final scale factor + state.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon for numerical stability + state.localSize); // local workgroup size + + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.tempLogits, + config.dim(), + config.rmsNormEps()); + } + logits.task("mapContextLogits", + TransformerComputeKernels::reductionOneBlock2WithLogits, + context, + state.wrapX, + weights.rms_final_weight_as_floatArray.asFloatArray(), + state.tempLogits); + + // === Vocabulary vocab_proj === + logits.task("vocab_proj", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, // + context, + state.wrapX, + state.wrapLogits, + weights.wclsByteArray.asByteArray(), + config.dim(), + config.vocabularySize(), + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); + + // === Transfer Results to Host === + logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } + // @formatter:on @Override public GridScheduler getGridScheduler() { From e2d8820bdfcc250ac3550f8f87787c944ccc0d12 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Mon, 8 Dec 2025 17:11:26 +0200 Subject: [PATCH 38/42] Refactor logits handling: replace `init` calls with `clear` for tensor resets, streamline state clearing, and adjust data transfer logic in logits task graphs. --- .../org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java | 5 +++-- .../tornadovm/layers/type/fp16/LogitsFP16Layer.java | 2 +- .../tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java | 1 - 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 293d2c0c..42fef642 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -141,7 +141,8 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) .execute(); } - + state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f + state.wrapLogits.clear(); // Clear the output logits tensor -> set to 0f // 3. Execute the final graph that projects the last hidden state to output logits executionPlan.withGraph(getFinalLogitsGraphIndex()) .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) @@ -179,7 +180,7 @@ private int getFinalLogitsGraphIndex() { /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer. public void forceCopyInReadOnlyDataLayered() { // Execute all TornadoVM graphs - state.wrapX.init(0.0f); + state.wrapX.clear(); state.positionHolder.init(0); // Execute activation update graph diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index 713d4437..3b9e009f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -28,7 +28,6 @@ public class LogitsFP16Layer extends AbstractLayer { public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; - state.tempLogits.clear(); var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); this.schedulerType = schedulerType; @@ -39,6 +38,7 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con var logits = new TaskGraph("logits"); // === Data Setup === logits.consumeFromDevice(lastTaskGraphID, state.wrapX); + logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Kernel context context, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java index 20782b0e..d54bb3ef 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -27,7 +27,6 @@ public class LogitsQ8_0Layer extends AbstractLayer { public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { super(taskGraphName, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; - state.tempLogits.init(0.0f); var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); this.schedulerType = schedulerType; From fee6ea49534315e2f9e3c518d065a251aae3fc48 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Mon, 8 Dec 2025 17:32:10 +0200 Subject: [PATCH 39/42] Refactor FFN layer task graphs: update data transfer logic by removing unnecessary `temp` buffers, consolidate transfer flow, and adjust formatting for better readability. --- .../tornadovm/TornadoVMMasterPlan.java | 4 +++- .../layers/type/fp16/LlamaFP16FFNLayers.java | 13 ++++++------- .../layers/type/fp16/LogitsFP16Layer.java | 2 +- .../layers/type/fp16/Phi3FP16FFNLayers.java | 17 ++++++++++++----- .../layers/type/fp16/Qwen2FP16FFNLayers.java | 9 +++++---- .../layers/type/fp16/Qwen3FP16FFNLayers.java | 6 +++--- 6 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 42fef642..eadd2e68 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -1,9 +1,9 @@ package org.beehive.gpullama3.tornadovm; -import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; @@ -133,6 +133,8 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { // Set the position in the state object (used by attention layers) state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); // 2. Execute each transformer layer graph sequentially // Each graph computes attention and feed-forward transformations for one layer diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index b0634dbb..8d105e89 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -293,7 +293,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { if (layerIndex == 0) { // First layer: Transfer initial data to device (one-time transfer) - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder); + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, + state.temp, state.tempFFN + ); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Kernel context context, @@ -304,9 +307,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye // KV cache state.wrapKeyCache, state.wrapValueCache, // Attention & FFN buffers - state.wrapAtt, state.wrapHb, state.wrapXbFP16, - // Reduction temporaries - state.temp, state.tempFFN); + state.wrapAtt, state.wrapHb, state.wrapXbFP16); } else { // Subsequent layers: Consume data already on device from previous layer unifiedLayer.consumeFromDevice( @@ -321,9 +322,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye // Attention & FFN buffers state.wrapAtt, state.wrapHb, // Position & misc - state.positionHolder, state.wrapXbFP16, - // Reduction temporaries - state.temp, state.tempFFN); + state.positionHolder, state.wrapXbFP16); } return unifiedLayer; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index 3b9e009f..d2a81407 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -28,9 +28,9 @@ public class LogitsFP16Layer extends AbstractLayer { public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; + this.schedulerType = schedulerType; var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); - this.schedulerType = schedulerType; } // @formatter:off diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 1fb8aeb1..c3b9ddc6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -319,14 +319,21 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { if (layerIndex == 0) { // First layer: Transfer temporary buffers and state every execution - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, phi3State.positionHolder); + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, phi3State.positionHolder, + phi3State.temp, phi3State.tempFFN); // First execution: allocate workspace buffers - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, context, phi3State.wrapXb, phi3State.wrapXb2, phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, phi3State.wrapKeyCache, - phi3State.wrapValueCache, phi3State.wrapAtt, phi3State.wrapHb, phi3State.temp, phi3State.tempFFN, phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, phi3State.wrapXb, phi3State.wrapXb2, + phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, + phi3State.wrapKeyCache, phi3State.wrapValueCache, + phi3State.wrapAtt, phi3State.wrapHb, phi3State.wrapHbG, + phi3State.wrapHbU, phi3State.wrapQkv); } else { // Subsequent layers: Consume data from previous layer - unifiedLayer.consumeFromDevice(context, phi3State.wrapXb, phi3State.wrapXb2, phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, phi3State.wrapKeyCache, phi3State.wrapValueCache, - phi3State.wrapAtt, phi3State.wrapHb, phi3State.positionHolder, phi3State.temp, phi3State.tempFFN, phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); + unifiedLayer.consumeFromDevice(context, phi3State.wrapXb, phi3State.wrapXb2, + phi3State.wrapQ, phi3State.wrapK, phi3State.wrapV, phi3State.wrapKeyCache, + phi3State.wrapValueCache, phi3State.wrapAtt, phi3State.wrapHb, phi3State.positionHolder, + phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); } return unifiedLayer; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index b6bfbb28..aded0c81 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -398,7 +398,6 @@ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) return unifiedLayer; } - // @formatter:on /** * Configure data transfers for first and subsequent layers @@ -407,12 +406,13 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye // First layer: Transfer initial data to device (one-time transfer) if (layerIndex == 0) { // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen2State.positionHolder); // + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen2State.positionHolder, + qwen2State.temp, qwen2State.tempFFN); // unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // context, qwen2State.wrapXb, qwen2State.wrapXb2, // qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, // qwen2State.wrapKeyCache, qwen2State.wrapValueCache, // - qwen2State.wrapAtt, qwen2State.wrapHb, qwen2State.temp, qwen2State.tempFFN); // + qwen2State.wrapAtt, qwen2State.wrapHb); // } else { // Subsequent layers: Consume data already on device from previous layer unifiedLayer.consumeFromDevice( // @@ -420,10 +420,11 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, // qwen2State.wrapKeyCache, qwen2State.wrapValueCache, // qwen2State.wrapAtt, qwen2State.wrapHb, // - qwen2State.positionHolder, qwen2State.temp, qwen2State.tempFFN // + qwen2State.positionHolder // ); } return unifiedLayer; } + // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 27565dca..453a4d7c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -400,20 +400,20 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye if (layerIndex == 0) { // First layer: Transfer temporary buffers and QKV state every execution unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.positionHolder); - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION); + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.temp, qwen3State.tempFFN); // First execution: allocate workspace buffers unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // context, qwen3State.wrapXb, qwen3State.wrapXb2, // qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, // qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // - qwen3State.wrapAtt, qwen3State.wrapHb, qwen3State.temp, qwen3State.tempFFN); + qwen3State.wrapAtt, qwen3State.wrapHb ); } else { // Subsequent layers: Consume data from previous layer unifiedLayer.consumeFromDevice(context, qwen3State.wrapXb, qwen3State.wrapXb2, // qwen3State.wrapQ, qwen3State.wrapK, // qwen3State.wrapV, qwen3State.wrapKeyCache, // qwen3State.wrapValueCache, qwen3State.wrapAtt, // - qwen3State.wrapHb, qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); // + qwen3State.wrapHb, qwen3State.positionHolder); // } return unifiedLayer; From 848eae69df4ebdccd58faf461492cd20debcc37d Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Tue, 9 Dec 2025 16:36:36 +0200 Subject: [PATCH 40/42] Update Tornado dependencies to version 2.1.0 in `pom.xml`. --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 814df7f5..a59830ef 100644 --- a/pom.xml +++ b/pom.xml @@ -54,12 +54,12 @@ io.github.beehive-lab tornado-api - 2.0.1-dev + 2.1.0 io.github.beehive-lab tornado-runtime - 2.0.1-dev + 2.1.0 From 09cefc157940a98b618128dcfe0ae213c5ca5118 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Tue, 9 Dec 2025 17:46:38 +0200 Subject: [PATCH 41/42] Refactor FFN layer setup: streamline task graph configuration, consolidate data transfer logic, replace repetitive patterns with reusable methods, and enhance code readability with improved formatting and comments. --- .../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 155 ++++++++++++------ 1 file changed, 102 insertions(+), 53 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index 385b626f..e135cc52 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -42,64 +42,94 @@ public ImmutableTaskGraph getImmutableTaskGraph() { } List setupFFNLayered() { - state.temp.init(0.0f); - state.tempFFN.init(0.0f); - var numLayers = config.numberOfLayers(); - - return IntStream.range(0, numLayers).mapToObj(i -> { + return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); - if (i == numLayers - 1) { + if (i == config.numberOfLayers() - 1) { setupLastID(ffnLayer.getTaskGraphName()); } return ffnLayer.snapshot(); }).toList(); } + // @formatter:off TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //Copy-in weights per layer for batched-layered layout weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].asByteArray(), weights.wkLayered[layerIndex].asByteArray(), - weights.wvLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), weights.woLayered[layerIndex].asByteArray(), weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - weights.w1Layered[layerIndex].asByteArray(), - weights.w2Layered[layerIndex].asByteArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w2Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapQ, - weights.wqLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapK, - weights.wkLayered[layerIndex].asByteArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapV, - weights.wvLayered[layerIndex].asByteArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), - layerIndex, config.contextLength()); - configureAttention(unifiedLayer, layerIndex); - unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapXb, state.wrapX, - weights.woLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + + // === Attention Block === + // RMS Normalization + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, state.wrapXb, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) + + + + .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asByteArray(), config.dim(), + config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].asByteArray(), config.dim(), + config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].asByteArray(), config.dim(), + config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + +// .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) +// .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), +// layerIndex, config.contextLength()); + + // RoPE + KV Cache + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, // Q (in/out) + state.wrapK, // K (in/out) + state.wrapV, // V (in only) + state.wrapKeyCache, // Key cache (out) + state.wrapValueCache, // Value cache (out) + config.kvDim(), + config.headSize(), + layerIndex, + config.contextLength()); + + configureAttention(unifiedLayer, layerIndex); + unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asByteArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte, context, state.wrapXb, state.wrapHb, - weights.w1Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray(), config.dim(), config.hiddenDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapHb, state.wrapX, - weights.w2Layered[layerIndex].asByteArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.tempFFN).task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte, context, state.wrapXb, state.wrapHb, + weights.w1Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asByteArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); return unifiedLayer; } @@ -107,15 +137,20 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye // First layer: Transfer initial data to device (one-time transfer) if (layerIndex == 0) { // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, + state.temp, state.tempFFN); // unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - context, state.wrapXb, state.wrapXb2, // + context, + state.wrapXb, state.wrapXb2, // state.wrapQ, state.wrapK, state.wrapV, // state.wrapKeyCache, state.wrapValueCache, // state.wrapAtt, state.wrapHb); // } else { // Subsequent layers: Consume data already on device from previous layer - unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // + unifiedLayer.consumeFromDevice( + context, + state.wrapXb, state.wrapXb2, // state.wrapQ, state.wrapK, state.wrapV, // state.wrapKeyCache, state.wrapValueCache, // state.wrapAtt, state.wrapHb, // @@ -141,20 +176,24 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); +// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); +// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); } return tornadoForwardScheduler; } @@ -165,15 +204,25 @@ public List getFfnLayerTaskGraphs() { private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { if (schedulerType == SchedulerType.NVIDIA) { - return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, - context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), - state.positionHolder, layerIndex, config.contextLength()); + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + state.wrapQ, state.wrapKeyCache, + state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, + config.contextLength()); } else { - return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, - state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), - state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, + state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, + config.contextLength()); } } + // @formatter:on } From 190c2d4e8e78611648ce2744203c421ca3e8526d Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Tue, 9 Dec 2025 18:27:39 +0200 Subject: [PATCH 42/42] Refactor `TransformerComputeKernelsLayered` and `LlamaQ8_0FFNLayers`: condense parameter lists, streamline formatting, optimize fused operations, and introduce Q8 fused methods for enhanced performance. --- .../TransformerComputeKernelsLayered.java | 588 ++++++++++++++---- .../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 223 +++++-- 2 files changed, 631 insertions(+), 180 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index 35fa093c..16f8c3b6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -18,11 +18,7 @@ public class TransformerComputeKernelsLayered { public TransformerComputeKernelsLayered() { } - public static void fusedQKvBiasAddition( - KernelContext context, - FloatArray q_out, FloatArray k_out, FloatArray qBias, - FloatArray v_out, FloatArray kBias, FloatArray vBias, - int dimQ, int dimKV) { + public static void fusedQKvBiasAddition(KernelContext context, FloatArray q_out, FloatArray k_out, FloatArray qBias, FloatArray v_out, FloatArray kBias, FloatArray vBias, int dimQ, int dimKV) { int gid = context.globalIdx; @@ -38,16 +34,11 @@ public static void fusedQKvBiasAddition( } } - - public static void fusedRmsNormFFNGateUp( - KernelContext context, - FloatArray x, // raw input (FP32) + public static void fusedRmsNormFFNGateUp(KernelContext context, FloatArray x, // raw input (FP32) FloatArray hb, // output FloatArray rmsWeights, // RMS norm weights FloatArray rmsScale, // temp[0] = scale factor - HalfFloatArray w1, - HalfFloatArray w3, - int dim, // input dimension + HalfFloatArray w1, HalfFloatArray w3, int dim, // input dimension int hiddenDim, // output dimension int localWorkGroupSize) { @@ -232,25 +223,16 @@ public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, Float } /** - * Fused RoPE rotation with KV cache copy. - * Eliminates separate copyToCaches kernel. + * Fused RoPE rotation with KV cache copy. Eliminates separate copyToCaches kernel. * - * - Rotates Q (full dim) - * - Rotates K and writes directly to keyCache - * - Copies V directly to valueCache (no rotation needed) + * - Rotates Q (full dim) - Rotates K and writes directly to keyCache - Copies V directly to valueCache (no rotation needed) */ - public static void ropeRotationWithCacheCopy( - KernelContext context, - IntArray positionHolder, - FloatArray sq, // Q vector (in/out) + public static void ropeRotationWithCacheCopy(KernelContext context, IntArray positionHolder, FloatArray sq, // Q vector (in/out) FloatArray sk, // K vector (in/out) FloatArray sv, // V vector (in only) FloatArray keyCache, // Key cache (out) FloatArray valueCache, // Value cache (out) - int kvDim, - int headSize, - int layer, - int contextLength) { + int kvDim, int headSize, int layer, int contextLength) { int i = context.globalIdx * 2; int pos = positionHolder.get(0); @@ -613,19 +595,9 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray } } - public static void processHeadsFlashAttentionOptV2( - KernelContext context, - FloatArray q, - FloatArray key_cache, - FloatArray value_cache, - FloatArray xb, - int nHeads, - int headSize, // NOTE: Still used for logic, but not for allocation size - int kvDim, - int kvMul, - IntArray positionHolder, - int layer, - int contextLength) { + public static void processHeadsFlashAttentionOptV2(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, + // NOTE: Still used for logic, but not for allocation size + int kvDim, int kvMul, IntArray positionHolder, int layer, int contextLength) { // --- STATIC CONSTANTS FOR OPENCL ALLOCATIONS --- // These must be large enough to handle the maximum expected values for @@ -807,6 +779,7 @@ public static void processHeadsFlashAttentionOptV2( xb.set(baseOffset + i, output[i] * normFactor); } } + /** * Same as processHeadsFlashAttention but with some optimizations that seem to lower attention's execution time, especially in larger models. */ @@ -1252,8 +1225,8 @@ public static void matrixVectorGenericWithResidual(KernelContext context, HalfFl * Work group size */ - - public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, HalfFloatArray x, HalfFloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, HalfFloatArray x, HalfFloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, + int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; int localId = context.localIdx; @@ -1273,7 +1246,6 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex } } - public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; @@ -1437,7 +1409,6 @@ public static HalfFloat matrixVectorRowMajorOptimizedFHF(KernelContext context, // partialSum = HalfFloat.add(partialSum, mul); } - // Store partial sum in local memory localSum[localId] = new HalfFloat(partialSum); context.localBarrier(); @@ -1446,7 +1417,7 @@ public static HalfFloat matrixVectorRowMajorOptimizedFHF(KernelContext context, for (int stride = localSize / 2; stride > 0; stride >>= 1) { if (localId < stride) { localSum[localId] = HalfFloat.add(localSum[localId], localSum[localId + stride]); -// localSum[localId] += localSum[localId + stride]; + // localSum[localId] += localSum[localId + stride]; } context.localBarrier(); } @@ -1465,15 +1436,14 @@ public static float matrixVectorRowMajorOptimizedF(KernelContext context, int lo // Each thread calculates partial dot product float partialSum = 0.0f; -// HalfFloat partialSum = new HalfFloat(0f); + // HalfFloat partialSum = new HalfFloat(0f); for (int j = localId; j < n; j += localSize) { int matrixIdx = rowOffset + j; -// HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j)); + // HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j)); partialSum += w.get(matrixIdx).getFloat32() * x.get(j).getFloat32(); -// partialSum = HalfFloat.add(partialSum, mul); + // partialSum = HalfFloat.add(partialSum, mul); } - // Store partial sum in local memory localSum[localId] = partialSum; context.localBarrier(); @@ -1489,8 +1459,7 @@ public static float matrixVectorRowMajorOptimizedF(KernelContext context, int lo return localSum[0]; } - public static float matrixVectorRowMajorOptimizedXX(KernelContext context, int localSize, - HalfFloatArray x, HalfFloatArray w, int n) { + public static float matrixVectorRowMajorOptimizedXX(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { int rowId = context.groupIdx; int localId = context.localIdx; float[] localSum = context.allocateFloatLocalArray(localSize); @@ -1539,40 +1508,38 @@ public static float matrixVectorRowMajorOptimizedXX(KernelContext context, int l return localSum[0]; } - public static float matrixVectorRowMajorOptimizedSingle(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { - int rowId = context.groupIdx; - int localId = context.localIdx; - - // Allocate local memory for reduction - float[] localSum = context.allocateFloatLocalArray(localSize); + public static float matrixVectorRowMajorOptimizedSingle(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; - int rowOffset = rowId * n; + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); - HalfFloat partialSum = new HalfFloat(0f); - for (int j = localId; j < n; j += localSize) { - int matrixIdx = rowOffset + j; - HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j)); - partialSum = HalfFloat.add(partialSum, mul); - } + int rowOffset = rowId * n; + HalfFloat partialSum = new HalfFloat(0f); + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j)); + partialSum = HalfFloat.add(partialSum, mul); + } - // Store partial sum in local memory - localSum[localId] = partialSum.getHalfFloatValue(); - context.localBarrier(); + // Store partial sum in local memory + localSum[localId] = partialSum.getHalfFloatValue(); + context.localBarrier(); - // Parallel reduction within workgroup - for (int stride = localSize / 2; stride > 0; stride >>= 1) { - if (localId < stride) { - localSum[localId] += localSum[localId + stride]; - } - context.localBarrier(); + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; } - - return localSum[0]; + context.localBarrier(); } - public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, - HalfFloatArray x, HalfFloatArray w, int n) { + return localSum[0]; + } + + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { int rowId = context.groupIdx; int localId = context.localIdx; float[] localSum = context.allocateFloatLocalArray(localSize); @@ -1627,12 +1594,9 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc return localSum[0]; } - public static void fusedQKVMatmul( - KernelContext context, - HalfFloatArray x, // input (read once!) + public static void fusedQKVMatmul(KernelContext context, HalfFloatArray x, // input (read once!) FloatArray q, FloatArray k, FloatArray v, // outputs - HalfFloatArray wq, HalfFloatArray wk, HalfFloatArray wv, - int dim, int kvDim, int localWorkGroupSize) { + HalfFloatArray wq, HalfFloatArray wk, HalfFloatArray wv, int dim, int kvDim, int localWorkGroupSize) { int rowId = context.groupIdx; int localId = context.localIdx; @@ -1643,76 +1607,81 @@ public static void fusedQKVMatmul( if (rowId < dim) { // Q projection float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wq, dim); - if (localId == 0) q.set(rowId, sum); + if (localId == 0) { + q.set(rowId, sum); + } } else if (rowId < dim + kvDim) { // K projection int kRow = rowId - dim; float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wk, dim); - if (localId == 0) k.set(kRow, sum); + if (localId == 0) { + k.set(kRow, sum); + } } else { // V projection int vRow = rowId - dim - kvDim; float sum = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, wv, dim); - if (localId == 0) v.set(vRow, sum); + if (localId == 0) { + v.set(vRow, sum); + } } } + public static float matrixVectorRowMajorOptimizedx(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; -public static float matrixVectorRowMajorOptimizedx(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { - int rowId = context.groupIdx; - int localId = context.localIdx; - - // Allocate local memory for reduction - float[] localSum = context.allocateFloatLocalArray(localSize); + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); - int rowOffset = rowId * n; + int rowOffset = rowId * n; - // Each thread calculates partial dot product - UNROLLED BY 4 - float sum0 = 0.0f; - float sum1 = 0.0f; - float sum2 = 0.0f; - float sum3 = 0.0f; + // Each thread calculates partial dot product - UNROLLED BY 4 + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; - int j = localId; - int stride = localSize; - int stride4 = localSize << 2; // localSize * 4 - int limit = n - (stride * 3); // Safe limit for 4 elements + int j = localId; + int stride = localSize; + int stride4 = localSize << 2; // localSize * 4 + int limit = n - (stride * 3); // Safe limit for 4 elements - // Main loop unrolled by 4 with separate accumulators - for (; j < limit; j += stride4) { - int base = rowOffset + j; - int j1 = j + stride; - int j2 = j + (stride << 1); - int j3 = j + stride * 3; + // Main loop unrolled by 4 with separate accumulators + for (; j < limit; j += stride4) { + int base = rowOffset + j; + int j1 = j + stride; + int j2 = j + (stride << 1); + int j3 = j + stride * 3; - sum0 += w.get(base).getFloat32() * x.get(j).getFloat32(); - sum1 += w.get(base + stride).getFloat32() * x.get(j1).getFloat32(); - sum2 += w.get(base + (stride << 1)).getFloat32() * x.get(j2).getFloat32(); - sum3 += w.get(base + stride * 3).getFloat32() * x.get(j3).getFloat32(); - } + sum0 += w.get(base).getFloat32() * x.get(j).getFloat32(); + sum1 += w.get(base + stride).getFloat32() * x.get(j1).getFloat32(); + sum2 += w.get(base + (stride << 1)).getFloat32() * x.get(j2).getFloat32(); + sum3 += w.get(base + stride * 3).getFloat32() * x.get(j3).getFloat32(); + } - // Handle remainder - for (; j < n; j += stride) { - sum0 += w.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); - } + // Handle remainder + for (; j < n; j += stride) { + sum0 += w.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } - // Combine accumulators (tree reduction for better precision) - float partialSum = (sum0 + sum1) + (sum2 + sum3); + // Combine accumulators (tree reduction for better precision) + float partialSum = (sum0 + sum1) + (sum2 + sum3); - // Store partial sum in local memory - localSum[localId] = partialSum; - context.localBarrier(); + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); - // Parallel reduction within workgroup - for (int s = localSize >> 1; s > 0; s >>= 1) { - if (localId < s) { - localSum[localId] += localSum[localId + s]; + // Parallel reduction within workgroup + for (int s = localSize >> 1; s > 0; s >>= 1) { + if (localId < s) { + localSum[localId] += localSum[localId + s]; + } + context.localBarrier(); } - context.localBarrier(); - } - return localSum[0]; -} + return localSum[0]; + } // Second kernel - Combines partial sums and computes final normalization public static void reductionFinalNormalization(KernelContext context, FloatArray output, int size, float ermsNorm) { @@ -2007,10 +1976,7 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex } } - public static void fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte(KernelContext context, FloatArray x, FloatArray hb, - ByteArray w1, - ByteArray w3, - int n, int d, int localWorkGroupSize) { + public static void fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte(KernelContext context, FloatArray x, FloatArray hb, ByteArray w1, ByteArray w3, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; int localId = context.localIdx; @@ -2075,4 +2041,364 @@ public static void processHeadsParallel(FloatArray q, FloatArray key_cache, Floa } } + /** + * Fused Q/K/V matrix-vector multiplication for Q8_0 quantized weights. Reduces kernel launch overhead and improves input vector cache utilization. + * + * Workgroup assignment: - rowId [0, dim): Q projection - rowId [dim, dim+kvDim): K projection - rowId [dim+kvDim, dim+2*kvDim): V projection + */ + public static void fusedQKVMatmulQ8(KernelContext context, FloatArray x, FloatArray q, FloatArray k, FloatArray v, ByteArray wq, ByteArray wk, ByteArray wv, int dim, int kvDim, + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + int blockSize = 32; + final int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (dim + blockSize - 1) / blockSize; + + float[] localSums = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowId < dim) { + // ========== Q projection ========== + int rowBlockOffset = rowId * blocksPerRow; + + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wq.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + byte quant1 = wq.get(quantsOffset); + byte quant2 = wq.get(quantsOffset + 1); + byte quant3 = wq.get(quantsOffset + 2); + byte quant4 = wq.get(quantsOffset + 3); + + partialSum1 += ((float) quant1 * scaleFloat) * x.get(j); + partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1); + partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2); + partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wq.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + byte quant = wq.get(blockByteOffset + 2 + withinBlockIdx); + partialSum += ((float) quant * scaleFloat) * x.get(j); + } + + localSums[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + q.set(rowId, localSums[0]); + } + + } else if (rowId < dim + kvDim) { + // ========== K projection ========== + int kRow = rowId - dim; + int rowBlockOffset = kRow * blocksPerRow; + + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wk.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + byte quant1 = wk.get(quantsOffset); + byte quant2 = wk.get(quantsOffset + 1); + byte quant3 = wk.get(quantsOffset + 2); + byte quant4 = wk.get(quantsOffset + 3); + + partialSum1 += ((float) quant1 * scaleFloat) * x.get(j); + partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1); + partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2); + partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wk.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + byte quant = wk.get(blockByteOffset + 2 + withinBlockIdx); + partialSum += ((float) quant * scaleFloat) * x.get(j); + } + + localSums[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + k.set(kRow, localSums[0]); + } + + } else if (rowId < dim + 2 * kvDim) { + // ========== V projection ========== + int vRow = rowId - dim - kvDim; + int rowBlockOffset = vRow * blocksPerRow; + + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wv.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + byte quant1 = wv.get(quantsOffset); + byte quant2 = wv.get(quantsOffset + 1); + byte quant3 = wv.get(quantsOffset + 2); + byte quant4 = wv.get(quantsOffset + 3); + + partialSum1 += ((float) quant1 * scaleFloat) * x.get(j); + partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1); + partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2); + partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat scale = wv.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + byte quant = wv.get(blockByteOffset + 2 + withinBlockIdx); + partialSum += ((float) quant * scaleFloat) * x.get(j); + } + + localSums[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + v.set(vRow, localSums[0]); + } + } + } + + /** + * Fully fused RMS normalization + FFN W1/W3 matmul with SiLU/GLU for Q8_0 weights. + * Each workgroup redundantly computes RMS scale to avoid cross-workgroup sync. + */ + public static void fullyFusedRmsNormFFNGateUpQ8( + KernelContext context, + FloatArray x, // raw input (FP32) + FloatArray hb, // output + FloatArray rmsWeights, // RMS norm weights + ByteArray w1, // Q8_0 quantized + ByteArray w3, // Q8_0 quantized + int dim, // input dimension + int hiddenDim, // output dimension + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= hiddenDim) { + return; + } + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + // ========== RMS Norm: Compute scale (each workgroup does this redundantly) ========== + float sumSquares = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float val = x.get(j); + sumSquares += val * val; + } + + localSum[localId] = sumSquares; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + float scale = 1.0f / TornadoMath.sqrt(localSum[0] / dim + 1e-5f); + + // ========== W1 matmul with inline RMS normalization ========== + int blockSize = 32; + final int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (dim + blockSize - 1) / blockSize; + int rowBlockOffset = rowId * blocksPerRow; + + float partialSum1_a = 0.0f; + float partialSum1_b = 0.0f; + float partialSum1_c = 0.0f; + float partialSum1_d = 0.0f; + + for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat w1Scale = w1.getHalfFloat(blockByteOffset); + float w1ScaleFloat = w1Scale.getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + byte q1 = w1.get(quantsOffset); + byte q2 = w1.get(quantsOffset + 1); + byte q3 = w1.get(quantsOffset + 2); + byte q4 = w1.get(quantsOffset + 3); + + float norm0 = rmsWeights.get(j) * scale * x.get(j); + float norm1 = rmsWeights.get(j + 1) * scale * x.get(j + 1); + float norm2 = rmsWeights.get(j + 2) * scale * x.get(j + 2); + float norm3 = rmsWeights.get(j + 3) * scale * x.get(j + 3); + + partialSum1_a += ((float) q1 * w1ScaleFloat) * norm0; + partialSum1_b += ((float) q2 * w1ScaleFloat) * norm1; + partialSum1_c += ((float) q3 * w1ScaleFloat) * norm2; + partialSum1_d += ((float) q4 * w1ScaleFloat) * norm3; + } + + float partialSum1 = partialSum1_a + partialSum1_b + partialSum1_c + partialSum1_d; + + for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat w1Scale = w1.getHalfFloat(blockByteOffset); + float w1ScaleFloat = w1Scale.getFloat32(); + + byte quant = w1.get(blockByteOffset + 2 + withinBlockIdx); + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum1 += ((float) quant * w1ScaleFloat) * normalized; + } + + localSum[localId] = partialSum1; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float result1 = localSum[0]; + + // ========== W3 matmul with inline RMS normalization ========== + float partialSum3_a = 0.0f; + float partialSum3_b = 0.0f; + float partialSum3_c = 0.0f; + float partialSum3_d = 0.0f; + + for (int j = localId * 4; j < dim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat w3Scale = w3.getHalfFloat(blockByteOffset); + float w3ScaleFloat = w3Scale.getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + byte q1 = w3.get(quantsOffset); + byte q2 = w3.get(quantsOffset + 1); + byte q3 = w3.get(quantsOffset + 2); + byte q4 = w3.get(quantsOffset + 3); + + float norm0 = rmsWeights.get(j) * scale * x.get(j); + float norm1 = rmsWeights.get(j + 1) * scale * x.get(j + 1); + float norm2 = rmsWeights.get(j + 2) * scale * x.get(j + 2); + float norm3 = rmsWeights.get(j + 3) * scale * x.get(j + 3); + + partialSum3_a += ((float) q1 * w3ScaleFloat) * norm0; + partialSum3_b += ((float) q2 * w3ScaleFloat) * norm1; + partialSum3_c += ((float) q3 * w3ScaleFloat) * norm2; + partialSum3_d += ((float) q4 * w3ScaleFloat) * norm3; + } + + float partialSum3 = partialSum3_a + partialSum3_b + partialSum3_c + partialSum3_d; + + for (int j = ((dim / 4) * 4) + localId; j < dim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + HalfFloat w3Scale = w3.getHalfFloat(blockByteOffset); + float w3ScaleFloat = w3Scale.getFloat32(); + + byte quant = w3.get(blockByteOffset + 2 + withinBlockIdx); + float normalized = rmsWeights.get(j) * scale * x.get(j); + partialSum3 += ((float) quant * w3ScaleFloat) * normalized; + } + + localSum[localId] = partialSum3; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float result3 = localSum[0]; + + // ========== SiLU + GLU ========== + if (localId == 0) { + float silu = result1 / (1.0f + TornadoMath.exp(-result1)); + hb.set(rowId, silu * result3); + } + } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index e135cc52..ba1b6a79 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -52,6 +52,93 @@ List setupFFNLayered() { } // @formatter:off + /** + * Transformer Layer Task Flow (LlamaQ8FFNLayers) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (partial sums) + * └────────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ attn_rms_finalize│──▶ temp (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌────────────────┐ + * │ attn_rms_apply │──▶ wrapXb (normalized, FP32) + * └───────┬────────┘ + * │ + * ▼ + * ┌────────────────┐ ┌─────────────────────────────┐ + * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │ + * └───────┬────────┘ └─────────────────────────────┘ + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (partial sums) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌─────────────────┐ + * │ ffn_rms_finalize│──▶ tempFFN (final scale) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fully fused: RMS reduce/apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points: + * • qkv_projection: Fused Q/K/V matmuls with Q8 dequantization (3→1 kernel) + * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_gate_up: Fully fused RMS norm + W1/W3 matmuls + SiLU + GLU (5→1 kernel) + * + * Quantization: Q8_0 format (8-bit weights with block-wise scaling) + * + */ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); @@ -59,25 +146,25 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, // === Data Setup === unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - //Copy-in weights per layer for batched-layered layout + // Copy-in weights per layer for batched-layered layout (Q8 format) weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].asByteArray(), weights.wkLayered[layerIndex].asByteArray(), - weights.wvLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), weights.woLayered[layerIndex].asByteArray(), weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - weights.w1Layered[layerIndex].asByteArray(), - weights.w2Layered[layerIndex].asByteArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w2Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); // === Attention Block === // RMS Normalization - unifiedLayer.task("attn_rms_reduce", + unifiedLayer.task("attn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - + if (shouldUseFinalNormalization()) { unifiedLayer.task("attn_rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, @@ -85,23 +172,24 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, } unifiedLayer.task("attn_rms_apply", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, - context, state.wrapXb, state.wrapX, - weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) - - - - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asByteArray(), config.dim(), - config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].asByteArray(), config.dim(), - config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].asByteArray(), config.dim(), - config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, state.wrapXb, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); - -// .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) -// .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), -// layerIndex, config.contextLength()); + // QKV Projection (fused with Q8 dequantization) + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulQ8, + context, + state.wrapXb, // input (FP32) + state.wrapQ, // output Q + state.wrapK, // output K + state.wrapV, // output V + weights.wqLayered[layerIndex].asByteArray(), // Wq (Q8) + weights.wkLayered[layerIndex].asByteArray(), // Wk (Q8) + weights.wvLayered[layerIndex].asByteArray(), // Wv (Q8) + config.dim(), // dim + config.kvDim(), // kvDim + LOCAL_WORK_GROUP_SIZE_ALLOC); // RoPE + KV Cache unifiedLayer.task("rope_and_kv_cache", @@ -118,18 +206,52 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, layerIndex, config.contextLength()); + // Attention configureAttention(unifiedLayer, layerIndex); - unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asByteArray(), - config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + + // Output Projection (Wo) with residual (Q8 dequantization) + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asByteArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // === FFN Block === + // RMS Normalization + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); } - unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - state.tempFFN).task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte, context, state.wrapXb, state.wrapHb, - weights.w1Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asByteArray(), - config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + + // Fully fused: RMS apply + Gate/Up projections + SiLU + GLU (Q8 dequantization) + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fullyFusedRmsNormFFNGateUpQ8, + context, + state.wrapX, // raw input (FP32) + state.wrapHb, // output + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + weights.w1Layered[layerIndex].asByteArray(), // W1 (Q8) + weights.w3Layered[layerIndex].asByteArray(), // W3 (Q8) + config.dim(), // input dimension + config.hiddenDim(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Down projection (W2) with residual (Q8 dequantization) + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asByteArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Keep activation X on device for next layer + unifiedLayer.persistOnDevice(state.wrapX); + return unifiedLayer; } @@ -162,39 +284,42 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); + // === Worker Grid Definitions === WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); - WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); - WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + + // Fused QKV: dim rows for Q + kvDim rows for K + kvDim rows for V + int fusedQkvGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); - for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); -// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); + // === Per-Layer Grid Assignments (ordered by task graph flow) === + for (int i = 0; i < config.numberOfLayers(); i++) { + // --- Attention Block --- + // RMS Normalization tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); -// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + // --- FFN Block --- + // RMS Normalization + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + // Fused RMS + Gate/Up Projections + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + // Down Projection + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); } + return tornadoForwardScheduler; }