diff --git a/src/parallax/metal/extensions/CMakelists.txt b/src/parallax/metal/extensions/CMakelists.txt new file mode 100755 index 00000000..06b34f31 --- /dev/null +++ b/src/parallax/metal/extensions/CMakelists.txt @@ -0,0 +1,83 @@ +cmake_minimum_required(VERSION 3.27) + +project(_ext LANGUAGES CXX) + +# ----------------------------- Setup ----------------------------- +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) + +# ----------------------------- Dependencies ----------------------------- +find_package( + Python 3.10 + COMPONENTS Interpreter Development.Module + REQUIRED) +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE nanobind_ROOT) +find_package(nanobind CONFIG REQUIRED) + +execute_process( + COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE MLX_ROOT) +find_package(MLX CONFIG REQUIRED) + +# ----------------------------- Extensions ----------------------------- + +# Add library +add_library(parallax_ext) + +# Add sources +target_sources( + parallax_ext + PUBLIC ${CMAKE_CURRENT_LIST_DIR}/paged_attention/paged_attention.cpp) + +# Add include headers +target_include_directories(parallax_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}) + +# Link to mlx +target_link_libraries(parallax_ext PUBLIC mlx) + +# ----------------------------- Metal ----------------------------- + +# Build metallib +if(MLX_BUILD_METAL) + mlx_build_metallib( + TARGET + parallax_ext_metallib + TITLE + parallax_ext + SOURCES + ${CMAKE_CURRENT_LIST_DIR}/paged_attention/utils.metal + ${CMAKE_CURRENT_LIST_DIR}/paged_attention/float8.metal + ${CMAKE_CURRENT_LIST_DIR}/paged_attention/paged_attention.metal + ${CMAKE_CURRENT_LIST_DIR}/paged_attention/reshape_and_cache.metal + INCLUDE_DIRS + ${PROJECT_SOURCE_DIR} + ${MLX_INCLUDE_DIRS} + OUTPUT_DIRECTORY + ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) + + add_dependencies(parallax_ext parallax_ext_metallib) + +endif() + +# ----------------------------- Python Bindings ----------------------------- +nanobind_add_module( + _ext + NB_STATIC + STABLE_ABI + LTO + NOMINSIZE + NB_DOMAIN + mlx + ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp) +target_link_libraries(_ext PRIVATE parallax_ext) + +if(BUILD_SHARED_LIBS) + target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path) +endif() diff --git a/src/parallax/metal/extensions/bindings.cpp b/src/parallax/metal/extensions/bindings.cpp new file mode 100755 index 00000000..e77f7571 --- /dev/null +++ b/src/parallax/metal/extensions/bindings.cpp @@ -0,0 +1,43 @@ +#include +#include + +#include "paged_attention/paged_attention.h" + +namespace nb = nanobind; +using namespace nb::literals; + +NB_MODULE(_ext, m) { + m.doc() = "vLLM PagedAttentionV1"; + + m.def( + "PagedAttentionV1", + ¶llax_ext::paged_attention_v1, + "query"_a, + "key_cache"_a, + "value_cache"_a, + "block_tables"_a, + "seq_lens"_a, + "num_kv_heads"_a, + "block_size"_a, + "max_seq_len"_a, + "scale"_a, + nb::kw_only(), + "stream"_a = nb::none(), + R"( + vLLM PagedAttentionV1 operation + + Args: + query (array): Input array [num_seqs, num_heads, head_size]. + key_cache (array): Input array [num_blocks, num_heads, head_size/x, block_size, x]. + value_cache (array): Input array [num_blocks, num_heads, head_size, block_size]. + block_tables (array): Input array [num_seqs, max_num_blocks_per_seq]. + seq_lens (array): Input array [num_seqs]. + num_kv_heads (int): Input parameter. + block_size (int): Input parameter. + max_seq_len (int): Input parameter. + scale (float): Input parameter. + + Returns: + array: ``Paged attention result`` + )"); +} diff --git a/src/parallax/metal/extensions/ops/__init__.py b/src/parallax/metal/extensions/ops/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/src/parallax/metal/extensions/paged_attention/float8.metal b/src/parallax/metal/extensions/paged_attention/float8.metal new file mode 100644 index 00000000..9e6f33eb --- /dev/null +++ b/src/parallax/metal/extensions/paged_attention/float8.metal @@ -0,0 +1,122 @@ +#include +using namespace metal; + +// Helpers ------------------------------------------------------------ +static inline uint as_bits(float x) { return as_type(x); } +static inline float from_bits(uint b) { return as_type(b); } + +// ------------------------------------------------------------------- +// FP8 E4M3 (bias = 7) +// ------------------------------------------------------------------- +inline float fp8_e4m3_to_float(uchar v) { + const uint s = v >> 7; + const uint exp = (v >> 3) & 0xF; + const uint man = v & 0x7; + + if (exp == 0) { // zero / sub-normal + if (man == 0) + return s ? -0.f : 0.f; + const float m = float(man) / 8.f; // already scaled by 2^-3 + float val = ldexp(m, 1 - 7); // 2^(1-bias) = 2^-6 + return s ? -val : val; + } + + if (exp == 0xF) { // Inf / NaN (E4M3FN keeps only NaN) + if (man != 0) + return NAN; + return s ? -INFINITY : INFINITY; + } + + const float m = 1.f + float(man) / 8.f; + float val = ldexp(m, int(exp) - 7); + return s ? -val : val; +} + +// ------------------------------------------------------------------- +// FP8 E5M2 (bias = 15) +// ------------------------------------------------------------------- +inline float fp8_e5m2_to_float(uchar v) { + const uint s = v >> 7; + const uint exp = (v >> 2) & 0x1F; + const uint man = v & 0x3; + + if (exp == 0) { + if (man == 0) + return s ? -0.f : 0.f; + const float m = float(man) / 4.f; + float val = ldexp(m, 1 - 15); // 2^(1-bias) = 2^-14 + return s ? -val : val; + } + + if (exp == 0x1F) { + if (man != 0) + return NAN; + return s ? -INFINITY : INFINITY; + } + + const float m = 1.f + float(man) / 4.f; + float val = ldexp(m, int(exp) - 15); + return s ? -val : val; +} + +// ------------------------------------------------------------------- +// Encoding helpers (round-to-nearest-even, gradual under-flow, sat-to-∞) +// ------------------------------------------------------------------- +namespace detail { +template +inline uchar fp32_to_fp8(float f) { + const uint bits = as_bits(f); + const uint s = bits >> 31; + const uint abs = bits & 0x7FFFFFFF; + + // NaN propagates, Inf saturates + if (abs >= 0x7F800000u) { + return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS) | + (abs != 0x7F800000u)); + } + + int e = int((abs >> 23) & 0xFF) - 127; // unbiased exponent + uint m = abs & 0x7FFFFFu; // 23-bit mantissa + const int EXP_MAX = (1 << EXP_BITS) - 2; // last finite exponent + + // ---------- Normal path ------------------------------------------------- + int e_fp8 = e + BIAS; + if (e_fp8 >= 1 && e_fp8 <= EXP_MAX) { + // round-to-nearest-even + const int shift = 23 - MAN_BITS; + uint mant = m >> shift; + const uint lsb = mant & 1u; + const uint round = (m >> (shift - 1)) & 1u; + const uint sticky = (m & ((1u << (shift - 1)) - 1u)) != 0u; + mant += (round & (sticky | lsb)); + if (mant >> MAN_BITS) { // mantissa overflow + mant = 0; + ++e_fp8; + if (e_fp8 > EXP_MAX) + return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS)); // ∞ + } + return uchar((s << 7) | (uint(e_fp8) << MAN_BITS) | + (mant & ((1u << MAN_BITS) - 1u))); + } + + // ---------- Sub-normal / under-flow ------------------------------------ + if (e_fp8 < 1 - MAN_BITS) // too small -> ±0 + return uchar(s << 7); + + // shift so that exponent becomes 1 + int rshift = (1 - e_fp8) + (23 - MAN_BITS); + uint mant = (0x800000u | m); // implicit 1 + uint rounded = (mant + (1u << (rshift - 1))) >> rshift; + if (rounded == 0) + return uchar(s << 7); // rounds to zero + + return uchar((s << 7) | (rounded & ((1u << MAN_BITS) - 1u))); +} +} // namespace detail + +inline uchar float_to_fp8_e4m3(float f) { + return detail::fp32_to_fp8<4, 3, 7>(f); +} +inline uchar float_to_fp8_e5m2(float f) { + return detail::fp32_to_fp8<5, 2, 15>(f); +} diff --git a/src/parallax/metal/extensions/paged_attention/paged_attention.cpp b/src/parallax/metal/extensions/paged_attention/paged_attention.cpp new file mode 100755 index 00000000..d6f65424 --- /dev/null +++ b/src/parallax/metal/extensions/paged_attention/paged_attention.cpp @@ -0,0 +1,126 @@ +#include +#include +#include + +#include "paged_attention.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" + +namespace parallax_ext { + +mx::array paged_attention_v1( + const mx::array& query, // [num_seqs, num_heads, head_size] + const mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const mx::array& value_cache, // [num_blocks, num_heads, head_size, block_size] + const mx::array& block_tables, // [num_seqs, max_num_blocks_per_seq] + const mx::array& seq_lens, // [num_seqs] + const int64_t num_kv_heads, + const int64_t block_size, + const int64_t max_seq_len, + const float scale, + mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation +) { + out_dtype = query.dtype(); + out_shape = query.shape(); + const std::vector inputs = {query, key_cache, value_cache, block_tables, seq_lens}; + // Construct the array as the output of the PagedAttentionV1 primitive + return mx::array( + /* const std::vector& shape = */ out_shape, + /* Dtype dtype = */ out_dtype, + /* std::unique_ptr primitive = */ + std::make_shared(to_stream(s), num_kv_heads, block_size, max_seq_len, scale), + /* const std::vector& inputs = */ inputs); +} + +/** Evaluate primitive on GPU */ +void PagedAttentionV1::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // Prepare inputs + assert(inputs.size() == 5); + auto& q = inputs[0]; + auto& k = inputs[1]; + auto& v = inputs[2]; + auto& block_tables = inputs[3]; + auto& seq_lens = inputs[4]; + auto& out = outputs[0]; + + // Each primitive carries the stream it should execute on + // and each stream carries its device identifiers + auto& s = stream(); + // We get the needed metal device using the stream + auto& d = metal::device(s.device); + + // Allocate output memory + out.set_data(allocator::malloc(out.nbytes())); + + // Set kernel paramas + const int num_threads = 256; + const int num_simd_lanes = 32; + const int partition_size = 0; // v1 doesn't use partitioning + + // Resolve name of kernel + std::string kname; + kname = "paged_attention_" + type_to_name(out); + kname += "_cache_" + type_to_name(k); + kname += "_hs" + std::to_string(num_kv_heads_); + kname += "_bs" + std::to_string(block_size_); + kname += "_nt" + std::to_string(num_threads); + kname += "_nsl" + std::to_string(num_simd_lanes); + kname += "_ps" + std::to_string(partition_size); + + // Load the metal library + auto lib = d.get_library("parallax_ext", current_binary_dir()); + + // Make a kernel from this metal library + auto kernel = d.get_kernel(kname, lib); + + // Prepare to encode kernel + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + // Calculate parameters + float softcapping_ = 1.0; // hard code for not use + const int64_t num_seqs = q.shape(0); + const int64_t num_heads = q.shape(1); + const int64_t max_num_blocks_per_seq = block_tables.shape(1); + int32_t q_stride = static_cast(q.strides(0)); + int32_t kv_block_stride = static_cast(k.strides(0)); + int32_t kv_head_stride = static_cast(k.strides(1)); + + // Encode arrays to kernel + // Skip exp_sums and max_logits for v1 (buffers 0, 1) + compute_encoder.set_output_array(out, 2); + compute_encoder.set_input_array(q, 3); + compute_encoder.set_input_array(k, 4); + compute_encoder.set_input_array(v, 5); + // Skip k_scale and v_scale for non-fp8 (buffers 6, 7) + compute_encoder.set_bytes(num_kv_heads_, 8); + compute_encoder.set_bytes(scale_, 9); + compute_encoder.set_bytes(softcapping_, 10); + compute_encoder.set_input_array(block_tables, 11); + compute_encoder.set_input_array(seq_lens, 12); + compute_encoder.set_bytes(max_num_blocks_per_seq, 13); + // Skip alibi_slopes (buffer 14) + compute_encoder.set_bytes(q_stride, 14); + compute_encoder.set_bytes(kv_block_stride, 15); + compute_encoder.set_bytes(kv_head_stride, 16); + + // Dispatch configuration + // Grid: (num_heads, num_seqs, 1) - no partitioning for v1 + MTL::Size grid = MTLSizeMake(num_heads, num_seqs, 1); + MTL::Size threadgroup = MTLSizeMake(num_threads, 1, 1); + + // Launch the grid with the given number of threads divided among + // the given threadgroups + compute_encoder.dispatch_threads(grid, threadgroup); +} + +/** Equivalence check **/ +bool PagedAttentionV1::is_equivalent(const mx::Primitive& other) const { + const PagedAttentionV1& r_other = static_cast(other); + return num_kv_heads_ == r_other.num_kv_heads_ && block_size_ == r_other.block_size_ && + max_seq_len_ == r_other.max_seq_len_ && scale_ == r_other.scale_; +} + +} // namespace parallax_ext diff --git a/src/parallax/metal/extensions/paged_attention/paged_attention.h b/src/parallax/metal/extensions/paged_attention/paged_attention.h new file mode 100755 index 00000000..27456d0e --- /dev/null +++ b/src/parallax/metal/extensions/paged_attention/paged_attention.h @@ -0,0 +1,48 @@ +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mx = mlx::core; + +namespace parallax_ext { + +mx::array paged_attention_v1( + const mx::array& query, // [num_seqs, num_heads, head_size] + const mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const mx::array& value_cache, // [num_blocks, num_heads, head_size, block_size] + const mx::array& block_tables, // [num_seqs, max_num_blocks_per_seq] + const mx::array& seq_lens, // [num_seqs] + const int64_t num_kv_heads, + const int64_t block_size, + const int64_t max_seq_len, + const float scale, + mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation +); + +class PagedAttentionV1 : public mx::Primitive { + public: + explicit PagedAttentionV1(mx::Stream stream, int64_t num_kv_heads, int64_t block_size, int64_t max_seq_len, float scale) + : mx::Primitive(stream), num_kv_heads_(num_kv_heads), block_size_(block_size), max_seq_len_(max_seq_len), scale_(scale){}; + + // void eval_cpu( + // const std::vector& inputs, + // std::vector& outputs) override; + void eval_gpu( + const std::vector& inputs, + std::vector& outputs) override; + + /** The name of primitive. */ + const char* name() const override { + return "PagedAttentionV1"; + } + + /** Equivalence check **/ + bool is_equivalent(const mx::Primitive& other) const override; + + private: + int64_t num_kv_heads_; + int64_t block_size_; + int64_t max_seq_len_; + float scale_; +}; + +} // namespace parallax_ext diff --git a/src/parallax/metal/extensions/paged_attention/paged_attention.metal b/src/parallax/metal/extensions/paged_attention/paged_attention.metal new file mode 100644 index 00000000..abc6cc3d --- /dev/null +++ b/src/parallax/metal/extensions/paged_attention/paged_attention.metal @@ -0,0 +1,1401 @@ +// Updated from MLX commit has f70764a + +#include "./utils.metal" +#include "./float8.metal" +#include +#include + +using namespace metal; + +// ========================================== Generic vector types + +// A vector type to store Q, K, V elements. +template struct Vec {}; + +// A vector type to store FP32 accumulators. +template struct FloatVec {}; + +// Template vector operations. +template inline Acc mul(A a, B b); + +template inline float sum(T v); + +template inline float dot(T a, T b) { + return sum(mul(a, b)); +} + +template inline float dot(T a, T b) { + return sum(mul(a, b)); +} + +// FP32 vector data types. +struct Float8_ { + float4 x; + float4 y; +}; + +template <> struct Vec { + using Type = float; +}; +template <> struct Vec { + using Type = float2; +}; +template <> struct Vec { + using Type = float4; +}; +template <> struct Vec { + using Type = Float8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(float a, float b) { return a * b; } + +template <> inline float2 mul(float2 a, float2 b) { return a * b; } + +template <> inline float4 mul(float4 a, float4 b) { return a * b; } + +template <> inline Float8_ mul(Float8_ a, Float8_ b) { + Float8_ c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> inline float sum(float a) { return a; } + +template <> inline float sum(float2 a) { return a.x + a.y; } + +template <> inline float sum(float4 a) { return a.x + a.y + a.z + a.w; } + +template <> inline float sum(Float8_ a) { return sum(a.x) + sum(a.y); } + +inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) { + Float8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread float &dst, float src) { dst = src; } +inline void from_float(thread float2 &dst, float2 src) { dst = src; } +inline void from_float(thread float4 &dst, float4 src) { dst = src; } +inline void from_float(thread Float8_ &dst, Float8_ src) { dst = src; } + +// BF16 vector data types. +// #if defined(__HAVE_BFLOAT__) + +// struct Bfloat8_ { +// bfloat4 x; +// bfloat4 y; +// }; + +// template<> +// struct Vec { +// using Type = bfloat; +// }; +// template<> +// struct Vec { +// using Type = bfloat2; +// }; +// template<> +// struct Vec { +// using Type = bfloat4; +// }; +// template<> +// struct Vec { +// using Type = Bfloat8_; +// }; + +// template<> +// struct FloatVec { +// using Type = float; +// }; +// template<> +// struct FloatVec { +// using Type = float2; +// }; +// template<> +// struct FloatVec { +// using Type = float4; +// }; +// template<> +// struct FloatVec { +// using Type = Float8_; +// }; + +// template<> +// inline float mul(bfloat a, bfloat b) { +// return (float)a * (float)b; +// } +// template<> +// inline bfloat mul(bfloat a, bfloat b) { +// return a*b; +// } + +// template<> +// inline float2 mul(bfloat2 a, bfloat2 b) { +// return (float2)a * (float2)b; +// } +// template<> +// inline bfloat2 mul(bfloat2 a, bfloat2 b) { +// return a * b; +// } + +// template<> +// inline float4 mul(bfloat4 a, bfloat4 b) { +// return (float4)a * (float4)b; +// } +// template<> +// inline bfloat4 mul(bfloat4 a, bfloat4 b) { +// return a * b; +// } + +// template<> +// inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { +// Float8_ c; +// c.x = mul(a.x, b.x); +// c.y = mul(a.y, b.y); +// return c; +// } +// template<> +// inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { +// Bfloat8_ c; +// c.x = mul(a.x, b.x); +// c.y = mul(a.y, b.y); +// return c; +// } + +// template<> +// inline float sum(bfloat a) { +// return (float)a; +// } + +// template<> +// inline float sum(bfloat2 a) { +// return (float)a.x + (float)a.y; +// } + +// template<> +// inline float sum(bfloat4 a) { +// return sum(a.x) + sum(a.y); +// } + +// template<> +// inline float sum(Bfloat8_ a) { +// return sum(a.x) + sum(a.y); +// } + +// inline float fma(bfloat a, bfloat b, float c) { +// return (float)a * (float)b + c; +// } + +// inline float2 fma(bfloat2 a, bfloat2 b, float2 c) { +// return (float2)a * (float2)b + c; +// } + +// inline float4 fma(bfloat4 a, bfloat4 b, float4 c) { +// return (float4)a * (float4)b + c; +// } + +// inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) { +// Float8_ res; +// res.x = fma((float4)a.x, (float4)b.x, (float4)c.x); +// res.y = fma((float4)a.y, (float4)b.y, (float4)c.y); +// return res; +// } +// inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { +// Bfloat8_ res; +// res.x = (bfloat4)fma((float4)a.x, (float4)b.x, (float4)c.x); +// res.y = (bfloat4)fma((float4)a.y, (float4)b.x, (float4)c.y); +// return c; +// } + +// inline void from_float(thread bfloat& dst, float src) { +// dst = static_cast(src); +// } +// inline void from_float(thread bfloat2& dst, float2 src) { +// dst.x = static_cast(src.x); +// dst.y = static_cast(src.y); +// } +// inline void from_float(thread bfloat4& dst, float4 src) { +// dst.x = static_cast(src.x); +// dst.y = static_cast(src.y); +// dst.z = static_cast(src.z); +// dst.w = static_cast(src.w); +// } +// inline void from_float(thread Bfloat8_& dst, Float8_ src) { +// bfloat4 x; +// bfloat4 y; +// from_float(x, src.x); +// from_float(y, src.y); +// dst.x = x; +// dst.y = y; +// } + +// #else + +struct Bfloat2_ { + bfloat16_t x; + bfloat16_t y; +}; + +struct Bfloat4_ { + Bfloat2_ x; + Bfloat2_ y; +}; + +struct Bfloat8_ { + Bfloat4_ x; + Bfloat4_ y; +}; + +template <> struct Vec { + using Type = bfloat16_t; +}; +template <> struct Vec { + using Type = Bfloat2_; +}; +template <> struct Vec { + using Type = Bfloat4_; +}; +template <> struct Vec { + using Type = Bfloat8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(bfloat16_t a, bfloat16_t b) { + return (float)a * (float)b; +} +template <> inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { return a * b; } + +template <> inline float2 mul(Bfloat2_ a, Bfloat2_ b) { + float2 a_f((float)a.x, (float)a.y); + float2 b_f((float)b.x, (float)b.y); + return a_f * b_f; +} +template <> inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) { + Bfloat2_ c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> inline float4 mul(Bfloat4_ a, Bfloat4_ b) { + float2 x = mul(a.x, b.x); + float2 y = mul(a.y, b.y); + float4 c; + c.x = x.x; + c.y = x.y; + c.z = y.x; + c.w = y.y; + return c; +} +template <> inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) { + Bfloat4_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { + Float8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} +template <> inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { + Bfloat8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline float sum(bfloat16_t a) { return (float)a; } + +template <> inline float sum(Bfloat2_ a) { return (float)a.x + (float)a.y; } + +template <> inline float sum(Bfloat4_ a) { return sum(a.x) + sum(a.y); } + +template <> inline float sum(Bfloat8_ a) { return sum(a.x) + sum(a.y); } + +inline float fma(bfloat16_t a, bfloat16_t b, float c) { + return (float)a * (float)b + c; +} +inline bfloat16_t fma(bfloat16_t a, bfloat16_t b, bfloat16_t c) { + return a * b + c; +} + +inline float2 fma(Bfloat2_ a, Bfloat2_ b, float2 c) { + float2 a_f((float)a.x, (float)a.y); + float2 b_f((float)b.x, (float)b.y); + return a_f * b_f + c; +} +inline Bfloat2_ fma(Bfloat2_ a, Bfloat2_ b, Bfloat2_ c) { + Bfloat2_ res; + res.x = a.x * b.x + c.x; + res.y = a.y * b.y + c.y; + return res; +} + +inline float4 fma(Bfloat4_ a, Bfloat4_ b, float4 c) { + float4 res; + res.x = fma(a.x.x, b.x.x, c.x); + res.y = fma(a.x.y, b.x.y, c.y); + res.z = fma(a.y.x, b.y.x, c.z); + res.w = fma(a.y.y, b.y.y, c.w); + return res; +} +inline Bfloat4_ fma(Bfloat4_ a, Bfloat4_ b, Bfloat4_ c) { + Bfloat4_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) { + float4 x = fma(a.x, b.x, c.x); + float4 y = fma(a.y, b.y, c.y); + Float8_ res; + res.x = x; + res.y = y; + return res; +} +inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { + Bfloat8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread bfloat16_t &dst, float src) { + dst = static_cast(src); +} +inline void from_float(thread Bfloat2_ &dst, float2 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); +} +inline void from_float(thread Bfloat4_ &dst, float4 src) { + dst.x.x = static_cast(src.x); + dst.x.y = static_cast(src.y); + dst.y.x = static_cast(src.z); + dst.y.y = static_cast(src.w); +} +inline void from_float(thread Bfloat8_ &dst, Float8_ src) { + Bfloat4_ x; + Bfloat4_ y; + from_float(x, src.x); + from_float(y, src.y); + dst.x = x; + dst.y = y; +} + +// #endif + +// FP16 vector data types. +struct Half8_ { + half4 x; + half4 y; +}; + +template <> struct Vec { + using Type = half; +}; +template <> struct Vec { + using Type = half2; +}; +template <> struct Vec { + using Type = half4; +}; +template <> struct Vec { + using Type = Half8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(half a, half b) { return (float)a * (float)b; } +template <> inline half mul(half a, half b) { return a * b; } + +template <> inline float2 mul(half2 a, half2 b) { + return (float2)a * (float2)b; +} +template <> inline half2 mul(half2 a, half2 b) { return a * b; } + +template <> inline float4 mul(half4 a, half4 b) { + return (float4)a * (float4)b; +} +template <> inline half4 mul(half4 a, half4 b) { return a * b; } + +template <> inline Float8_ mul(Half8_ a, Half8_ b) { + float4 x = mul(a.x, b.x); + float4 y = mul(a.y, b.y); + Float8_ c; + c.x = x; + c.y = y; + return c; +} +template <> inline Half8_ mul(Half8_ a, Half8_ b) { + Half8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline float sum(half a) { return (float)a; } + +template <> inline float sum(half2 a) { return (float)a.x + (float)a.y; } + +template <> inline float sum(half4 a) { return a.x + a.y + a.z + a.w; } + +template <> inline float sum(Half8_ a) { return sum(a.x) + sum(a.y); } + +inline float fma(half a, half b, float c) { return (float)a * (float)b + c; } + +inline float2 fma(half2 a, half2 b, float2 c) { + return (float2)a * (float2)b + c; +} + +inline float4 fma(half4 a, half4 b, float4 c) { + return (float4)a * (float4)b + c; +} + +inline Float8_ fma(Half8_ a, Half8_ b, Float8_ c) { + float4 x = fma(a.x, b.x, c.x); + float4 y = fma(a.y, b.y, c.y); + Float8_ res; + res.x = x; + res.y = y; + return res; +} +inline Half8_ fma(Half8_ a, Half8_ b, Half8_ c) { + Half8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread half &dst, float src) { + dst = static_cast(src); +} +inline void from_float(thread half2 &dst, float2 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); +} +inline void from_float(thread half4 &dst, float4 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); + dst.z = static_cast(src.z); + dst.w = static_cast(src.w); +} +inline void from_float(thread Half8_ &dst, Float8_ src) { + half4 x; + half4 y; + from_float(x, src.x); + from_float(y, src.y); + dst.x = x; + dst.y = y; +} + +// ========================================== FP8 (uchar) vector data types. + +// 8‑lane uchar vector – Metal only provides up to uchar4, so build our own. +struct Uchar8_ { + uchar4 x; + uchar4 y; +}; + +// Vec specialisations so Vec::Type resolves correctly. +template <> struct Vec { + using Type = uchar; +}; +template <> struct Vec { + using Type = uchar2; +}; +template <> struct Vec { + using Type = uchar4; +}; +template <> struct Vec { + using Type = Uchar8_; +}; + +// General case: not uchar +template inline constexpr bool is_uchar() { return false; } + +// Specialization: T is uchar +template <> inline constexpr bool is_uchar() { return true; } + +// Generic fallback – will fail to compile if a required specialisation is +// missing. +template +inline Vec fp8_convert(const thread Quant_vec &, float scale) { + static_assert(sizeof(Vec) == 0, "Missing fp8_convert specialisation"); +} + +// ========================================== FP8 → float/half/bfloat +inline float __dequant_single(uchar v, float scale) { + return fp8_e4m3_to_float(v) * scale; +} + +// ---- 1‑lane ---- +template <> +inline float fp8_convert(const thread uchar &in, float scale) { + return __dequant_single(in, scale); +} +template <> +inline half fp8_convert(const thread uchar &in, float scale) { + return half(__dequant_single(in, scale)); +} +template <> +inline bfloat16_t fp8_convert(const thread uchar &in, + float scale) { + return bfloat16_t(__dequant_single(in, scale)); +} + +// ---- 2‑lane ---- +template <> +inline float2 fp8_convert(const thread uchar2 &in, + float scale) { + return float2(__dequant_single(in.x, scale), __dequant_single(in.y, scale)); +} +template <> +inline half2 fp8_convert(const thread uchar2 &in, float scale) { + half2 out; + out.x = half(__dequant_single(in.x, scale)); + out.y = half(__dequant_single(in.y, scale)); + return out; +} +template <> +inline Bfloat2_ fp8_convert(const thread uchar2 &in, + float scale) { + Bfloat2_ out; + out.x = bfloat16_t(__dequant_single(in.x, scale)); + out.y = bfloat16_t(__dequant_single(in.y, scale)); + return out; +} + +// ---- 4‑lane ---- +template <> +inline float4 fp8_convert(const thread uchar4 &in, + float scale) { + return float4(__dequant_single(in.x, scale), __dequant_single(in.y, scale), + __dequant_single(in.z, scale), __dequant_single(in.w, scale)); +} +template <> +inline half4 fp8_convert(const thread uchar4 &in, float scale) { + half4 out; + out.x = half(__dequant_single(in.x, scale)); + out.y = half(__dequant_single(in.y, scale)); + out.z = half(__dequant_single(in.z, scale)); + out.w = half(__dequant_single(in.w, scale)); + return out; +} +template <> +inline Bfloat4_ fp8_convert(const thread uchar4 &in, + float scale) { + Bfloat4_ out; + out.x.x = bfloat16_t(__dequant_single(in.x, scale)); + out.x.y = bfloat16_t(__dequant_single(in.y, scale)); + out.y.x = bfloat16_t(__dequant_single(in.z, scale)); + out.y.y = bfloat16_t(__dequant_single(in.w, scale)); + return out; +} + +// ---- 8‑lane ---- +template <> +inline Float8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Float8_ out; + out.x = + float4(__dequant_single(in.x.x, scale), __dequant_single(in.x.y, scale), + __dequant_single(in.x.z, scale), __dequant_single(in.x.w, scale)); + out.y = + float4(__dequant_single(in.y.x, scale), __dequant_single(in.y.y, scale), + __dequant_single(in.y.z, scale), __dequant_single(in.y.w, scale)); + return out; +} +template <> +inline Half8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Half8_ out; + out.x = half4(half(__dequant_single(in.x.x, scale)), + half(__dequant_single(in.x.y, scale)), + half(__dequant_single(in.x.z, scale)), + half(__dequant_single(in.x.w, scale))); + out.y = half4(half(__dequant_single(in.y.x, scale)), + half(__dequant_single(in.y.y, scale)), + half(__dequant_single(in.y.z, scale)), + half(__dequant_single(in.y.w, scale))); + return out; +} +template <> +inline Bfloat8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Bfloat8_ out; + // first 4 + out.x.x.x = bfloat16_t(__dequant_single(in.x.x, scale)); + out.x.x.y = bfloat16_t(__dequant_single(in.x.y, scale)); + out.x.y.x = bfloat16_t(__dequant_single(in.x.z, scale)); + out.x.y.y = bfloat16_t(__dequant_single(in.x.w, scale)); + // second 4 + out.y.x.x = bfloat16_t(__dequant_single(in.y.x, scale)); + out.y.x.y = bfloat16_t(__dequant_single(in.y.y, scale)); + out.y.y.x = bfloat16_t(__dequant_single(in.y.z, scale)); + out.y.y.y = bfloat16_t(__dequant_single(in.y.w, scale)); + return out; +} + +// ========================================== Dot product utilities + +// TODO(EricLBuehler): optimize with vectorization +template +inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) { + // Compute the parallel products for Q*K^T (treat vector lanes separately). + using A_vec = typename FloatVec::Type; + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += simd_shuffle_xor(qk, mask); + } + return qk; +} + +template struct Qk_dot { + template + static inline float dot(const threadgroup Vec (&q)[N], + const thread Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +// ========================================== Block sum utility + +// Utility function for attention softmax. +template +inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid, + uint simd_lid) { + // Compute the sum per simdgroup. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { + sum += simd_shuffle_xor(sum, mask); + } + + // Simd leaders store the data to shared memory. + if (simd_lid == 0) { + red_smem[simd_tid] = sum; + } + + // Make sure the data is in shared memory. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // The warps compute the final sums. + if (simd_lid < NUM_WARPS) { + sum = red_smem[simd_lid]; + } + + // Parallel reduction inside the simd group. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += simd_shuffle_xor(sum, mask); + } + + // Broadcast to other threads. + return simd_shuffle(sum, 0); +} + +// ========================================== Paged Attention kernel + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +constant bool use_partitioning [[function_constant(10)]]; +constant bool use_alibi [[function_constant(20)]]; +constant bool use_fp8_scales [[function_constant(30)]]; + +template +[[kernel]] void paged_attention( + device float *exp_sums + [[buffer(0)]], // [num_seqs, num_heads, max_num_partitions] - only used when + // use_partitioning + device float *max_logits + [[buffer(1)]], // [num_seqs, num_heads, max_num_partitions] - only used when + // use_partitioning + device T *out + [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size] + device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size] + device const CACHE_T *k_cache + [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x] + device const CACHE_T *v_cache + [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size] + const device float *__restrict__ k_scale + [[buffer(6)]], // [1] - only used when use_fp8_scales + const device float *__restrict__ v_scale + [[buffer(7)]], // [1] - only used when use_fp8_scales + const constant int &num_kv_heads [[buffer(8)]], // [num_heads] + const constant float &scale [[buffer(9)]], + const constant float &softcapping [[buffer(10)]], + device const uint32_t *block_tables + [[buffer(11)]], // [num_seqs, max_num_blocks_per_seq] + device const uint32_t *context_lens [[buffer(12)]], // [num_seqs] + const constant int &max_num_blocks_per_seq [[buffer(13)]], + device const float *alibi_slopes + [[buffer(14)]], // [num_heads] - only used when use_alibi + const constant int &q_stride [[buffer(15)]], + const constant int &kv_block_stride [[buffer(16)]], + const constant int &kv_head_stride [[buffer(17)]], + threadgroup char *shared_mem [[threadgroup(0)]], + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint3 threadgroups_per_grid [[threadgroups_per_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint simd_tid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int seq_idx = threadgroup_position_in_grid.y; + const int partition_idx = threadgroup_position_in_grid.z; + const int max_num_partitions = threadgroups_per_grid.z; + const int thread_idx = thread_position_in_threadgroup.x; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const uint32_t context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(NUM_SIMD_LANES / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, NUM_SIMD_LANES); + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + const int warp_idx = simd_tid; + const int lane = simd_lid; + + const int head_idx = threadgroup_position_in_grid.x; + const int num_heads = threadgroups_per_grid.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = !use_alibi ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the query, and the second thread has + // 1, 5, 9, ... th vectors of the query, and so on. + const device T *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Use fp32 on softmax logits for better accuracy + threadgroup float *logits = reinterpret_cast(shared_mem); + // Workspace for reduction + threadgroup float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(CACHE_T); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const device uint32_t *block_table = + block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE: The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the key, and the second thread has + // 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const device CACHE_T *k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (is_uchar()) { + // FP8 support + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8_convert(k_vec_quant, *k_scale); + } else { + // Non-FP8 default + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + + // Apply softcapping + if (softcapping != 1.0) { + qk = precise::tanh(qk / softcapping) * softcapping; + } + + // Add the ALiBi bias if slopes are given. + if (use_alibi && alibi_slope != 0) { + // Compute bias with explicit float precision to minimize precision loss + int position_offset = token_idx - int(context_len) + 1; + float alibi_bias = alibi_slope * float(position_offset); + qk += alibi_bias; + } + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE: It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : max(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = simd_shuffle(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = exp(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum, + simd_tid, simd_lid); + + // Compute softmax. + const float inv_sum = divide(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) { + device float *max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + device float *exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(T), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + using V_quant_vec = typename Vec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE: We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + T zero_value = 0; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE: The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + Float_L_vec logits_float_vec = *reinterpret_cast( + logits + token_idx - start_token_idx); + from_float(logits_vec, logits_float_vec); + + const device CACHE_T *v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + // NOTE: When v_vec contains the tokens that are out of the context, + // we should explicitly zero out the values since they may contain NaNs. + // See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + V_vec v_vec; + + if constexpr (is_uchar()) { + // FP8 support + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + v_vec = fp8_convert(v_quant_vec, *v_scale); + } else { + // Non-FP8 default + v_vec = *reinterpret_cast(v_ptr + offset); + } + + if (block_idx == num_context_blocks - 1) { + thread T *v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = + token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += simd_shuffle_xor(acc, mask); + } + accs[i] = acc; + } + + // NOTE: A barrier is required because the shared memory space for logits + // is reused for the output. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Perform reduction across warps. + threadgroup float *out_smem = + reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + threadgroup float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Lower warps update the output. + if (warp_idx < mid) { + const threadgroup float *src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write the final output. + if (warp_idx == 0) { + device T *out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + *(out_ptr + row_idx) = T(accs[i]); + } + } + } +} + +template +[[kernel]] void paged_attention_v2_reduce( + device T *out [[buffer(0)]], const device float *exp_sums [[buffer(1)]], + const device float *max_logits [[buffer(2)]], + const device T *tmp_out [[buffer(3)]], + device uint32_t *context_lens [[buffer(4)]], + const constant int &max_num_partitions [[buffer(5)]], + threadgroup char *shared_mem [[threadgroup(0)]], + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint3 threadgroups_per_grid [[threadgroups_per_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint3 threads_per_threadgroup [[threads_per_threadgroup]], + uint simd_tid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int num_heads = threadgroups_per_grid.x; + const int head_idx = threadgroup_position_in_grid.x; + const int seq_idx = threadgroup_position_in_grid.y; + const uint32_t context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + device T *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const device T *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += threads_per_threadgroup.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + const int warp_idx = simd_tid; + const int lane = simd_lid; + + // Workspace for reduction. + threadgroup float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + threadgroup float *shared_max_logits = + reinterpret_cast(shared_mem); + const device float *max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = max(max_logit, l); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { + max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = simd_shuffle(max_logit, 0); + + // Load rescaled exp sums to shared memory. + threadgroup float *shared_exp_sums = reinterpret_cast( + shared_mem + sizeof(float) * num_partitions); + const device float *exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * exp(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + global_exp_sum = block_sum( + &red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid); + const float inv_global_exp_sum = divide(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const device T *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + device T *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + out_ptr[i] = T(acc); + } +} + +#define instantiate_paged_attention_inner(type, cache_type, head_size, \ + block_size, num_threads, \ + num_simd_lanes, partition_size) \ + template [[host_name("paged_attention_" #type "_cache_" #cache_type \ + "_hs" #head_size "_bs" #block_size "_nt" #num_threads \ + "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention( \ + device float *exp_sums [[buffer(0)]], \ + device float *max_logits [[buffer(1)]], \ + device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \ + device const cache_type *k_cache [[buffer(4)]], \ + device const cache_type *v_cache [[buffer(5)]], \ + const device float *__restrict__ k_scale [[buffer(6)]], \ + const device float *__restrict__ v_scale [[buffer(7)]], \ + const constant int &num_kv_heads [[buffer(8)]], \ + const constant float &scale [[buffer(9)]], \ + const constant float &softcapping [[buffer(10)]], \ + device const uint32_t *block_tables [[buffer(11)]], \ + device const uint32_t *context_lens [[buffer(12)]], \ + const constant int &max_num_blocks_per_seq [[buffer(13)]], \ + device const float *alibi_slopes [[buffer(14)]], \ + const constant int &q_stride [[buffer(15)]], \ + const constant int &kv_block_stride [[buffer(16)]], \ + const constant int &kv_head_stride [[buffer(17)]], \ + threadgroup char *shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_v2_reduce_inner( \ + type, head_size, num_threads, num_simd_lanes, partition_size) \ + template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \ + "_nt" #num_threads "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention_v2_reduce( \ + device type * out [[buffer(0)]], \ + const device float *exp_sums [[buffer(1)]], \ + const device float *max_logits [[buffer(2)]], \ + const device type *tmp_out [[buffer(3)]], \ + device uint32_t *context_lens [[buffer(4)]], \ + const constant int &max_num_partitions [[buffer(5)]], \ + threadgroup char *shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint3 threads_per_threadgroup [[threads_per_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_heads( \ + type, cache_type, block_size, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_inner(type, cache_type, 32, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 64, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 80, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 96, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 112, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 120, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 128, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 192, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 256, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); + +#define instantiate_paged_attention_v2_reduce_heads( \ + type, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner(type, 32, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 64, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 80, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 96, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 112, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 120, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 128, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 192, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, \ + num_simd_lanes, partition_size); + +#define instantiate_paged_attention_block_size(type, cache_type, num_threads, \ + num_simd_lanes, partition_size) \ + instantiate_paged_attention_heads(type, cache_type, 8, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_heads(type, cache_type, 16, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_heads(type, cache_type, 32, num_threads, \ + num_simd_lanes, partition_size); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 0 +#define instantiate_paged_attention_v1(type, cache_type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, cache_type, 256, \ + num_simd_lanes, 0); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 512 +#define instantiate_paged_attention_v2(type, cache_type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, cache_type, 256, \ + num_simd_lanes, 512); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 512 +#define instantiate_paged_attention_v2_reduce(type, num_simd_lanes) \ + instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512); + +instantiate_paged_attention_v1(float, float, 32); +instantiate_paged_attention_v1(bfloat16_t, bfloat16_t, 32); +instantiate_paged_attention_v1(half, half, 32); + +instantiate_paged_attention_v1(float, uchar, 32); +instantiate_paged_attention_v1(bfloat16_t, uchar, 32); +instantiate_paged_attention_v1(half, uchar, 32); + +instantiate_paged_attention_v2_reduce(float, 32); +instantiate_paged_attention_v2_reduce(bfloat16_t, 32); +instantiate_paged_attention_v2_reduce(half, 32); + +instantiate_paged_attention_v2(float, float, 32); +instantiate_paged_attention_v2(bfloat16_t, bfloat16_t, 32); +instantiate_paged_attention_v2(half, half, 32); + +instantiate_paged_attention_v2(float, uchar, 32); +instantiate_paged_attention_v2(bfloat16_t, uchar, 32); +instantiate_paged_attention_v2(half, uchar, 32); diff --git a/src/parallax/metal/extensions/paged_attention/reshape_and_cache.metal b/src/parallax/metal/extensions/paged_attention/reshape_and_cache.metal new file mode 100644 index 00000000..e597c1e2 --- /dev/null +++ b/src/parallax/metal/extensions/paged_attention/reshape_and_cache.metal @@ -0,0 +1,193 @@ +#include "./utils.metal" +#include "./float8.metal" +#include + +using namespace metal; + +template +inline CACHE_T to_cache(KV_T v) = delete; + +template <> inline uchar to_cache(float v) { + return float_to_fp8_e4m3(v); +} + +template <> inline uchar to_cache(bfloat16_t v) { + return float_to_fp8_e4m3((float)v); +} + +template <> inline uchar to_cache(half v) { + return float_to_fp8_e4m3((float)v); +} + +template <> inline float to_cache(float v) { return v; } + +template <> inline bfloat16_t to_cache(bfloat16_t v) { + return v; +} + +template <> inline half to_cache(half v) { return v; } + +constant bool use_fp8_scales [[function_constant(10)]]; + +template +[[kernel]] void reshape_and_cache( + const device KV_T *__restrict__ key + [[buffer(0)]], // [num_tokens, num_heads, head_size] + const device KV_T *__restrict__ value + [[buffer(1)]], // [num_tokens, num_heads, head_size] + device CACHE_T *__restrict__ key_cache + [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x] + device CACHE_T *__restrict__ value_cache + [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size] + const device int64_t *__restrict__ slot_mapping + [[buffer(4)]], // [num_tokens] + const device float *__restrict__ k_scale + [[buffer(5)]], // [1] - only used when use_fp8_scales + const device float *__restrict__ v_scale + [[buffer(6)]], // [1] - only used when use_fp8_scales + device const int &key_stride [[buffer(7)]], + device const int &value_stride [[buffer(8)]], + device const int &num_heads [[buffer(9)]], + device const int &head_size [[buffer(10)]], + device const int &block_size [[buffer(11)]], + device const int &x [[buffer(12)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup [[threads_per_threadgroup]]) { + const int64_t token_idx = gid; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = tid; i < n; i += threads_per_threadgroup) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; + + if (use_fp8_scales) { + key_cache[tgt_key_idx] = + to_cache(KV_T((float)key[src_key_idx] / *k_scale)); + value_cache[tgt_value_idx] = + to_cache(KV_T((float)value[src_value_idx] / *v_scale)); + } else { + key_cache[tgt_key_idx] = to_cache(key[src_key_idx]); + value_cache[tgt_value_idx] = to_cache(value[src_value_idx]); + } + } +} + +#define instantiate_reshape_and_cache(kv_type, cache_type) \ + template [[host_name("reshape_and_cache_kv_" #kv_type \ + "_cache_" #cache_type)]] [[kernel]] void \ + reshape_and_cache( \ + const device kv_type *__restrict__ key [[buffer(0)]], \ + const device kv_type *__restrict__ value [[buffer(1)]], \ + device cache_type *__restrict__ key_cache [[buffer(2)]], \ + device cache_type *__restrict__ value_cache [[buffer(3)]], \ + const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \ + const device float *__restrict__ k_scale [[buffer(5)]], \ + const device float *__restrict__ v_scale [[buffer(6)]], \ + device const int &key_stride [[buffer(7)]], \ + device const int &value_stride [[buffer(8)]], \ + device const int &num_heads [[buffer(9)]], \ + device const int &head_size [[buffer(10)]], \ + device const int &block_size [[buffer(11)]], \ + device const int &x [[buffer(12)]], \ + uint gid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_reshape_and_cache(float, float); +instantiate_reshape_and_cache(bfloat16_t, bfloat16_t); +instantiate_reshape_and_cache(half, half); + +instantiate_reshape_and_cache(float, uchar); +instantiate_reshape_and_cache(bfloat16_t, uchar); +instantiate_reshape_and_cache(half, uchar); + +// Flash version with different cache layout: [num_blocks, block_size, +// num_heads, head_size] +template +[[kernel]] void reshape_and_cache_flash( + const device T *__restrict__ key + [[buffer(0)]], // [num_tokens, num_heads, head_size] + const device T *__restrict__ value + [[buffer(1)]], // [num_tokens, num_heads, head_size] + device T *__restrict__ key_cache + [[buffer(2)]], // [num_blocks, block_size, num_heads, head_size] + device T *__restrict__ value_cache + [[buffer(3)]], // [num_blocks, block_size, num_heads, head_size] + const device int64_t *__restrict__ slot_mapping + [[buffer(4)]], // [num_tokens] + device const int &key_stride, device const int &value_stride, + device const int &num_heads, device const int &head_size, + device const int &block_size, uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup [[threads_per_threadgroup]]) { + const int64_t token_idx = gid; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = tid; i < n; i += threads_per_threadgroup) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + + // Flash cache layout: [num_blocks, block_size, num_heads, head_size] + const int64_t tgt_key_idx = block_idx * block_size * num_heads * head_size + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; + const int64_t tgt_value_idx = + block_idx * block_size * num_heads * head_size + + block_offset * num_heads * head_size + head_idx * head_size + + head_offset; + key_cache[tgt_key_idx] = key[src_key_idx]; + value_cache[tgt_value_idx] = value[src_value_idx]; + } +} + +#define instantiate_reshape_and_cache_flash(type) \ + template [[host_name("reshape_and_cache_flash_" #type)]] [[kernel]] void \ + reshape_and_cache_flash( \ + const device type *__restrict__ key [[buffer(0)]], \ + const device type *__restrict__ value [[buffer(1)]], \ + device type *__restrict__ key_cache [[buffer(2)]], \ + device type *__restrict__ value_cache [[buffer(3)]], \ + const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \ + device const int &key_stride, device const int &value_stride, \ + device const int &num_heads, device const int &head_size, \ + device const int &block_size, uint gid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_reshape_and_cache_flash(float); +instantiate_reshape_and_cache_flash(bfloat16_t); +instantiate_reshape_and_cache_flash(half); diff --git a/src/parallax/metal/extensions/paged_attention/utils.metal b/src/parallax/metal/extensions/paged_attention/utils.metal new file mode 100644 index 00000000..b93bd718 --- /dev/null +++ b/src/parallax/metal/extensions/paged_attention/utils.metal @@ -0,0 +1,238 @@ +#include +using namespace metal; + +#if defined(__HAVE_BFLOAT__) + +typedef bfloat bfloat16_t; + +#else + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + template >::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + template >::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + template >::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + template >::type> + constexpr METAL_FUNC operator T() constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \ + _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \ + float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \ + _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +#endif diff --git a/src/parallax/metal/extensions/setup.py b/src/parallax/metal/extensions/setup.py new file mode 100755 index 00000000..e8e68b78 --- /dev/null +++ b/src/parallax/metal/extensions/setup.py @@ -0,0 +1,18 @@ +# Copyright © 2023-2024 Apple Inc. + +from setuptools import setup + +from mlx import extension + +if __name__ == "__main__": + setup( + name="ops", + version="0.0.0", + description="Metal op extensions.", + ext_modules=[extension.CMakeExtension("ops._ext")], + cmdclass={"build_ext": extension.CMakeBuild}, + packages=["ops"], + package_data={"ops": ["*.so", "*.dylib", "*.metallib"]}, + zip_safe=False, + python_requires=">=3.10", + )