Skip to content

Commit 73c8cb3

Browse files
support gpt-oss GPU by OP add-id, mul_mat for mxfp4, swiglu_oai, fix warning
1 parent d9e03db commit 73c8cb3

File tree

13 files changed

+399
-16
lines changed

13 files changed

+399
-16
lines changed

ggml/src/ggml-sycl/add-id.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include <sycl/sycl.hpp>
2+
#include "common.hpp"
3+
#include "add-id.hpp"
4+
5+
static void add_id_kernel(
6+
const float* src0,
7+
const float* src1,
8+
const int32_t* src2,
9+
float* dst,
10+
int64_t ne0,
11+
int64_t ne1,
12+
size_t nb01,
13+
size_t nb02,
14+
size_t nb11,
15+
size_t nb21,
16+
sycl::nd_item<3> item_ct1) {
17+
const int64_t i1 = item_ct1.get_group(2);
18+
const int64_t i2 = item_ct1.get_group(1);
19+
20+
const int i11 =
21+
*(const int32_t*)((const char*)src2 + i1 * sizeof(int32_t) + i2 * nb21);
22+
23+
const size_t nb1 = ne0 * sizeof(float);
24+
const size_t nb2 = ne1 * nb1;
25+
26+
float* dst_row = (float*)((char*)dst + i1 * nb1 + i2 * nb2);
27+
const float* src0_row =
28+
(const float*)((const char*)src0 + i1 * nb01 + i2 * nb02);
29+
const float* src1_row = (const float*)((const char*)src1 + i11 * nb11);
30+
31+
for (int64_t i0 = item_ct1.get_local_id(2); i0 < ne0;
32+
i0 += item_ct1.get_local_range(2)) {
33+
dst_row[i0] = src0_row[i0] + src1_row[i0];
34+
}
35+
}
36+
37+
void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
38+
const ggml_tensor* src0 = dst->src[0];
39+
const ggml_tensor* src1 = dst->src[1];
40+
const ggml_tensor* src2 = dst->src[2];
41+
42+
GGML_TENSOR_TERNARY_OP_LOCALS
43+
44+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
45+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
46+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
47+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
48+
49+
GGML_ASSERT(nb00 == sizeof(float));
50+
GGML_ASSERT(nb10 == sizeof(float));
51+
GGML_ASSERT(nb20 == sizeof(int32_t));
52+
53+
const float* src0_d = (const float*)src0->data;
54+
const float* src1_d = (const float*)src1->data;
55+
const int32_t* src2_d = (const int32_t*)src2->data;
56+
float* dst_d = (float*)dst->data;
57+
58+
int threads = std::min((int)ne00, 768); // cols
59+
ctx.stream()->parallel_for(
60+
sycl::nd_range<3>(
61+
sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),
62+
sycl::range<3>(1, 1, threads)),
63+
[=](sycl::nd_item<3> item_ct1) {
64+
add_id_kernel(
65+
src0_d,
66+
src1_d,
67+
src2_d,
68+
dst_d,
69+
ne0,
70+
ne1,
71+
nb01,
72+
nb02,
73+
nb11,
74+
nb21,
75+
item_ct1);
76+
});
77+
}

ggml/src/ggml-sycl/add-id.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_ADD_ID_HPP
2+
#define GGML_SYCL_ADD_ID_HPP
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_add_id(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7+
8+
#endif // GGML_SYCL_ADD_ID_HPP

ggml/src/ggml-sycl/common.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,5 +642,22 @@ static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3
642642
return sycl::uint2(div_val, mod_val);
643643
}
644644

645+
static __dpct_inline__ int ggml_sycl_dp4a(const int a, const int b, int c) {
646+
return dpct::dp4a(a, b, c);
647+
}
648+
649+
static __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) {
650+
uint32_t bits;
651+
if (x == 0) {
652+
bits = 0x00400000;
653+
} else {
654+
bits = (uint32_t) x << 23;
655+
}
656+
657+
float result;
658+
memcpy(&result, &bits, sizeof(float));
659+
return result;
660+
}
661+
645662

646663
#endif // GGML_SYCL_COMMON_HPP

ggml/src/ggml-sycl/convert.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,16 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
472472
}
473473
}
474474

475+
template <typename dst_t>
476+
static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
477+
const int nb = (k + QK_K - 1) / QK_K;
478+
stream->parallel_for(
479+
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
480+
[=](sycl::nd_item<3> item_ct1) {
481+
dequantize_block_mxfp4(vx, y, item_ct1);
482+
});
483+
}
484+
475485
template <typename src_t, typename dst_t>
476486
static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
477487
const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
@@ -518,6 +528,7 @@ static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct
518528
convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
519529
}
520530

