Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/parallax/metal/extensions/CMakelists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
cmake_minimum_required(VERSION 3.27)

project(_ext LANGUAGES CXX)

# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)

# ----------------------------- Dependencies -----------------------------
find_package(
Python 3.10
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)

execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)

# ----------------------------- Extensions -----------------------------

# Add library
add_library(parallax_ext)

# Add sources
target_sources(
parallax_ext
PUBLIC ${CMAKE_CURRENT_LIST_DIR}/paged_attention/paged_attention.cpp)

# Add include headers
target_include_directories(parallax_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})

# Link to mlx
target_link_libraries(parallax_ext PUBLIC mlx)

# ----------------------------- Metal -----------------------------

# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET
parallax_ext_metallib
TITLE
parallax_ext
SOURCES
${CMAKE_CURRENT_LIST_DIR}/paged_attention/utils.metal
${CMAKE_CURRENT_LIST_DIR}/paged_attention/float8.metal
${CMAKE_CURRENT_LIST_DIR}/paged_attention/paged_attention.metal
${CMAKE_CURRENT_LIST_DIR}/paged_attention/reshape_and_cache.metal
INCLUDE_DIRS
${PROJECT_SOURCE_DIR}
${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})

add_dependencies(parallax_ext parallax_ext_metallib)

endif()

# ----------------------------- Python Bindings -----------------------------
nanobind_add_module(
_ext
NB_STATIC
STABLE_ABI
LTO
NOMINSIZE
NB_DOMAIN
mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
target_link_libraries(_ext PRIVATE parallax_ext)

if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()
43 changes: 43 additions & 0 deletions src/parallax/metal/extensions/bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/variant.h>

#include "paged_attention/paged_attention.h"

namespace nb = nanobind;
using namespace nb::literals;

NB_MODULE(_ext, m) {
m.doc() = "vLLM PagedAttentionV1";

m.def(
"PagedAttentionV1",
&parallax_ext::paged_attention_v1,
"query"_a,
"key_cache"_a,
"value_cache"_a,
"block_tables"_a,
"seq_lens"_a,
"num_kv_heads"_a,
"block_size"_a,
"max_seq_len"_a,
"scale"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
vLLM PagedAttentionV1 operation

Args:
query (array): Input array [num_seqs, num_heads, head_size].
key_cache (array): Input array [num_blocks, num_heads, head_size/x, block_size, x].
value_cache (array): Input array [num_blocks, num_heads, head_size, block_size].
block_tables (array): Input array [num_seqs, max_num_blocks_per_seq].
seq_lens (array): Input array [num_seqs].
num_kv_heads (int): Input parameter.
block_size (int): Input parameter.
max_seq_len (int): Input parameter.
scale (float): Input parameter.

Returns:
array: ``Paged attention result``
)");
}
Empty file.
122 changes: 122 additions & 0 deletions src/parallax/metal/extensions/paged_attention/float8.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include <metal_stdlib>
using namespace metal;

// Helpers ------------------------------------------------------------
static inline uint as_bits(float x) { return as_type<uint>(x); }
static inline float from_bits(uint b) { return as_type<float>(b); }

// -------------------------------------------------------------------
// FP8 E4M3 (bias = 7)
// -------------------------------------------------------------------
inline float fp8_e4m3_to_float(uchar v) {
const uint s = v >> 7;
const uint exp = (v >> 3) & 0xF;
const uint man = v & 0x7;

if (exp == 0) { // zero / sub-normal
if (man == 0)
return s ? -0.f : 0.f;
const float m = float(man) / 8.f; // already scaled by 2^-3
float val = ldexp(m, 1 - 7); // 2^(1-bias) = 2^-6
return s ? -val : val;
}

if (exp == 0xF) { // Inf / NaN (E4M3FN keeps only NaN)
if (man != 0)
return NAN;
return s ? -INFINITY : INFINITY;
}

const float m = 1.f + float(man) / 8.f;
float val = ldexp(m, int(exp) - 7);
return s ? -val : val;
}

// -------------------------------------------------------------------
// FP8 E5M2 (bias = 15)
// -------------------------------------------------------------------
inline float fp8_e5m2_to_float(uchar v) {
const uint s = v >> 7;
const uint exp = (v >> 2) & 0x1F;
const uint man = v & 0x3;

if (exp == 0) {
if (man == 0)
return s ? -0.f : 0.f;
const float m = float(man) / 4.f;
float val = ldexp(m, 1 - 15); // 2^(1-bias) = 2^-14
return s ? -val : val;
}

if (exp == 0x1F) {
if (man != 0)
return NAN;
return s ? -INFINITY : INFINITY;
}

const float m = 1.f + float(man) / 4.f;
float val = ldexp(m, int(exp) - 15);
return s ? -val : val;
}

// -------------------------------------------------------------------
// Encoding helpers (round-to-nearest-even, gradual under-flow, sat-to-∞)
// -------------------------------------------------------------------
namespace detail {
template <int EXP_BITS, int MAN_BITS, int BIAS>
inline uchar fp32_to_fp8(float f) {
const uint bits = as_bits(f);
const uint s = bits >> 31;
const uint abs = bits & 0x7FFFFFFF;

// NaN propagates, Inf saturates
if (abs >= 0x7F800000u) {
return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS) |
(abs != 0x7F800000u));
}

int e = int((abs >> 23) & 0xFF) - 127; // unbiased exponent
uint m = abs & 0x7FFFFFu; // 23-bit mantissa
const int EXP_MAX = (1 << EXP_BITS) - 2; // last finite exponent

// ---------- Normal path -------------------------------------------------
int e_fp8 = e + BIAS;
if (e_fp8 >= 1 && e_fp8 <= EXP_MAX) {
// round-to-nearest-even
const int shift = 23 - MAN_BITS;
uint mant = m >> shift;
const uint lsb = mant & 1u;
const uint round = (m >> (shift - 1)) & 1u;
const uint sticky = (m & ((1u << (shift - 1)) - 1u)) != 0u;
mant += (round & (sticky | lsb));
if (mant >> MAN_BITS) { // mantissa overflow
mant = 0;
++e_fp8;
if (e_fp8 > EXP_MAX)
return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS)); // ∞
}
return uchar((s << 7) | (uint(e_fp8) << MAN_BITS) |
(mant & ((1u << MAN_BITS) - 1u)));
}

