Skip to content

Commit 1c50709

Browse files
committed
Use objmode in scipy.special without numba-scipy
1 parent b2caa73 commit 1c50709

File tree

5 files changed

+58
-8
lines changed

5 files changed

+58
-8
lines changed

pytensor/configdefaults.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,12 @@ def add_numba_configvars():
12521252
BoolParam(True),
12531253
in_c_key=False,
12541254
)
1255+
config.add(
1256+
"numba_scipy",
1257+
("Enable usage of the numba_scipy package for special functions",),
1258+
BoolParam(True),
1259+
in_c_key=False,
1260+
)
12551261

12561262

12571263
def _default_compiledirname():

pytensor/link/numba/dispatch/basic.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,8 @@ def numba_typify(data, dtype=None, **kwargs):
323323
return data
324324

325325

326-
@singledispatch
327-
def numba_funcify(op, node=None, storage_map=None, **kwargs):
328-
"""Create a Numba compatible function from an PyTensor `Op`."""
326+
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
327+
"""Create a Numba compatible function from an Aesara `Op`."""
329328

330329
warnings.warn(
331330
f"Numba will use object mode to run {op}'s perform method",
@@ -379,6 +378,12 @@ def perform(*inputs):
379378
return perform
380379

381380

381+
@singledispatch
382+
def numba_funcify(op, node=None, storage_map=None, **kwargs):
383+
"""Generate a numba function for a given op and apply node."""
384+
return generate_fallback_impl(op, node, storage_map, **kwargs)
385+
386+
382387
@numba_funcify.register(OpFromGraph)
383388
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
384389

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
OR,
2828
XOR,
2929
Add,
30+
Composite,
3031
IntDiv,
3132
Mean,
3233
Mul,
@@ -40,6 +41,7 @@
4041
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
4142
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros
4243
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
44+
from pytensor.tensor.type import scalar
4345

4446

4547
@singledispatch
@@ -424,8 +426,17 @@ def axis_apply_fn(x):
424426

425427
@numba_funcify.register(Elemwise)
426428
def numba_funcify_Elemwise(op, node, **kwargs):
427-
428-
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
429+
# Creating a new scalar node is more involved and unnecessary
430+
# if the scalar_op is composite, as the fgraph already contains
431+
# all the necessary information.
432+
scalar_node = None
433+
if not isinstance(op.scalar_op, Composite):
434+
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
435+
scalar_node = op.scalar_op.make_node(*scalar_inputs)
436+
437+
scalar_op_fn = numba_funcify(
438+
op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs
439+
)
429440
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False)
430441
elemwise_fn_name = elemwise_fn.__name__
431442

pytensor/link/numba/dispatch/scalar.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import warnings
23
from functools import reduce
34
from typing import List
45

@@ -10,7 +11,11 @@
1011
from pytensor.compile.ops import ViewOp
1112
from pytensor.graph.basic import Variable
1213
from pytensor.link.numba.dispatch import basic as numba_basic
13-
from pytensor.link.numba.dispatch.basic import create_numba_signature, numba_funcify
14+
from pytensor.link.numba.dispatch.basic import (
15+
create_numba_signature,
16+
generate_fallback_impl,
17+
numba_funcify,
18+
)
1419
from pytensor.link.utils import (
1520
compile_function_src,
1621
get_name_for_object,
@@ -37,14 +42,31 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
3742
# compiling the same Numba function over and over again?
3843

3944
scalar_func_name = op.nfunc_spec[0]
45+
scalar_func = None
4046

4147
if scalar_func_name.startswith("scipy."):
4248
func_package = scipy
4349
scalar_func_name = scalar_func_name.split(".", 1)[-1]
50+
51+
use_numba_scipy = config.numba_scipy
52+
if use_numba_scipy:
53+
try:
54+
import numba_scipy # noqa: F401
55+
except ImportError:
56+
use_numba_scipy = False
57+
if not use_numba_scipy:
58+
warnings.warn(
59+
"Native numba versions of scipy functions might be "
60+
"avalable if numba-scipy is installed.",
61+
UserWarning,
62+
)
63+
scalar_func = generate_fallback_impl(op, node, **kwargs)
4464
else:
4565
func_package = np
4666

47-
if "." in scalar_func_name:
67+
if scalar_func is not None:
68+
pass
69+
elif "." in scalar_func_name:
4870
scalar_func = reduce(getattr, [scipy] + scalar_func_name.split("."))
4971
else:
5072
scalar_func = getattr(func_package, scalar_func_name)
@@ -220,7 +242,7 @@ def clip(_x, _min, _max):
220242

221243
@numba_funcify.register(Composite)
222244
def numba_funcify_Composite(op, node, **kwargs):
223-
signature = create_numba_signature(node, force_scalar=True)
245+
signature = create_numba_signature(op.fgraph, force_scalar=True)
224246

225247
_ = kwargs.pop("storage_map", None)
226248

tests/link/numba/test_elemwise.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@
5757
lambda x: at.erfc(x),
5858
None,
5959
),
60+
(
61+
[at.vector()],
62+
[rng.standard_normal(100).astype(config.floatX)],
63+
lambda x: at.erfcx(x),
64+
None,
65+
),
6066
(
6167
[at.vector() for i in range(4)],
6268
[rng.standard_normal(100).astype(config.floatX) for i in range(4)],

0 commit comments

Comments
 (0)