531+
521532
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
522533
switch (type) {
523534
case GGML_TYPE_Q4_0:
@@ -571,6 +582,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
571582
return dequantize_row_iq4_xs_sycl;
572583
case GGML_TYPE_IQ4_NL:
573584
return dequantize_row_iq4_nl_sycl;
585+
case GGML_TYPE_MXFP4:
586+
return dequantize_row_mxfp4_sycl;
574587
case GGML_TYPE_F32:
575588
return convert_unary_sycl<float>;
576589
#ifdef GGML_SYCL_HAS_BF16
@@ -636,6 +649,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
636649
return dequantize_row_iq4_xs_sycl;
637650
case GGML_TYPE_IQ4_NL:
638651
return dequantize_row_iq4_nl_sycl;
652+
case GGML_TYPE_MXFP4:
653+
return dequantize_row_mxfp4_sycl;
639654
case GGML_TYPE_F16:
640655
return convert_unary_sycl<sycl::half>;
641656
#ifdef GGML_SYCL_HAS_BF16

ggml/src/ggml-sycl/dequantize.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,5 +819,23 @@ dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
819819
}
820820
}
821821

822+
template<typename dst_t>
823+
static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy,
824+
const sycl::nd_item<3> &item_ct1) {
825+
// auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
826+
const int64_t i = item_ct1.get_group(2);
827+
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
828+
829+
const int64_t tid = item_ct1.get_local_id(2);
830+
const int64_t il = tid/8; // 0...3
831+
const int64_t ib = tid%8; // 0...7
832+
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
833+
const uint8_t * q4 = x[ib].qs + 4*il;
834+
const float d = ggml_sycl_e8m0_to_fp32(x[ib].e);
835+
for (int j = 0; j < 4; ++j) {
836+
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
837+
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
838+
}
839+
}
822840

823841
#endif // GGML_SYCL_DEQUANTIZE_HPP

ggml/src/ggml-sycl/dpct/helper.hpp

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,10 +1860,31 @@ namespace dpct
18601860
: id);
18611861
}
18621862

1863+
template <typename T1, typename T2>
1864+
using dot_product_acc_t = std::conditional_t<
1865+
std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
1866+
uint32_t,
1867+
int32_t>;
1868+
1869+
template <typename T>
1870+
sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val) {
1871+
return sycl::vec<T, 1>(val)
1872+
.template as<sycl::vec<
1873+
std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>,
1874+
4>>()
1875+
.template convert<T>();
1876+
}
1877+
18631878
template <typename T1, typename T2, typename T3>
1864-
inline auto dp4a(T1 a, T2 b, T3 c)
1865-
{
1866-
return syclcompat::dp4a(a, b, c);
1879+
inline auto dp4a(T1 a, T2 b, T3 c) {
1880+
dot_product_acc_t<T1, T2> res = c;
1881+
auto va = extract_and_sign_or_zero_extend4(a);
1882+
auto vb = extract_and_sign_or_zero_extend4(b);
1883+
res += va[0] * vb[0];
1884+
res += va[1] * vb[1];
1885+
res += va[2] * vb[2];
1886+
res += va[3] * vb[3];
1887+
return res;
18671888
}
18681889

18691890
struct sub_sat
@@ -2972,6 +2993,38 @@ namespace dpct
29722993
atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
29732994
}
29742995

2996+
inline unsigned int byte_level_permute(
2997+
unsigned int a, unsigned int b, unsigned int s) {
2998+
unsigned int ret;
2999+
ret = ((((std::uint64_t)b << 32 | a) >> (s & 0x7) * 8) & 0xff) |
3000+
(((((std::uint64_t)b << 32 | a) >> ((s >> 4) & 0x7) * 8) & 0xff)
3001+
<< 8) |
3002+
(((((std::uint64_t)b << 32 | a) >> ((s >> 8) & 0x7) * 8) & 0xff)
3003+
<< 16) |
3004+
(((((std::uint64_t)b << 32 | a) >> ((s >> 12) & 0x7) * 8) & 0xff)
3005+
<< 24);
3006+
return ret;
3007+
}
3008+
3009+
inline uint32_t byte_level_permute_custom(
3010+
uint32_t low32, uint32_t high32, uint32_t sel, int mode = 0) {
3011+
constexpr uint16_t lookup[6][4] = {
3012+
{0x3210, 0x4321, 0x5432, 0x6543}, // Forward 4-byte extract
3013+
{0x5670, 0x6701, 0x7012, 0x0123}, // Backward 4-byte extract
3014+
{0x0000, 0x1111, 0x2222, 0x3333}, // Replicate 8-bit values
3015+
{0x3210, 0x3211, 0x3222, 0x3333}, // Edge clamp left
3016+
{0x0000, 0x1110, 0x2210, 0x3210}, // Edge clamp right
3017+
{0x1010, 0x3232, 0x1010, 0x3232} // Replicate 16-bit values
3018+
};
3019+
3020+
if (mode >= 1 && mode <= 6) {
3021+
return byte_level_permute(low32, high32, lookup[mode - 1][sel & 0x3]);
3022+
} else if (!mode) {
3023+
return byte_level_permute(low32, high32, sel);
3024+
}
3025+
return 0;
3026+
}
3027+
29753028
} // COPY from DPCT head files
29763029

