From 73968c329657bd4a78b2d4e57b39c936779cf7e5 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Fri, 26 Dec 2025 13:12:26 +0800 Subject: [PATCH] issue/843: success per_channel_quant_int8 --- .../ops/quant/per_channel_quant_int8.h | 28 ++ .../per_channel_quant_int8/cuda/kernel.cuh | 277 ++++++++++++++ .../ops/quant/per_channel_quant_int8/info.h | 59 +++ .../nvidia/per_channel_quant_int8_nvidia.cu | 118 ++++++ .../nvidia/per_channel_quant_int8_nvidia.cuh | 7 + .../quant/per_channel_quant_int8/operator.cc | 98 +++++ .../per_channel_quant_int8.h | 40 ++ test/infiniop/libinfiniop/op_register.py | 35 ++ test/infiniop/per_channel_quant_int8.py | 194 ++++++++++ test/infiniop/w8a8_per_channel.py | 347 ++++++++++++++++++ xmake.lua | 2 +- xmake/nvidia.lua | 2 +- 12 files changed, 1205 insertions(+), 2 deletions(-) create mode 100644 include/infiniop/ops/quant/per_channel_quant_int8.h create mode 100644 src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh create mode 100644 src/infiniop/ops/quant/per_channel_quant_int8/info.h create mode 100644 src/infiniop/ops/quant/per_channel_quant_int8/nvidia/per_channel_quant_int8_nvidia.cu create mode 100644 src/infiniop/ops/quant/per_channel_quant_int8/nvidia/per_channel_quant_int8_nvidia.cuh create mode 100644 src/infiniop/ops/quant/per_channel_quant_int8/operator.cc create mode 100644 src/infiniop/ops/quant/per_channel_quant_int8/per_channel_quant_int8.h create mode 100644 test/infiniop/per_channel_quant_int8.py create mode 100644 test/infiniop/w8a8_per_channel.py diff --git a/include/infiniop/ops/quant/per_channel_quant_int8.h b/include/infiniop/ops/quant/per_channel_quant_int8.h new file mode 100644 index 000000000..ce21f4556 --- /dev/null +++ b/include/infiniop/ops/quant/per_channel_quant_int8.h @@ -0,0 +1,28 @@ +#ifndef __INFINIOP_PER_CHANNEL_QUANT_INT8_API_H__ +#define __INFINIOP_PER_CHANNEL_QUANT_INT8_API_H__ + +#include "../../operator_descriptor.h" + +typedef InfiniopDescriptor *infiniopPerChannelQuantI8Descriptor_t; + +__C __export infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t handle, + infiniopPerChannelQuantI8Descriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_packed_desc, + infiniopTensorDescriptor_t x_scale_desc, + infiniopTensorDescriptor_t x_zero_desc, + infiniopTensorDescriptor_t x_desc); + +__C __export infiniStatus_t infiniopGetPerChannelQuantI8WorkspaceSize(infiniopPerChannelQuantI8Descriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopPerChannelQuantI8(infiniopPerChannelQuantI8Descriptor_t desc, + void *workspace, + size_t workspace_size, + void *x_packed, + void *x_scale, + void *x_zero, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroyPerChannelQuantI8Descriptor(infiniopPerChannelQuantI8Descriptor_t desc); + +#endif diff --git a/src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh b/src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh new file mode 100644 index 000000000..a3cbdbe01 --- /dev/null +++ b/src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh @@ -0,0 +1,277 @@ +#ifndef __PERCHANNEL_QUANTINT8_KERNEL_CUH__ +#define __PERCHANNEL_QUANTINT8_KERNEL_CUH__ + +#include +__device__ inline int round_half_away_from_zero(float x) { + float ax = fabsf(x); + float r = floorf(ax + 0.5f); + return (x >= 0.0f) ? (int)r : -(int)r; +} + +template +__device__ void blockPerChannelQuantI8Kernel( + int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, + int M, int K) { + int row = blockIdx.x; + int tid = row * K; + + // ---- 1. reduce max ---- + float local_max = op::common_cuda::reduce_op::max( + x + tid, K); + + __shared__ float global_max_f; + if (threadIdx.x == 0) { + global_max_f = local_max; + } + __syncthreads(); + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // ---- 2. reduce min ---- + float thread_min = __FLT_MAX__; + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + thread_min = fminf(thread_min, (float)x[tid + ind]); + } + float local_min = BlockReduce(temp_storage).Reduce(thread_min, cub::Min()); + + __shared__ float global_min_f; + if (threadIdx.x == 0) { + global_min_f = local_min; + } + __syncthreads(); + + // ---- 3. 使用 float(匹配 python)计算 scale/zero ---- + float global_max = global_max_f; + float global_min = global_min_f; + + float scale = (global_max - global_min) / 255.0f; + if (scale < 1e-8f) { + scale = 1e-8f; + } + + float inv_scale = 1.0f / scale; + float zero = -global_min * inv_scale - 128.0f; + + // 写回 scale, zero + x_scale[row] = (Tdata)scale; + x_zero[row] = (Tdata)zero; + + // ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)---- + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + + float v = (float)x[tid + ind]; + float qf = v * inv_scale + zero; + + int q = round_half_away_from_zero(qf); + + if (q > 127) { + q = 127; + } + if (q < -128) { + q = -128; + } + + x_packed[tid + ind] = (int8_t)q; + } +} + +template +__device__ void blockPerChannelQuantI8SymKernel( + int8_t *x_packed, float *x_scale, const Tdata *x, + int M, int K) { + int row = blockIdx.x; + int tid = row * K; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // ---- 2. reduce min ---- + float thread_max = -__FLT_MAX__; + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + thread_max = fmaxf(thread_max, fabs((float)x[tid + ind])); + } + float local_max = BlockReduce(temp_storage).Reduce(thread_max, cub::Max()); + + __shared__ float global_max_f; + if (threadIdx.x == 0) { + global_max_f = local_max; + } + __syncthreads(); + + // ---- 3. 使用 float(匹配 python)计算 scale/zero ---- + float global_max = global_max_f; + + float scale = global_max / 127.0f; + if (scale < 1e-8f) { + scale = 1e-8f; + } + + float inv_scale = 1.0f / scale; + + // 写回 scale, zero + x_scale[row] = (Tdata)scale; + + // ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)---- + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + + float v = (float)x[tid + ind]; + float qf = v * inv_scale; + + int q = round_half_away_from_zero(qf); + + if (q > 127) { + q = 127; + } + if (q < -127) { + q = -127; + } + + x_packed[tid + ind] = (int8_t)q; + } +} + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return max(a, b); + } +}; +template +struct MinOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return min(a, b); + } +}; +template