@@ -911,6 +911,98 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
911911 });
912912}
913913
914+ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single (float x, float g, float alpha = 1 .702f , float limit = 7 .0f ) {
915+ x = sycl::fmin (x, limit);
916+ g = sycl::fmax (sycl::fmin (g, limit), -limit);
917+
918+ float out_glu = x / (1 .0f + sycl::native::exp (-x * alpha));
919+ out_glu = out_glu * (1 .0f + g);
920+ return out_glu;
921+ }
922+
923+
924+ template <typename T>
925+ static void swiglu_oai_kernel (const T * x, const T * g, T * dst, const int64_t k,
926+ const int64_t n, const int64_t o0, const int64_t o1,
927+ float alpha, float limit, sycl::nd_item<3 > item_ct1) {
928+ const int64_t i = int64_t (item_ct1.get_local_range (2 )) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
929+
930+ if (i >= k) {
931+ return ;
932+ }
933+
934+ const int64_t j0 = (i / n) * o0 + (i % n);
935+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
936+
937+ float xi = x[j0];
938+ float gi = g[j1];
939+
940+ dst[i] = ggml_sycl_op_swiglu_oai_single (xi, gi, alpha, limit);
941+ }
942+
943+ template <typename T>
944+ static void swiglu_oai_sycl (const T * x,
945+ const T * g,
946+ T * dst,
947+ const int64_t k,
948+ const int64_t n,
949+ const int64_t o0,
950+ const int64_t o1,
951+ const float alpha,
952+ const float limit,
953+ dpct::queue_ptr stream) {
954+ const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1 ) / SYCL_GLU_BLOCK_SIZE;
955+ stream->parallel_for (sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , SYCL_GLU_BLOCK_SIZE),
956+ sycl::range<3 >(1 , 1 , SYCL_GLU_BLOCK_SIZE)),
957+ [=](sycl::nd_item<3 > item_ct1) {
958+ swiglu_oai_kernel (x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
959+ });
960+ }
961+
962+ void ggml_sycl_op_swiglu_oai (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
963+ const ggml_tensor * src0 = dst->src [0 ];
964+ const ggml_tensor * src1 = dst->src [1 ];
965+ void * src0_d = src0->data ;
966+ void * src1_d = src1 ? src1->data : src0->data ;
967+ const int64_t src0_o = src0->nb [1 ];
968+ const int64_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
969+ void * dst_d = dst->data ;
970+ const int64_t nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
971+ dpct::queue_ptr stream = ctx.stream ();
972+
973+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
974+ GGML_ASSERT (src0->nb [0 ] == ggml_element_size (src0));
975+ GGML_ASSERT (ggml_is_contiguous (dst));
976+
977+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
978+ GGML_ASSERT ( dst->type == GGML_TYPE_F32);
979+ GGML_ASSERT (src0->type == dst->type );
980+ GGML_ASSERT (dst->ne [0 ] == nc);
981+ GGML_ASSERT (ggml_nrows (dst) == ggml_nrows (src0));
982+
983+ if (src1) {
984+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
985+ GGML_ASSERT (src1->nb [0 ] == ggml_element_size (src1));
986+ GGML_ASSERT (src1->ne [0 ] == nc);
987+ GGML_ASSERT (src0->type == src1->type );
988+ }
989+
990+ // const int32_t swapped = ((const int32_t *) dst->op_params)[1];
991+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
992+ const float alpha = ggml_get_op_params_f32 (dst, 2 );
993+ const float limit = ggml_get_op_params_f32 (dst, 3 );
994+
995+ float * src0_p = (float *) src0_d;
996+ float * src1_p = (float *) src1_d;
997+
998+ if (!src1) {
999+ src0_p += swapped ? nc : 0 ;
1000+ src1_p += swapped ? 0 : nc;
1001+ }
1002+
1003+ swiglu_oai_sycl (src0_p, src1_p, (float *)dst_d, ggml_nelements (dst), nc, src0_o / sizeof (float ), src1_o / sizeof (float ), alpha, limit, stream);
1004+ }
1005+
9141006static inline void ggml_sycl_op_geglu_erf (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
9151007 ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu (ctx, dst,
9161008 [](const auto * x_ptr, const auto * g_ptr, auto * dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
@@ -1070,6 +1162,11 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
10701162 ggml_sycl_op_swiglu (ctx, dst);
10711163}
10721164
1165+ void ggml_sycl_swiglu_oai (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1166+ scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
1167+ ggml_sycl_op_swiglu_oai (ctx, dst);
1168+ }
1169+
10731170void ggml_sycl_geglu_erf (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
10741171 scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
10751172 ggml_sycl_op_geglu_erf (ctx, dst);
0 commit comments