29773030
#endif // GGML_SYCL_DPCT_HELPER_HPP

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,98 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
911911
});
912912
}
913913

914+
__dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
915+
x = sycl::fmin(x, limit);
916+
g = sycl::fmax(sycl::fmin(g, limit), -limit);
917+
918+
float out_glu = x / (1.0f + sycl::native::exp(-x * alpha));
919+
out_glu = out_glu * (1.0f + g);
920+
return out_glu;
921+
}
922+
923+
924+
template <typename T>
925+
static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
926+
const int64_t n, const int64_t o0, const int64_t o1,
927+
float alpha, float limit, sycl::nd_item<3> item_ct1) {
928+
const int64_t i = int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
929+
930+
if (i >= k) {
931+
return;
932+
}
933+
934+
const int64_t j0 = (i / n) * o0 + (i % n);
935+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
936+
937+
float xi = x[j0];
938+
float gi = g[j1];
939+
940+
dst[i] = ggml_sycl_op_swiglu_oai_single(xi, gi, alpha, limit);
941+
}
942+
943+
template <typename T>
944+
static void swiglu_oai_sycl(const T * x,
945+
const T * g,
946+
T * dst,
947+
const int64_t k,
948+
const int64_t n,
949+
const int64_t o0,
950+
const int64_t o1,
951+
const float alpha,
952+
const float limit,
953+
dpct::queue_ptr stream) {
954+
const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
955+
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
956+
sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
957+
[=](sycl::nd_item<3> item_ct1) {
958+
swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
959+
});
960+
}
961+
962+
void ggml_sycl_op_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
963+
const ggml_tensor * src0 = dst->src[0];
964+
const ggml_tensor * src1 = dst->src[1];
965+
void * src0_d = src0->data;
966+
void * src1_d = src1 ? src1->data : src0->data;
967+
const int64_t src0_o = src0->nb[1];
968+
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
969+
void * dst_d = dst->data;
970+
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
971+
dpct::queue_ptr stream = ctx.stream();
972+
973+
GGML_ASSERT(ggml_is_contiguous_1(src0));
974+
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
975+
GGML_ASSERT(ggml_is_contiguous(dst));
976+
977+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
978+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
979+
GGML_ASSERT(src0->type == dst->type);
980+
GGML_ASSERT(dst->ne[0] == nc);
981+
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
982+
983+
if (src1) {
984+
GGML_ASSERT(ggml_is_contiguous_1(src1));
985+
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
986+
GGML_ASSERT(src1->ne[0] == nc);
987+
GGML_ASSERT(src0->type == src1->type);
988+
}
989+
990+
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
991+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
992+
const float alpha = ggml_get_op_params_f32(dst, 2);
993+
const float limit = ggml_get_op_params_f32(dst, 3);
994+
995+
float * src0_p = (float *) src0_d;
996+
float * src1_p = (float *) src1_d;
997+
998+
if (!src1) {
999+
src0_p += swapped ? nc : 0;
1000+
src1_p += swapped ? 0 : nc;
1001+
}
1002+
1003+
swiglu_oai_sycl(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
1004+
}
1005+
9141006
static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
9151007
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
9161008
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
@@ -1070,6 +1162,11 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
10701162
ggml_sycl_op_swiglu(ctx, dst);
10711163
}
10721164

1165+
void ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1166+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1167+
ggml_sycl_op_swiglu_oai(ctx, dst);
1168+
}
1169+
10731170
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
10741171
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
10751172
ggml_sycl_op_geglu_erf(ctx, dst);

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "ggml.h"
66
#include <limits> // For std::numeric_limits
77

8+
#define SYCL_GLU_BLOCK_SIZE 256
9+
810
template <typename T>
911
T neg_infinity() {
1012
return -std::numeric_limits<T>::infinity();
@@ -41,6 +43,8 @@ void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
4143

4244
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
4345

46+
void ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
47+
4448
void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
4549

4650
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)