|
| 1 | +import ctypes |
| 2 | +import importlib |
| 3 | +import re |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast |
| 6 | + |
| 7 | +import numba |
| 8 | +import numpy as np |
| 9 | +from numpy.typing import DTypeLike |
| 10 | +from scipy import LowLevelCallable |
| 11 | + |
| 12 | + |
| 13 | +_C_TO_NUMPY: Dict[str, DTypeLike] = { |
| 14 | + "bool": np.bool_, |
| 15 | + "signed char": np.byte, |
| 16 | + "unsigned char": np.ubyte, |
| 17 | + "short": np.short, |
| 18 | + "unsigned short": np.ushort, |
| 19 | + "int": np.intc, |
| 20 | + "unsigned int": np.uintc, |
| 21 | + "long": np.int_, |
| 22 | + "unsigned long": np.uint, |
| 23 | + "long long": np.longlong, |
| 24 | + "float": np.single, |
| 25 | + "double": np.double, |
| 26 | + "long double": np.longdouble, |
| 27 | + "float complex": np.csingle, |
| 28 | + "double complex": np.cdouble, |
| 29 | +} |
| 30 | + |
| 31 | + |
| 32 | +@dataclass |
| 33 | +class Signature: |
| 34 | + res_dtype: DTypeLike |
| 35 | + res_c_type: str |
| 36 | + arg_dtypes: List[DTypeLike] |
| 37 | + arg_c_types: List[str] |
| 38 | + arg_names: List[Optional[str]] |
| 39 | + |
| 40 | + @property |
| 41 | + def arg_numba_types(self) -> List[DTypeLike]: |
| 42 | + return [numba.from_dtype(dtype) for dtype in self.arg_dtypes] |
| 43 | + |
| 44 | + def can_cast_args(self, args: List[DTypeLike]) -> bool: |
| 45 | + ok = True |
| 46 | + count = 0 |
| 47 | + for name, dtype in zip(self.arg_names, self.arg_dtypes): |
| 48 | + if name == "__pyx_skip_dispatch": |
| 49 | + continue |
| 50 | + if len(args) <= count: |
| 51 | + raise ValueError("Incorrect number of arguments") |
| 52 | + ok &= np.can_cast(args[count], dtype) |
| 53 | + count += 1 |
| 54 | + if count != len(args): |
| 55 | + return False |
| 56 | + return ok |
| 57 | + |
| 58 | + def provides(self, restype: DTypeLike, arg_dtypes: List[DTypeLike]) -> bool: |
| 59 | + args_ok = self.can_cast_args(arg_dtypes) |
| 60 | + if np.issubdtype(restype, np.inexact): |
| 61 | + result_ok = np.can_cast(self.res_dtype, restype, casting="same_kind") |
| 62 | + # We do not want to provide less accuracy than advertised |
| 63 | + result_ok &= np.dtype(self.res_dtype).itemsize >= np.dtype(restype).itemsize |
| 64 | + else: |
| 65 | + result_ok = np.can_cast(self.res_dtype, restype) |
| 66 | + return args_ok and result_ok |
| 67 | + |
| 68 | + @staticmethod |
| 69 | + def from_c_types(signature: bytes) -> "Signature": |
| 70 | + # Match strings like "double(int, double)" |
| 71 | + # and extract the return type and the joined arguments |
| 72 | + expr = re.compile(rb"\s*(?P<restype>[\w ]*\w+)\s*\((?P<args>[\w\s,]*)\)") |
| 73 | + re_match = re.fullmatch(expr, signature) |
| 74 | + |
| 75 | + if re_match is None: |
| 76 | + raise ValueError(f"Invalid signature: {signature.decode()}") |
| 77 | + |
| 78 | + groups = re_match.groupdict() |
| 79 | + res_c_type = groups["restype"].decode() |
| 80 | + res_dtype: DTypeLike = _C_TO_NUMPY[res_c_type] |
| 81 | + |
| 82 | + raw_args = groups["args"] |
| 83 | + |
| 84 | + decl_expr = re.compile( |
| 85 | + rb"\s*(?P<type>((long )|(unsigned )|(signed )|(double )|)" |
| 86 | + rb"((double)|(float)|(int)|(short)|(char)|(long)|(bool)|(complex)))" |
| 87 | + rb"(\s(?P<name>[\w_]*))?\s*" |
| 88 | + ) |
| 89 | + |
| 90 | + arg_dtypes = [] |
| 91 | + arg_names: List[Optional[str]] = [] |
| 92 | + arg_c_types = [] |
| 93 | + for raw_arg in raw_args.split(b","): |
| 94 | + re_match = re.fullmatch(decl_expr, raw_arg) |
| 95 | + if re_match is None: |
| 96 | + raise ValueError(f"Invalid signature: {signature.decode()}") |
| 97 | + groups = re_match.groupdict() |
| 98 | + arg_c_type = groups["type"].decode() |
| 99 | + try: |
| 100 | + arg_dtype = _C_TO_NUMPY[arg_c_type] |
| 101 | + except KeyError: |
| 102 | + raise ValueError(f"Unknown C type: {arg_c_type}") |
| 103 | + |
| 104 | + arg_c_types.append(arg_c_type) |
| 105 | + arg_dtypes.append(arg_dtype) |
| 106 | + name = groups["name"] |
| 107 | + if not name: |
| 108 | + arg_names.append(None) |
| 109 | + else: |
| 110 | + arg_names.append(name.decode()) |
| 111 | + |
| 112 | + return Signature(res_dtype, res_c_type, arg_dtypes, arg_c_types, arg_names) |
| 113 | + |
| 114 | + |
| 115 | +def _available_impls(func: Callable) -> List[Tuple[Signature, Any]]: |
| 116 | + """Find all available implementations for a fused cython function.""" |
| 117 | + impls = [] |
| 118 | + mod = importlib.import_module(func.__module__) |
| 119 | + |
| 120 | + signatures = getattr(func, "__signatures__", None) |
| 121 | + if signatures is not None: |
| 122 | + # Cython function with __signatures__ should be fused and thus |
| 123 | + # indexable |
| 124 | + func_map = cast(Mapping, func) |
| 125 | + candidates = [func_map[key] for key in signatures] |
| 126 | + else: |
| 127 | + candidates = [func] |
| 128 | + for candidate in candidates: |
| 129 | + name = candidate.__name__ |
| 130 | + capsule = mod.__pyx_capi__[name] |
| 131 | + llc = LowLevelCallable(capsule) |
| 132 | + try: |
| 133 | + signature = Signature.from_c_types(llc.signature.encode()) |
| 134 | + except KeyError: |
| 135 | + continue |
| 136 | + impls.append((signature, capsule)) |
| 137 | + return impls |
| 138 | + |
| 139 | + |
| 140 | +class _CythonWrapper(numba.types.WrapperAddressProtocol): |
| 141 | + def __init__(self, pyfunc, signature, capsule): |
| 142 | + self._keep_alive = capsule |
| 143 | + get_name = ctypes.pythonapi.PyCapsule_GetName |
| 144 | + get_name.restype = ctypes.c_char_p |
| 145 | + get_name.argtypes = (ctypes.py_object,) |
| 146 | + |
| 147 | + raw_signature = get_name(capsule) |
| 148 | + |
| 149 | + get_pointer = ctypes.pythonapi.PyCapsule_GetPointer |
| 150 | + get_pointer.restype = ctypes.c_void_p |
| 151 | + get_pointer.argtypes = (ctypes.py_object, ctypes.c_char_p) |
| 152 | + self._func_ptr = get_pointer(capsule, raw_signature) |
| 153 | + |
| 154 | + self._signature = signature |
| 155 | + self._pyfunc = pyfunc |
| 156 | + |
| 157 | + def signature(self): |
| 158 | + return numba.from_dtype(self._signature.res_dtype)( |
| 159 | + *self._signature.arg_numba_types |
| 160 | + ) |
| 161 | + |
| 162 | + def __wrapper_address__(self): |
| 163 | + return self._func_ptr |
| 164 | + |
| 165 | + def __call__(self, *args, **kwargs): |
| 166 | + args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)] |
| 167 | + if self.has_pyx_skip_dispatch(): |
| 168 | + output = self._pyfunc(*args[:-1], **kwargs) |
| 169 | + else: |
| 170 | + output = self._pyfunc(*args, **kwargs) |
| 171 | + return self._signature.res_dtype(output) |
| 172 | + |
| 173 | + def has_pyx_skip_dispatch(self): |
| 174 | + if not self._signature.arg_names: |
| 175 | + return False |
| 176 | + if any( |
| 177 | + name == "__pyx_skip_dispatch" for name in self._signature.arg_names[:-1] |
| 178 | + ): |
| 179 | + raise ValueError("skip_dispatch parameter must be last") |
| 180 | + return self._signature.arg_names[-1] == "__pyx_skip_dispatch" |
| 181 | + |
| 182 | + def numpy_arg_dtypes(self): |
| 183 | + return self._signature.arg_dtypes |
| 184 | + |
| 185 | + def numpy_output_dtype(self): |
| 186 | + return self._signature.res_dtype |
| 187 | + |
| 188 | + |
| 189 | +def wrap_cython_function(func, restype, arg_types): |
| 190 | + impls = _available_impls(func) |
| 191 | + compatible = [] |
| 192 | + for sig, capsule in impls: |
| 193 | + if sig.provides(restype, arg_types): |
| 194 | + compatible.append((sig, capsule)) |
| 195 | + |
| 196 | + def sort_key(args): |
| 197 | + sig, _ = args |
| 198 | + |
| 199 | + # Prefer functions with less inputs bytes |
| 200 | + argsize = sum(np.dtype(dtype).itemsize for dtype in sig.arg_dtypes) |
| 201 | + |
| 202 | + # Prefer functions with more exact (integer) arguments |
| 203 | + num_inexact = sum(np.issubdtype(dtype, np.inexact) for dtype in sig.arg_dtypes) |
| 204 | + return (num_inexact, argsize) |
| 205 | + |
| 206 | + compatible.sort(key=sort_key) |
| 207 | + |
| 208 | + if not compatible: |
| 209 | + raise NotImplementedError(f"Could not find a compatible impl of {func}") |
| 210 | + sig, capsule = compatible[0] |
| 211 | + return _CythonWrapper(func, sig, capsule) |
0 commit comments