// ---------- Sub-normal / under-flow ------------------------------------
if (e_fp8 < 1 - MAN_BITS) // too small -> Β±0
return uchar(s << 7);

// shift so that exponent becomes 1
int rshift = (1 - e_fp8) + (23 - MAN_BITS);
uint mant = (0x800000u | m); // implicit 1
uint rounded = (mant + (1u << (rshift - 1))) >> rshift;
if (rounded == 0)
return uchar(s << 7); // rounds to zero

return uchar((s << 7) | (rounded & ((1u << MAN_BITS) - 1u)));
}
} // namespace detail

inline uchar float_to_fp8_e4m3(float f) {
return detail::fp32_to_fp8<4, 3, 7>(f);
}
inline uchar float_to_fp8_e5m2(float f) {
return detail::fp32_to_fp8<5, 2, 15>(f);
}
126 changes: 126 additions & 0 deletions src/parallax/metal/extensions/paged_attention/paged_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include <iostream>
#include <sstream>
#include <string>

#include "paged_attention.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"

namespace parallax_ext {

mx::array paged_attention_v1(
const mx::array& query, // [num_seqs, num_heads, head_size]
const mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const mx::array& value_cache, // [num_blocks, num_heads, head_size, block_size]
const mx::array& block_tables, // [num_seqs, max_num_blocks_per_seq]
const mx::array& seq_lens, // [num_seqs]
const int64_t num_kv_heads,
const int64_t block_size,
const int64_t max_seq_len,
const float scale,
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
out_dtype = query.dtype();
out_shape = query.shape();
const std::vector<mx::array> inputs = {query, key_cache, value_cache, block_tables, seq_lens};
// Construct the array as the output of the PagedAttentionV1 primitive
return mx::array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<PagedAttentionV1>(to_stream(s), num_kv_heads, block_size, max_seq_len, scale),
/* const std::vector<array>& inputs = */ inputs);
}

/** Evaluate primitive on GPU */
void PagedAttentionV1::eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Prepare inputs
assert(inputs.size() == 5);
auto& q = inputs[0];
auto& k = inputs[1];
auto& v = inputs[2];
auto& block_tables = inputs[3];
auto& seq_lens = inputs[4];
auto& out = outputs[0];

// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);

// Allocate output memory
out.set_data(allocator::malloc(out.nbytes()));

// Set kernel paramas
const int num_threads = 256;
const int num_simd_lanes = 32;
const int partition_size = 0; // v1 doesn't use partitioning

// Resolve name of kernel
std::string kname;
kname = "paged_attention_" + type_to_name(out);
kname += "_cache_" + type_to_name(k);
kname += "_hs" + std::to_string(num_kv_heads_);
kname += "_bs" + std::to_string(block_size_);
kname += "_nt" + std::to_string(num_threads);
kname += "_nsl" + std::to_string(num_simd_lanes);
kname += "_ps" + std::to_string(partition_size);

// Load the metal library
auto lib = d.get_library("parallax_ext", current_binary_dir());

// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);

// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);

// Calculate parameters
float softcapping_ = 1.0; // hard code for not use
const int64_t num_seqs = q.shape(0);
const int64_t num_heads = q.shape(1);
const int64_t max_num_blocks_per_seq = block_tables.shape(1);
int32_t q_stride = static_cast<int32_t>(q.strides(0));
int32_t kv_block_stride = static_cast<int32_t>(k.strides(0));
int32_t kv_head_stride = static_cast<int32_t>(k.strides(1));

// Encode arrays to kernel
// Skip exp_sums and max_logits for v1 (buffers 0, 1)
compute_encoder.set_output_array(out, 2);
compute_encoder.set_input_array(q, 3);
compute_encoder.set_input_array(k, 4);
compute_encoder.set_input_array(v, 5);
// Skip k_scale and v_scale for non-fp8 (buffers 6, 7)
compute_encoder.set_bytes(num_kv_heads_, 8);
compute_encoder.set_bytes(scale_, 9);
compute_encoder.set_bytes(softcapping_, 10);
compute_encoder.set_input_array(block_tables, 11);
compute_encoder.set_input_array(seq_lens, 12);
compute_encoder.set_bytes(max_num_blocks_per_seq, 13);
// Skip alibi_slopes (buffer 14)
compute_encoder.set_bytes(q_stride, 14);
compute_encoder.set_bytes(kv_block_stride, 15);
compute_encoder.set_bytes(kv_head_stride, 16);

// Dispatch configuration
// Grid: (num_heads, num_seqs, 1) - no partitioning for v1
MTL::Size grid = MTLSizeMake(num_heads, num_seqs, 1);
MTL::Size threadgroup = MTLSizeMake(num_threads, 1, 1);

// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid, threadgroup);
}

/** Equivalence check **/
bool PagedAttentionV1::is_equivalent(const mx::Primitive& other) const {
const PagedAttentionV1& r_other = static_cast<const PagedAttentionV1&>(other);
return num_kv_heads_ == r_other.num_kv_heads_ && block_size_ == r_other.block_size_ &&
max_seq_len_ == r_other.max_seq_len_ && scale_ == r_other.scale_;
}

} // namespace parallax_ext
Loading
Loading