Skip to content

Commit 4f0dc02

Browse files
committed
Add dovmod implementation to ufunc extensions
1 parent 65b0876 commit 4f0dc02

File tree

5 files changed

+326
-0
lines changed

5 files changed

+326
-0
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ set(_elementwise_sources
3131
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/bitwise_count.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/common.cpp
3333
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/degrees.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/divmod.cpp
3435
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/erf_funcs.cpp
3536
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fabs.cpp
3637
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fix.cpp

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
#include "bitwise_count.hpp"
3232
#include "degrees.hpp"
33+
#include "divmod.hpp"
3334
#include "erf_funcs.hpp"
3435
#include "fabs.hpp"
3536
#include "fix.hpp"
@@ -63,6 +64,7 @@ void init_elementwise_functions(py::module_ m)
6364
{
6465
init_bitwise_count(m);
6566
init_degrees(m);
67+
init_divmod(m);
6668
init_erf_funcs(m);
6769
init_fabs(m);
6870
init_fix(m);
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#include <complex>
30+
#include <type_traits>
31+
#include <vector>
32+
33+
#include <sycl/sycl.hpp>
34+
35+
#include "dpctl4pybind11.hpp"
36+
37+
#include "divmod.hpp"
38+
#include "kernels/elementwise_functions/divmod.hpp"
39+
#include "populate.hpp"
40+
41+
// include a local copy of elementwise common header from dpctl tensor:
42+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
43+
// TODO: replace by including dpctl header once available
44+
#include "../../elementwise_functions/elementwise_functions.hpp"
45+
46+
#include "../../elementwise_functions/common.hpp"
47+
#include "../../elementwise_functions/type_dispatch_building.hpp"
48+
49+
// utils extension header
50+
#include "ext/common.hpp"
51+
52+
// dpctl tensor headers
53+
#include "kernels/elementwise_functions/common.hpp"
54+
#include "utils/type_dispatch.hpp"
55+
56+
namespace dpnp::extensions::ufunc
57+
{
58+
namespace py = pybind11;
59+
namespace py_int = dpnp::extensions::py_internal;
60+
61+
namespace impl
62+
{
63+
namespace ew_cmn_ns = dpnp::extensions::py_internal::elementwise_common;
64+
namespace td_int_ns = py_int::type_dispatch;
65+
namespace td_ns = dpctl::tensor::type_dispatch;
66+
67+
using dpnp::kernels::divmod::DivmodFunctor;
68+
69+
template <typename T1, typename T2>
70+
struct OutputType
71+
{
72+
using table_type = typename std::disjunction< // disjunction is C++17
73+
// feature, supported by DPC++
74+
td_int_ns::
75+
BinaryTypeMapTwoResultsEntry<T1, std::uint8_t, T2, std::uint8_t>,
76+
td_int_ns::
77+
BinaryTypeMapTwoResultsEntry<T1, std::int8_t, T2, std::int8_t>,
78+
td_int_ns::
79+
BinaryTypeMapTwoResultsEntry<T1, std::uint16_t, T2, std::uint16_t>,
80+
td_int_ns::
81+
BinaryTypeMapTwoResultsEntry<T1, std::int16_t, T2, std::int16_t>,
82+
td_int_ns::
83+
BinaryTypeMapTwoResultsEntry<T1, std::uint32_t, T2, std::uint32_t>,
84+
td_int_ns::
85+
BinaryTypeMapTwoResultsEntry<T1, std::int32_t, T2, std::int32_t>,
86+
td_int_ns::
87+
BinaryTypeMapTwoResultsEntry<T1, std::uint64_t, T2, std::uint64_t>,
88+
td_int_ns::
89+
BinaryTypeMapTwoResultsEntry<T1, std::int64_t, T2, std::int64_t>,
90+
td_int_ns::BinaryTypeMapTwoResultsEntry<T1, sycl::half, T2, sycl::half>,
91+
td_int_ns::BinaryTypeMapTwoResultsEntry<T1, float, T2, float>,
92+
td_int_ns::BinaryTypeMapTwoResultsEntry<T1, double, T2, double>,
93+
td_int_ns::DefaultTwoResultsEntry<void>>;
94+
using value_type1 = typename table_type::result_type1;
95+
using value_type2 = typename table_type::result_type2;
96+
};
97+
98+
template <typename argTy1,
99+
typename argTy2,
100+
typename resTy1,
101+
typename resTy2,
102+
unsigned int vec_sz = 4,
103+
unsigned int n_vecs = 2,
104+
bool enable_sg_loadstore = true>
105+
using ContigFunctor = ew_cmn_ns::BinaryTwoOutputsContigFunctor<
106+
argTy1,
107+
argTy2,
108+
resTy1,
109+
resTy2,
110+
DivmodFunctor<argTy1, argTy2, resTy1, resTy2>,
111+
vec_sz,
112+
n_vecs,
113+
enable_sg_loadstore>;
114+
115+
template <typename argTy1,
116+
typename argTy2,
117+
typename resTy1,
118+
typename resTy2,
119+
typename IndexerT>
120+
using StridedFunctor = ew_cmn_ns::BinaryTwoOutputsStridedFunctor<
121+
argTy1,
122+
argTy2,
123+
resTy1,
124+
resTy2,
125+
IndexerT,
126+
DivmodFunctor<argTy1, argTy2, resTy1, resTy2>>;
127+
128+
using ew_cmn_ns::binary_two_outputs_contig_impl_fn_ptr_t;
129+
using ew_cmn_ns::binary_two_outputs_strided_impl_fn_ptr_t;
130+
131+
static binary_two_outputs_contig_impl_fn_ptr_t
132+
divmod_contig_dispatch_table[td_ns::num_types][td_ns::num_types];
133+
static std::pair<int, int> divmod_output_typeid_table[td_ns::num_types]
134+
[td_ns::num_types];
135+
static binary_two_outputs_strided_impl_fn_ptr_t
136+
divmod_strided_dispatch_table[td_ns::num_types][td_ns::num_types];
137+
138+
MACRO_POPULATE_DISPATCH_2OUTS_TABLES(divmod);
139+
} // namespace impl
140+
141+
void init_divmod(py::module_ m)
142+
{
143+
using arrayT = dpctl::tensor::usm_ndarray;
144+
using event_vecT = std::vector<sycl::event>;
145+
{
146+
impl::populate_divmod_dispatch_tables();
147+
using impl::divmod_contig_dispatch_table;
148+
using impl::divmod_output_typeid_table;
149+
using impl::divmod_strided_dispatch_table;
150+
151+
auto divmod_pyapi = [&](const arrayT &src1, const arrayT &src2,
152+
const arrayT &dst1, const arrayT &dst2,
153+
sycl::queue &exec_q,
154+
const event_vecT &depends = {}) {
155+
return py_int::py_binary_two_outputs_ufunc(
156+
src1, src2, dst1, dst2, exec_q, depends,
157+
divmod_output_typeid_table, divmod_contig_dispatch_table,
158+
divmod_strided_dispatch_table);
159+
};
160+
m.def("_divmod", divmod_pyapi, "", py::arg("src1"), py::arg("src2"),
161+
py::arg("dst1"), py::arg("dst2"), py::arg("sycl_queue"),
162+
py::arg("depends") = py::list());
163+
164+
auto divmod_result_type_pyapi = [&](const py::dtype &dtype1,
165+
const py::dtype &dtype2) {
166+
return py_int::py_binary_two_outputs_ufunc_result_type(
167+
dtype1, dtype2, divmod_output_typeid_table);
168+
};
169+
m.def("_divmod_result_type", divmod_result_type_pyapi);
170+
}
171+
}
172+
} // namespace dpnp::extensions::ufunc
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#pragma once
30+
31+
#include <pybind11/pybind11.h>
32+
33+
namespace py = pybind11;
34+
35+
namespace dpnp::extensions::ufunc
36+
{
37+
void init_divmod(py::module_ m);
38+
} // namespace dpnp::extensions::ufunc

dpnp/backend/extensions/ufunc/elementwise_functions/populate.hpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,116 @@ namespace ext_ns = ext::common;
335335
ext_ns::init_dispatch_table<int, TypeMapFactory>( \
336336
__name__##_output_typeid_table); \
337337
};
338+
339+
/**
340+
* @brief A macro used to define factories and a populating binary universal
341+
* functions with two output arrays.
342+
*/
343+
#define MACRO_POPULATE_DISPATCH_2OUTS_TABLES(__name__) \
344+
template <typename argT1, typename argT2, typename resT1, typename resT2, \
345+
unsigned int vec_sz, unsigned int n_vecs> \
346+
class __name__##_contig_kernel; \
347+
\
348+
template <typename argTy1, typename argTy2> \
349+
sycl::event __name__##_contig_impl( \
350+
sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
351+
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
352+
char *res1_p, py::ssize_t res1_offset, char *res2_p, \
353+
py::ssize_t res2_offset, const std::vector<sycl::event> &depends = {}) \
354+
{ \
355+
return ew_cmn_ns::binary_two_outputs_contig_impl< \
356+
argTy1, argTy2, OutputType, ContigFunctor, \
357+
__name__##_contig_kernel>( \
358+
exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res1_p, \
359+
res1_offset, res2_p, res2_offset, depends); \
360+
} \
361+
\
362+
template <typename fnT, typename T1, typename T2> \
363+
struct ContigFactory \
364+
{ \
365+
fnT get() \
366+
{ \
367+
if constexpr (std::is_same_v< \
368+
typename OutputType<T1, T2>::value_type1, \
369+
void> || \
370+
std::is_same_v< \
371+
typename OutputType<T1, T2>::value_type2, void>) \
372+
{ \
373+
\
374+
fnT fn = nullptr; \
375+
return fn; \
376+
} \
377+
else { \
378+
fnT fn = __name__##_contig_impl<T1, T2>; \
379+
return fn; \
380+
} \
381+
} \
382+
}; \
383+
\
384+
template <typename fnT, typename T1, typename T2> \
385+
struct TypeMapFactory \
386+
{ \
387+
std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
388+
std::pair<int, int>> \
389+
get() \
390+
{ \
391+
using rT1 = typename OutputType<T1, T2>::value_type1; \
392+
using rT2 = typename OutputType<T1, T2>::value_type2; \
393+
return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
394+
td_ns::GetTypeid<rT2>{}.get()); \
395+
} \
396+
}; \
397+
\
398+
template <typename T1, typename T2, typename resT1, typename resT2, \
399+
typename IndexerT> \
400+
class __name__##_strided_kernel; \
401+
\
402+
template <typename argTy1, typename argTy2> \
403+
sycl::event __name__##_strided_impl( \
404+
sycl::queue &exec_q, size_t nelems, int nd, \
405+
const py::ssize_t *shape_and_strides, const char *arg1_p, \
406+
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
407+
char *res1_p, py::ssize_t res1_offset, char *res2_p, \
408+
py::ssize_t res2_offset, const std::vector<sycl::event> &depends, \
409+
const std::vector<sycl::event> &additional_depends) \
410+
{ \
411+
return ew_cmn_ns::binary_two_outputs_strided_impl< \
412+
argTy1, argTy2, OutputType, StridedFunctor, \
413+
__name__##_strided_kernel>( \
414+
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
415+
arg2_p, arg2_offset, res1_p, res1_offset, res2_p, res2_offset, \
416+
depends, additional_depends); \
417+
} \
418+
\
419+
template <typename fnT, typename T1, typename T2> \
420+
struct StridedFactory \
421+
{ \
422+
fnT get() \
423+
{ \
424+
if constexpr (std::is_same_v< \
425+
typename OutputType<T1, T2>::value_type1, \
426+
void> || \
427+
std::is_same_v< \
428+
typename OutputType<T1, T2>::value_type2, void>) \
429+
{ \
430+
fnT fn = nullptr; \
431+
return fn; \
432+
} \
433+
else { \
434+
fnT fn = __name__##_strided_impl<T1, T2>; \
435+
return fn; \
436+
} \
437+
} \
438+
}; \
439+
\
440+
void populate_##__name__##_dispatch_tables(void) \
441+
{ \
442+
ext_ns::init_dispatch_table<binary_two_outputs_contig_impl_fn_ptr_t, \
443+
ContigFactory>( \
444+
__name__##_contig_dispatch_table); \
445+
ext_ns::init_dispatch_table<binary_two_outputs_strided_impl_fn_ptr_t, \
446+
StridedFactory>( \
447+
__name__##_strided_dispatch_table); \
448+
ext_ns::init_dispatch_table<std::pair<int, int>, TypeMapFactory>( \
449+
__name__##_output_typeid_table); \
450+
};

0 commit comments

Comments
 (0)