Skip to content

Commit 98be9c5

Browse files
committed
Replace numba_scipy
1 parent 1c50709 commit 98be9c5

File tree

4 files changed

+370
-73
lines changed

4 files changed

+370
-73
lines changed

pytensor/configdefaults.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,12 +1252,6 @@ def add_numba_configvars():
12521252
BoolParam(True),
12531253
in_c_key=False,
12541254
)
1255-
config.add(
1256-
"numba_scipy",
1257-
("Enable usage of the numba_scipy package for special functions",),
1258-
BoolParam(True),
1259-
in_c_key=False,
1260-
)
12611255

12621256

12631257
def _default_compiledirname():
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import ctypes
2+
import importlib
3+
import re
4+
from dataclasses import dataclass
5+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
6+
7+
import numba
8+
import numpy as np
9+
from numpy.typing import DTypeLike
10+
from scipy import LowLevelCallable
11+
12+
13+
_C_TO_NUMPY: Dict[str, DTypeLike] = {
14+
"bool": np.bool_,
15+
"signed char": np.byte,
16+
"unsigned char": np.ubyte,
17+
"short": np.short,
18+
"unsigned short": np.ushort,
19+
"int": np.intc,
20+
"unsigned int": np.uintc,
21+
"long": np.int_,
22+
"unsigned long": np.uint,
23+
"long long": np.longlong,
24+
"float": np.single,
25+
"double": np.double,
26+
"long double": np.longdouble,
27+
"float complex": np.csingle,
28+
"double complex": np.cdouble,
29+
}
30+
31+
32+
@dataclass
33+
class Signature:
34+
res_dtype: DTypeLike
35+
res_c_type: str
36+
arg_dtypes: List[DTypeLike]
37+
arg_c_types: List[str]
38+
arg_names: List[Optional[str]]
39+
40+
@property
41+
def arg_numba_types(self) -> List[DTypeLike]:
42+
return [numba.from_dtype(dtype) for dtype in self.arg_dtypes]
43+
44+
def can_cast_args(self, args: List[DTypeLike]) -> bool:
45+
ok = True
46+
count = 0
47+
for name, dtype in zip(self.arg_names, self.arg_dtypes):
48+
if name == "__pyx_skip_dispatch":
49+
continue
50+
if len(args) <= count:
51+
raise ValueError("Incorrect number of arguments")
52+
ok &= np.can_cast(args[count], dtype)
53+
count += 1
54+
if count != len(args):
55+
return False
56+
return ok
57+
58+
def provides(self, restype: DTypeLike, arg_dtypes: List[DTypeLike]) -> bool:
59+
args_ok = self.can_cast_args(arg_dtypes)
60+
if np.issubdtype(restype, np.inexact):
61+
result_ok = np.can_cast(self.res_dtype, restype, casting="same_kind")
62+
# We do not want to provide less accuracy than advertised
63+
result_ok &= np.dtype(self.res_dtype).itemsize >= np.dtype(restype).itemsize
64+
else:
65+
result_ok = np.can_cast(self.res_dtype, restype)
66+
return args_ok and result_ok
67+
68+
@staticmethod
69+
def from_c_types(signature: bytes) -> "Signature":
70+
# Match strings like "double(int, double)"
71+
# and extract the return type and the joined arguments
72+
expr = re.compile(rb"\s*(?P<restype>[\w ]*\w+)\s*\((?P<args>[\w\s,]*)\)")
73+
re_match = re.fullmatch(expr, signature)
74+
75+
if re_match is None:
76+
raise ValueError(f"Invalid signature: {signature.decode()}")
77+
78+
groups = re_match.groupdict()
79+
res_c_type = groups["restype"].decode()
80+
res_dtype: DTypeLike = _C_TO_NUMPY[res_c_type]
81+
82+
raw_args = groups["args"]
83+
84+
decl_expr = re.compile(
85+
rb"\s*(?P<type>((long )|(unsigned )|(signed )|(double )|)"
86+
rb"((double)|(float)|(int)|(short)|(char)|(long)|(bool)|(complex)))"
87+
rb"(\s(?P<name>[\w_]*))?\s*"
88+
)
89+
90+
arg_dtypes = []
91+
arg_names: List[Optional[str]] = []
92+
arg_c_types = []
93+
for raw_arg in raw_args.split(b","):
94+
re_match = re.fullmatch(decl_expr, raw_arg)
95+
if re_match is None:
96+
raise ValueError(f"Invalid signature: {signature.decode()}")
97+
groups = re_match.groupdict()
98+
arg_c_type = groups["type"].decode()
99+
try:
100+
arg_dtype = _C_TO_NUMPY[arg_c_type]
101+
except KeyError:
102+
raise ValueError(f"Unknown C type: {arg_c_type}")
103+
104+
arg_c_types.append(arg_c_type)
105+
arg_dtypes.append(arg_dtype)
106+
name = groups["name"]
107+
if not name:
108+
arg_names.append(None)
109+
else:
110+
arg_names.append(name.decode())
111+
112+
return Signature(res_dtype, res_c_type, arg_dtypes, arg_c_types, arg_names)
113+
114+
115+
def _available_impls(func: Callable) -> List[Tuple[Signature, Any]]:
116+
"""Find all available implementations for a fused cython function."""
117+
impls = []
118+
mod = importlib.import_module(func.__module__)
119+
120+
signatures = getattr(func, "__signatures__", None)
121+
if signatures is not None:
122+
# Cython function with __signatures__ should be fused and thus
123+
# indexable
124+
func_map = cast(Mapping, func)
125+
candidates = [func_map[key] for key in signatures]
126+
else:
127+
candidates = [func]
128+
for candidate in candidates:
129+
name = candidate.__name__
130+
capsule = mod.__pyx_capi__[name]
131+
llc = LowLevelCallable(capsule)
132+
try:
133+
signature = Signature.from_c_types(llc.signature.encode())
134+
except KeyError:
135+
continue
136+
impls.append((signature, capsule))
137+
return impls
138+
139+
140+
class _CythonWrapper(numba.types.WrapperAddressProtocol):
141+
def __init__(self, pyfunc, signature, capsule):
142+
self._keep_alive = capsule
143+
get_name = ctypes.pythonapi.PyCapsule_GetName
144+
get_name.restype = ctypes.c_char_p
145+
get_name.argtypes = (ctypes.py_object,)
146+
147+
raw_signature = get_name(capsule)
148+
149+
get_pointer = ctypes.pythonapi.PyCapsule_GetPointer
150+
get_pointer.restype = ctypes.c_void_p
151+
get_pointer.argtypes = (ctypes.py_object, ctypes.c_char_p)
152+
self._func_ptr = get_pointer(capsule, raw_signature)
153+
154+
self._signature = signature
155+
self._pyfunc = pyfunc
156+
157+
def signature(self):
158+
return numba.from_dtype(self._signature.res_dtype)(
159+
*self._signature.arg_numba_types
160+
)
161+
162+
def __wrapper_address__(self):
163+
return self._func_ptr
164+
165+
def __call__(self, *args, **kwargs):
166+
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
167+
if self.has_pyx_skip_dispatch():
168+
output = self._pyfunc(*args[:-1], **kwargs)
169+
else:
170+
output = self._pyfunc(*args, **kwargs)
171+
return self._signature.res_dtype(output)
172+
173+
def has_pyx_skip_dispatch(self):
174+
if not self._signature.arg_names:
175+
return False
176+
if any(
177+
name == "__pyx_skip_dispatch" for name in self._signature.arg_names[:-1]
178+
):
179+
raise ValueError("skip_dispatch parameter must be last")
180+
return self._signature.arg_names[-1] == "__pyx_skip_dispatch"
181+
182+
def numpy_arg_dtypes(self):
183+
return self._signature.arg_dtypes
184+
185+
def numpy_output_dtype(self):
186+
return self._signature.res_dtype
187+
188+
189+
def wrap_cython_function(func, restype, arg_types):
190+
impls = _available_impls(func)
191+
compatible = []
192+
for sig, capsule in impls:
193+
if sig.provides(restype, arg_types):
194+
compatible.append((sig, capsule))
195+
196+
def sort_key(args):
197+
sig, _ = args
198+
199+
# Prefer functions with less inputs bytes
200+
argsize = sum(np.dtype(dtype).itemsize for dtype in sig.arg_dtypes)
201+
202+
# Prefer functions with more exact (integer) arguments
203+
num_inexact = sum(np.issubdtype(dtype, np.inexact) for dtype in sig.arg_dtypes)
204+
return (num_inexact, argsize)
205+
206+
compatible.sort(key=sort_key)
207+
208+
if not compatible:
209+
raise NotImplementedError(f"Could not find a compatible impl of {func}")
210+
sig, capsule = compatible[0]
211+
return _CythonWrapper(func, sig, capsule)

0 commit comments

Comments
 (0)