From 48aad8b7816ed5bead318d8a7411e9e3fe9f8891 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Mon, 10 Feb 2025 14:05:36 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=9B=99=E5=85=89=EF=BC=9A=E6=94=AF?= =?UTF-8?q?=E6=8C=81DCU=E6=8E=A8=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/devices/cuda/common_cuda.h | 5 +++ src/ops/causal_softmax/cuda/causal_softmax.cu | 36 +++++++++++++++++-- src/ops/matmul/cuda/matmul_cuda.cu | 20 ++++++++++- src/ops/random_sample/cuda/random_sample.cu | 14 ++++---- src/ops/rearrange/cuda/rearrange.cu | 9 ++--- src/ops/rms_norm/cuda/rms_norm.cu | 13 +++++-- src/ops/swiglu/cuda/swiglu.cu | 2 +- xmake.lua | 34 ++++++++++++++++-- 8 files changed, 112 insertions(+), 21 deletions(-) diff --git a/src/devices/cuda/common_cuda.h b/src/devices/cuda/common_cuda.h index 0c10122f..d46d45c4 100644 --- a/src/devices/cuda/common_cuda.h +++ b/src/devices/cuda/common_cuda.h @@ -1,7 +1,12 @@ #ifndef __COMMON_CUDA_H__ #define __COMMON_CUDA_H__ +#ifdef ENABLE_SUGON_DCU +#define MAX_THREADS_PER_BLOCK 512 +#else #define MAX_THREADS_PER_BLOCK 1024 +#endif + #define MAX_WARP_PER_BLOCK 32 #define WARP_SIZE 32 diff --git a/src/ops/causal_softmax/cuda/causal_softmax.cu b/src/ops/causal_softmax/cuda/causal_softmax.cu index 09fd1741..7f937edc 100644 --- a/src/ops/causal_softmax/cuda/causal_softmax.cu +++ b/src/ops/causal_softmax/cuda/causal_softmax.cu @@ -16,6 +16,12 @@ struct AttentionCausualMask { } }; +struct MaxOp { + __device__ float operator()(const float a, const float b) const { + return a > b ? a: b; + } +}; + template static __device__ void block_padding( Tdata *__restrict__ att, @@ -33,7 +39,12 @@ static __device__ void block_padding( __shared__ float max; { +#ifdef ENABLE_SUGON_DCU + MaxOp max_op; + auto acc = block_op.Reduce(thread_data, max_op, total_seq_len); +#else auto acc = block_op.Reduce(thread_data, cub::Max(), total_seq_len); +#endif if (threadIdx.x == 0) { max = acc; } } __syncthreads(); @@ -67,7 +78,12 @@ static __device__ void block_folding( thread_data[i] = att_idx < total_seq_len && mask(token_idx, seq_len, att_idx, total_seq_len) ? float(att[i]) : -__FLT_MAX__; +#ifdef ENABLE_SUGON_DCU + MaxOp max_op; + thread_max = max_op(thread_max, thread_data[i]); +#else thread_max = cub::Max()(thread_max, thread_data[i]); +#endif } using BlockOp = cub::BlockReduce; @@ -76,7 +92,12 @@ static __device__ void block_folding( __shared__ float max; { +#ifdef ENABLE_SUGON_DCU + MaxOp max_op; + auto acc = block_op.Reduce(thread_max, max_op); +#else auto acc = block_op.Reduce(thread_max, cub::Max()); +#endif if (threadIdx.x == 0) { max = acc; } } __syncthreads(); @@ -130,7 +151,7 @@ static __forceinline__ __device__ void folding( } template -__global__ void fused_softmax_padding( +__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void fused_softmax_padding( Tdata *__restrict__ att, unsigned int const stride_x, unsigned int const stride_y, @@ -140,7 +161,7 @@ __global__ void fused_softmax_padding( } template -__global__ void fused_softmax_folding( +__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void fused_softmax_folding( Tdata *__restrict__ att, unsigned int const stride_x, unsigned int const stride_y, @@ -152,7 +173,7 @@ __global__ void fused_softmax_folding( } template -__global__ void fused_softmax_standard( +__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void fused_softmax_standard( Tdata *__restrict__ att_, unsigned int const stride_x, unsigned int const stride_y, @@ -183,7 +204,12 @@ __global__ void fused_softmax_standard( __syncthreads(); // Block reduce max { +#ifdef ENABLE_SUGON_DCU + MaxOp max_op; + auto acc = block_op.Reduce(partial, max_op); +#else auto acc = block_op.Reduce(partial, cub::Max()); +#endif if (threadIdx.x == 0) { max_ = acc; } } __syncthreads(); @@ -200,7 +226,11 @@ __global__ void fused_softmax_standard( // Block reduce sum { +#ifdef ENABLE_SUGON_DCU + auto acc = block_op.Sum(partial); +#else auto acc = block_op.Reduce(partial, cub::Sum()); +#endif if (threadIdx.x == 0) { sum_ = acc; } } __syncthreads(); diff --git a/src/ops/matmul/cuda/matmul_cuda.cu b/src/ops/matmul/cuda/matmul_cuda.cu index a75b164e..f3d130b0 100644 --- a/src/ops/matmul/cuda/matmul_cuda.cu +++ b/src/ops/matmul/cuda/matmul_cuda.cu @@ -13,20 +13,38 @@ infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, v std::swap(a, b); } + + +#ifdef ENABLE_SUGON_DCU + float alpha_, beta_; +#else Tdata alpha_, beta_; +#endif cudaDataType a_type, b_type, c_type; cublasComputeType_t compute_type; - if constexpr (std::is_same::value) { +#ifdef ENABLE_SUGON_DCU + alpha_ = alpha; + beta_ = beta; +#else alpha_ = __float2half(alpha); beta_ = __float2half(beta); +#endif a_type = b_type = c_type = CUDA_R_16F; +#ifdef ENABLE_SUGON_DCU + compute_type = CUBLAS_COMPUTE_32F; +#else compute_type = CUBLAS_COMPUTE_16F; +#endif } else { alpha_ = alpha; beta_ = beta; a_type = b_type = c_type = CUDA_R_32F; +#ifdef ENABLE_SUGON_DCU + compute_type = CUBLAS_COMPUTE_32F; +#else compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; +#endif } auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; diff --git a/src/ops/random_sample/cuda/random_sample.cu b/src/ops/random_sample/cuda/random_sample.cu index 40761e89..12bc03b2 100644 --- a/src/ops/random_sample/cuda/random_sample.cu +++ b/src/ops/random_sample/cuda/random_sample.cu @@ -5,7 +5,7 @@ #include template -__global__ void softmax( +__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void softmax( T *val_out, int topk, float temperature, int voc) { @@ -29,14 +29,14 @@ __global__ void softmax( } } -__global__ void index(uint64_t *key_in, int voc) { +__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void index(uint64_t *key_in, int voc) { int ind = threadIdx.x + blockIdx.x * blockDim.x; if (ind < voc) { key_in[ind] = static_cast(ind); } } template -__global__ void random_sample_kernel(uint64_t *result, +__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void random_sample_kernel(uint64_t *result, T *val_out, float random_val, float topp, @@ -119,7 +119,9 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace uint64_t *key_in = (uint64_t *) keyTmp; uint64_t *key_out = key_in + voc; - index<<<(voc + 1023) / 1024, 1024, 0, (cudaStream_t) stream>>>(key_in, voc); + int block_dim = MAX_THREADS_PER_BLOCK; + int num_blocks = ROUND_UP_DIV(voc, block_dim); + index<<>>(key_in, voc); //下面开始计算workspace空间 size_t size_radix_sort; size_t size_scan; @@ -134,9 +136,7 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace voc, (cudaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上 //排序结束,然后开始做softmax变换 if (topp > 0 && topk > 1) { - int BLOCK_DIM = 1024; - int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM; - softmax<<>>(val_out, topk, + softmax<<>>(val_out, topk, temperature, voc); diff --git a/src/ops/rearrange/cuda/rearrange.cu b/src/ops/rearrange/cuda/rearrange.cu index 04651f6b..8f90924c 100644 --- a/src/ops/rearrange/cuda/rearrange.cu +++ b/src/ops/rearrange/cuda/rearrange.cu @@ -1,8 +1,9 @@ #include "../../../devices/cuda/common_cuda.h" #include "rearrange.cuh" +#include "../../utils.h" template -static __global__ void rearrange( +static __launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void rearrange( void *__restrict__ dst, int const rsa, int const csa, @@ -35,9 +36,9 @@ void rearrange_nv_gpu(RearrangeCudaDescriptor_t desc, void *y, void const *x, vo return; } - auto warps = 1024 / WARP_SIZE; - auto grid = dim3((c + warps - 1) / warps, r); - auto block = dim3(WARP_SIZE, (c + grid.x - 1) / grid.x); + auto warps = MAX_THREADS_PER_BLOCK / WARP_SIZE; + auto grid = dim3(ROUND_UP_DIV(c, warps), r); + auto block = dim3(WARP_SIZE, ROUND_UP_DIV(c, grid.x)); dst_rs /= unit; dst_cs /= unit; src_rs /= unit; diff --git a/src/ops/rms_norm/cuda/rms_norm.cu b/src/ops/rms_norm/cuda/rms_norm.cu index 0dac45f0..aa36f2f0 100644 --- a/src/ops/rms_norm/cuda/rms_norm.cu +++ b/src/ops/rms_norm/cuda/rms_norm.cu @@ -6,7 +6,7 @@ // assert BLOCK_SIZE >= blockDim.x template -static __global__ void rms_norm_padding( +__launch_bounds__(MAX_THREADS_PER_BLOCK) static __global__ void rms_norm_padding( Tdata *__restrict__ o_, unsigned int const stride_y, Tdata const *__restrict__ x_, @@ -19,8 +19,11 @@ static __global__ void rms_norm_padding( using BlockOp = cub::BlockReduce; __shared__ typename BlockOp::TempStorage temp_storage; +#ifdef ENABLE_SUGON_DCU + auto acc = BlockOp(temp_storage).Sum(x * x); +#else auto acc = BlockOp(temp_storage).Reduce(x * x, cub::Sum()); - +#endif __shared__ Tdata rms; if (threadIdx.x == 0) { rms = Tdata(rsqrtf(acc / float(blockDim.x) + epsilon)); @@ -31,7 +34,7 @@ static __global__ void rms_norm_padding( } template -static __global__ void rms_norm_folding( +__launch_bounds__(MAX_THREADS_PER_BLOCK) static __global__ void rms_norm_folding( Tdata *__restrict__ y, unsigned int const stride_y, Tdata const *__restrict__ x, @@ -59,7 +62,11 @@ static __global__ void rms_norm_folding( { using BlockOp = cub::BlockReduce; __shared__ typename BlockOp::TempStorage temp_storage; +#ifdef ENABLE_SUGON_DCU + acc = BlockOp(temp_storage).Sum(squared); +#else acc = BlockOp(temp_storage).Reduce(squared, cub::Sum()); +#endif } __shared__ Tdata rms; diff --git a/src/ops/swiglu/cuda/swiglu.cu b/src/ops/swiglu/cuda/swiglu.cu index c02ce186..fdd3f16b 100644 --- a/src/ops/swiglu/cuda/swiglu.cu +++ b/src/ops/swiglu/cuda/swiglu.cu @@ -17,7 +17,7 @@ inline int gcd(int a, int b) { } template -static __global__ void swiglu( +static __launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void swiglu( Tdata *__restrict__ c, int const stride_c, Tdata const *__restrict__ a, diff --git a/xmake.lua b/xmake.lua index dcb14715..ce8f065a 100644 --- a/xmake.lua +++ b/xmake.lua @@ -48,6 +48,14 @@ option("metax-gpu") option_end() +option("sugon-dcu") + set_default(false) + set_showmenu(true) + set_description("Enable or disable Sugon DCU kernel") + add_defines("ENABLE_SUGON_DCU") + add_defines("ENABLE_NV_GPU") +option_end() + if is_mode("debug") then add_cxflags("-g -O0") add_defines("DEBUG_MODE") @@ -74,9 +82,11 @@ if has_config("cpu") then end -if has_config("nv-gpu") then - +if has_config("nv-gpu", "sugon-dcu") then add_defines("ENABLE_NV_GPU") + if has_config("sugon-dcu") then + add_defines("ENABLE_SUGON_DCU") + end local CUDA_ROOT = os.getenv("CUDA_ROOT") or os.getenv("CUDA_HOME") or os.getenv("CUDA_PATH") local CUDNN_ROOT = os.getenv("CUDNN_ROOT") or os.getenv("CUDNN_HOME") or os.getenv("CUDNN_PATH") if CUDA_ROOT ~= nil then @@ -267,6 +277,11 @@ if has_config("metax-gpu") then end + +toolchain("sugon-dcu-linker") + set_toolset("sh", "nvcc") +toolchain_end() + target("infiniop") set_kind("shared") @@ -276,6 +291,21 @@ target("infiniop") if has_config("nv-gpu") then add_deps("nv-gpu") end + if has_config("sugon-dcu") then + local builddir = string.format( + "build/%s/%s/%s", + get_config("plat"), + get_config("arch"), + get_config("mode") + ) + add_shflags("-s", "-shared", "-fPIC") + add_links("cublas", "cudnn", "cudadevrt", "cudart_static", "rt", "pthread", "dl") + -- Using -lnv-gpu will fail, manually link the target using full path + add_deps("nv-gpu", {inherit = false}) + add_links(builddir.."/libnv-gpu.a") + set_toolchains("sugon-dcu-linker") + end + if has_config("cambricon-mlu") then add_deps("cambricon-mlu") end From 02970cbad1cbf9d5d83f53f47f4db71f085315ef Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Mon, 10 Feb 2025 14:26:38 +0800 Subject: [PATCH 2/2] =?UTF-8?q?fix:=20cublas=20matmul=20fp16=E4=BD=BF?= =?UTF-8?q?=E7=94=A8f32=E8=AE=A1=E7=AE=97=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ops/matmul/cuda/matmul_cuda.cu | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/src/ops/matmul/cuda/matmul_cuda.cu b/src/ops/matmul/cuda/matmul_cuda.cu index f3d130b0..fcbc755d 100644 --- a/src/ops/matmul/cuda/matmul_cuda.cu +++ b/src/ops/matmul/cuda/matmul_cuda.cu @@ -13,32 +13,12 @@ infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, v std::swap(a, b); } - - -#ifdef ENABLE_SUGON_DCU - float alpha_, beta_; -#else - Tdata alpha_, beta_; -#endif cudaDataType a_type, b_type, c_type; cublasComputeType_t compute_type; if constexpr (std::is_same::value) { -#ifdef ENABLE_SUGON_DCU - alpha_ = alpha; - beta_ = beta; -#else - alpha_ = __float2half(alpha); - beta_ = __float2half(beta); -#endif a_type = b_type = c_type = CUDA_R_16F; -#ifdef ENABLE_SUGON_DCU compute_type = CUBLAS_COMPUTE_32F; -#else - compute_type = CUBLAS_COMPUTE_16F; -#endif } else { - alpha_ = alpha; - beta_ = beta; a_type = b_type = c_type = CUDA_R_32F; #ifdef ENABLE_SUGON_DCU compute_type = CUBLAS_COMPUTE_32F; @@ -58,7 +38,7 @@ infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, v info.m, info.n, info.k, - &alpha_, + &alpha, a, a_type, info.a_matrix.ld(), @@ -67,7 +47,7 @@ infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, v b_type, info.b_matrix.ld(), info.b_matrix.stride, - &beta_, + &beta, c, c_type, info.c_matrix.ld(),