diff --git a/external/tornadovm b/external/tornadovm index f6de88c1..6e29a5be 160000 --- a/external/tornadovm +++ b/external/tornadovm @@ -1 +1 @@ -Subproject commit f6de88c150117d17ddc04a749e34f7f4ac4d0429 +Subproject commit 6e29a5be7d5e8a70dc780ad9ec5b140a0a09c9c6 diff --git a/llama-tornado b/llama-tornado index 9c0d6ba8..4f4f695a 100755 --- a/llama-tornado +++ b/llama-tornado @@ -422,7 +422,7 @@ def create_parser() -> argparse.ArgumentParser: ) debug_group.add_argument( "--profiler-dump-dir", - default="/home/mikepapadim/repos/gpu-llama3.java/prof.json", + default="/home/ruiqi/GPULlama3.java/prof.json", help="Directory for profiler output", ) diff --git a/set_paths b/set_paths index fd807c5e..c61d735f 100644 --- a/set_paths +++ b/set_paths @@ -6,10 +6,10 @@ # Resolve root of this project (LLaMA3) and TornadoVM export LLAMA_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" +export TORNADO_ROOT="/home/ruiqi/TornadoVM_OCL/TornadoVM" # Set the path to TornadoVM SDK binaries -export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" +export TORNADO_SDK="/home/ruiqi/TornadoVM_OCL/TornadoVM/bin/sdk" # Add TornadoVM and LLaMA bin directories to PATH export PATH="${PATH}:${TORNADO_SDK}:${LLAMA_ROOT}" diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index ac6619f6..1c3403a0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -33,6 +33,7 @@ public TransformerComputeKernelsLayered() { * @param localMemSize * Size of local memory allocation (must match work group size) */ + public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) { int gid = context.globalIdx; int lid = context.localIdx; @@ -80,20 +81,92 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray } /** - * Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization. + * Performs RMS (Root Mean Square) normalization using parallel reduction. It first computes the variance and scaling factor across all work groups, + * then it applies the computed normalization factor to input and weight elements. * + *
* Formula: output[i] = weight[i] * (normalizationFactor * x[i]) * + * Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. All thread combines all partial + * sums and computes normalization factor 5. Applies the computed normalization factor to input and weight elements. + * * @param context * Kernel execution context * @param output - * Array for normalized output + * Array to store partial sums and final normalization factor * @param x - * Input values to normalize + * Input array to normalize * @param weights * Weight values for each element * @param temp * Temporary array containing normalization factor at index 0 + * @param size + * Number of elements to process + * @param ermsNorm + * Epsilon value squared for numerical stability + * @param localMemSize + * Size of local memory allocation (must match work group size) + */ + + public static void reductionOneBlockWithLayerFuse(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp, int size, float ermsNorm, int localMemSize) { + int gid = context.globalIdx; + int lid = context.localIdx; + int groupId = context.groupIdx; + int groupSize = context.localGroupSizeX; + + // Allocate local memory with the provided size + float[] localX = context.allocateFloatLocalArray(localMemSize); + + // Load input value and compute square + if (gid < size) { + float v = x.get(gid); + localX[lid] = v * v; + } else { + localX[lid] = 0.0f; + } + + // Perform parallel reduction within the work group + for (int stride = (groupSize / 2); stride > 0; stride /= 2) { + context.localBarrier(); + if (lid < stride) { + localX[lid] += localX[lid + stride]; + } + } + + // Each workgroup stores its partial sum in a different location + if (lid == 0) { + // Store the partial sum from each workgroup + temp.set(groupId, localX[0]); + } + + context.globalBarrier(); + + float localss = 0.0f; + int numGroups = (size + groupSize - 1) / groupSize; + for (int i = 0; i < numGroups; i++) { // Assuming 8 workgroups + localss += temp.get(i); + } + localss /= size; + localss += ermsNorm; + localss = 1.0f / TornadoMath.sqrt(localss); + + if (gid < size) { + float in = x.get(gid); + float w = weights.get(gid); + output.set(gid, w * (localss * in)); + } + } + + /** + * Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization. + *
+ * Formula: output[i] = weight[i] * (normalizationFactor * x[i]) + * + * @param context Kernel execution context + * @param output Array for normalized output + * @param x Input values to normalize + * @param weights Weight values for each element + * @param temp Temporary array containing normalization factor at index 0 */ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) { int gid = context.globalIdx; @@ -104,25 +177,17 @@ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray /** * Copies keys and values into the key-value cache for attention computation. Enables efficient access to past key-value pairs during autoregressive generation. - * + *
* Cache layout: [layer][position][dimension] - Each layer has its own key and value cache - Each position in sequence has a key and value vector * - * @param destKeyCache - * Destination array for key cache - * @param srcKey - * Source keys to copy - * @param destValueCache - * Destination array for value cache - * @param srcValue - * Source values to copy - * @param positioNlayer - * Array containing current position - * @param kvDim - * Dimension of key/value vectors - * @param layer - * Current transformer layer index - * @param contextLength - * Maximum sequence length + * @param destKeyCache Destination array for key cache + * @param srcKey Source keys to copy + * @param destValueCache Destination array for value cache + * @param srcValue Source values to copy + * @param positioNlayer Array containing current position + * @param kvDim Dimension of key/value vectors + * @param layer Current transformer layer index + * @param contextLength Maximum sequence length */ public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) { @@ -158,21 +223,15 @@ public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArr /** * Applies Rotary Position Encoding (RoPE) to query and key vectors. RoPE rotates pairs of dimensions based on their position in the sequence, enabling the model to learn relative positional * information. - * + *
* For each pair of dimensions (2*i, 2*i+1): - Compute rotation angle based on position and frequency - Apply 2D rotation to the pair * - * @param context - * Kernel execution context - * @param positionHolder - * Array containing current position - * @param sq - * Query vectors to rotate - * @param sk - * Key vectors to rotate - * @param kv_dim - * Dimension of key/value vectors - * @param head_size - * Dimension of each attention head + * @param context Kernel execution context + * @param positionHolder Array containing current position + * @param sq Query vectors to rotate + * @param sk Key vectors to rotate + * @param kv_dim Dimension of key/value vectors + * @param head_size Dimension of each attention head */ public static void ropeRotation(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) { int i = context.globalIdx * 2; @@ -247,31 +306,20 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold /** * Computes attention for a single head. Implements scaled dot-product attention with softmax normalization. - * + *
* Steps: 1. Compute attention scores: Q·K / sqrt(head_size) 2. Apply softmax (with max subtraction for numerical stability) 3. Compute weighted sum of values * - * @param allQ - * All query vectors - * @param key_cache - * Cached keys - * @param value_cache - * Cached values - * @param allXb - * Output buffer - * @param h - * Head index to process - * @param headSize - * Dimension per head - * @param kvDim - * Key/value dimension - * @param kvMul - * Key multiplier for grouped attention - * @param loff - * Layer offset in cache - * @param pos - * Current position - * @param wrapAtt - * Attention weights buffer + * @param allQ All query vectors + * @param key_cache Cached keys + * @param value_cache Cached values + * @param allXb Output buffer + * @param h Head index to process + * @param headSize Dimension per head + * @param kvDim Key/value dimension + * @param kvMul Key multiplier for grouped attention + * @param loff Layer offset in cache + * @param pos Current position + * @param wrapAtt Attention weights buffer */ private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos, FloatArray wrapAtt) { @@ -627,23 +675,16 @@ public static void processHeadsFlashAttentionOpt(KernelContext context, FloatArr /** * Performs optimized matrix-vector multiplication where each work group processes one row of the matrix. - * + *
* Algorithm: 1. Each work group handles one output dimension 2. Threads in work group compute partial dot products 3. Parallel reduction yields final row result * - * @param context - * Kernel execution context - * @param x - * Input vector - * @param hb - * Output vector - * @param w - * Weight matrix (row-major) - * @param n - * Input dimension - * @param d - * Output dimension - * @param localWorkGroupSize - * Number of threads per work group + * @param context Kernel execution context + * @param x Input vector + * @param hb Output vector + * @param w Weight matrix (row-major) + * @param n Input dimension + * @param d Output dimension + * @param localWorkGroupSize Number of threads per work group */ public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 96acd650..6ddcdd18 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -31,6 +31,7 @@ public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Config public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim()/2, 128); WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + //System.out.println("llama config dim: " + config.dim()); int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); @@ -54,9 +55,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); } @@ -112,13 +111,8 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, weights.w3Layered[layerIndex].asHalfFloatArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); unifiedLayer - .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), + .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp, config.dim(), config.rmsNormEps(), state.localSize); + unifiedLayer.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) @@ -130,12 +124,8 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, configureAttention(unifiedLayer, layerIndex); unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN, config.dim(), config.rmsNormEps(), state.localSize); + unifiedLayer.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 75f9f531..66628a0c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -63,7 +63,6 @@ public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeigh public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { // RMS norm worker WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize); - // Combined QKV matmul worker int matmulQkvGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid matmulQkvRowMajorWorker = WorkerGridFactory.genericWorker(matmulQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 858848ea..d5400789 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -108,9 +108,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); } @@ -178,10 +176,8 @@ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) weights.w3Layered[layerIndex].asHalfFloatArray()); // unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); // - unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.temp, qwen2State.wrapX, config.dim(), config.rmsNormEps(), + unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen2State.temp, config.dim(), config.rmsNormEps(), qwen2State.localSize) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), - qwen2State.temp) .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), @@ -197,10 +193,8 @@ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), qwen2State.positionHolder, layerIndex, config.contextLength()) .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapXb, qwen2State.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.tempFFN, qwen2State.wrapX, config.dim(), config.rmsNormEps(), + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), qwen2State.tempFFN, config.dim(), config.rmsNormEps(), qwen2State.localSize) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - qwen2State.tempFFN) .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen2State.wrapXb, qwen2State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapHb, qwen2State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 379921c3..e9302361 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -103,7 +103,6 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { // Map workers to tasks for each layer for (int i = 0; i < config.numberOfLayers(); i++) { gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); gridScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker); gridScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker); @@ -121,7 +120,6 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); gridScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker); gridScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker); } @@ -193,9 +191,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) weights.w3Layered[layerIndex].asHalfFloatArray() // ); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.temp, qwen3State.wrapX, // in - qwen3Config.dim(), qwen3Config.rmsNormEps(), qwen3State.localSize).task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen3State.wrapXb, // out - qwen3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen3State.temp); + unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen3State.temp,// in + qwen3Config.dim(), qwen3Config.rmsNormEps(), qwen3State.localSize); int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); int kvDim0 = nEmbdGqa; @@ -247,11 +244,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3Config.dim(), // dim0 = 1024 LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.tempFFN, qwen3State.wrapX, qwen3Config.dim(), - qwen3Config.rmsNormEps(), qwen3State.localSize) - .task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, qwen3Config.dim(), qwen3Config.rmsNormEps()) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - qwen3State.tempFFN); + unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), qwen3State.tempFFN, qwen3Config.dim(), + qwen3Config.rmsNormEps(), qwen3State.localSize); unifiedLayer.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen3State.wrapXb, qwen3State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), weights.w3Layered[layerIndex].asHalfFloatArray(), qwen3Config.dim(), qwen3Config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index 385b626f..e62950fd 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -71,12 +71,7 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, weights.w2Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) + unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp, config.dim(), config.rmsNormEps(), state.localSize) .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapK, @@ -89,12 +84,7 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, configureAttention(unifiedLayer, layerIndex); unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, - config.dim(), config.rmsNormEps()); - } - unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN, config.dim(), config.rmsNormEps(), state.localSize) .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) @@ -150,9 +140,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index 441bbece..668c6dcb 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -77,9 +77,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wDown", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wGateUp", wgetHiddenDimRowMajorWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker); @@ -147,20 +145,15 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex // RMSNorm for attention input unifiedLayer.task("reductionsOneBlock", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, - phi3State.temp, - phi3State.wrapX, - phi3Config.dim(), - phi3Config.rmsNormEps(), - phi3State.localSize) - .task("mapContext", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, phi3State.wrapXb, phi3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), - phi3State.temp); + phi3State.temp, + phi3Config.dim(), + phi3Config.rmsNormEps(), + phi3State.localSize); // Combined QKV projection (quantized) unifiedLayer.task("qkvmatmul", @@ -232,20 +225,15 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex // FFN section: RMSNorm unifiedLayer.task("reductionsOneBlockFFN", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, - phi3State.tempFFN, - phi3State.wrapX, - phi3Config.dim(), - phi3Config.rmsNormEps(), - phi3State.localSize) - .task("mapContextFFN", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, phi3State.wrapXb, phi3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), - phi3State.tempFFN); + phi3State.tempFFN, + phi3Config.dim(), + phi3Config.rmsNormEps(), + phi3State.localSize); // FFN: combined Up and Gate projection (outputs 2 * hiddenDim, quantized) unifiedLayer.task("wGateUp", diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index d21f3509..1add4843 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -111,9 +111,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); } @@ -180,10 +178,8 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd ); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) + unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, state.wrapXb, + state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp, config.dim(), config.rmsNormEps(), state.localSize) .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, @@ -203,10 +199,8 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd state.positionHolder, layerIndex, config.contextLength()) .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, state.wrapXb, + state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN, config.dim(), config.rmsNormEps(), state.localSize) .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index b0f348c2..7bd7527f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -93,7 +93,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) for (int i = 0; i < config.numberOfLayers(); i++) { tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker); @@ -106,7 +105,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker); } @@ -180,11 +178,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) // RMS norm for attention input unifiedLayer.task("reductionsOneBlock", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, qwen3State.temp, qwen3State.wrapX, config.dim(), config.rmsNormEps(), qwen3State.localSize) - .task("mapContext", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, - context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen3State.temp); + TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, + context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen3State.temp, config.dim(), config.rmsNormEps(), qwen3State.localSize); // QKV projections with Qwen3 GQA dimensions // Q8_0 weights pass both quants and scales @@ -254,11 +249,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) // RMS norm for FFN input unifiedLayer.task("reductionsOneBlockFFN", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, qwen3State.tempFFN, qwen3State.wrapX, config.dim(), config.rmsNormEps(), qwen3State.localSize) - .task("mapContextFFN", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, - context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), qwen3State.tempFFN); + TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, + context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), qwen3State.tempFFN, config.dim(), config.rmsNormEps(), qwen3State.localSize); // Fused FFN: w1(x) ⊗ w3(x) with SiLU activation (Q8_0 weights) unifiedLayer.task("fused_ffn_w1_w3",