11import math
2+ import warnings
23from functools import reduce
34from typing import List
45
1011from pytensor .compile .ops import ViewOp
1112from pytensor .graph .basic import Variable
1213from pytensor .link .numba .dispatch import basic as numba_basic
13- from pytensor .link .numba .dispatch .basic import create_numba_signature , numba_funcify
14+ from pytensor .link .numba .dispatch .basic import (
15+ create_numba_signature ,
16+ generate_fallback_impl ,
17+ numba_funcify ,
18+ )
1419from pytensor .link .utils import (
1520 compile_function_src ,
1621 get_name_for_object ,
@@ -37,14 +42,31 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
3742 # compiling the same Numba function over and over again?
3843
3944 scalar_func_name = op .nfunc_spec [0 ]
45+ scalar_func = None
4046
4147 if scalar_func_name .startswith ("scipy." ):
4248 func_package = scipy
4349 scalar_func_name = scalar_func_name .split ("." , 1 )[- 1 ]
50+
51+ use_numba_scipy = config .numba_scipy
52+ if use_numba_scipy :
53+ try :
54+ import numba_scipy # noqa: F401
55+ except ImportError :
56+ use_numba_scipy = False
57+ if not use_numba_scipy :
58+ warnings .warn (
59+ "Native numba versions of scipy functions might be "
60+ "avalable if numba-scipy is installed." ,
61+ UserWarning ,
62+ )
63+ scalar_func = generate_fallback_impl (op , node , ** kwargs )
4464 else :
4565 func_package = np
4666
47- if "." in scalar_func_name :
67+ if scalar_func is not None :
68+ pass
69+ elif "." in scalar_func_name :
4870 scalar_func = reduce (getattr , [scipy ] + scalar_func_name .split ("." ))
4971 else :
5072 scalar_func = getattr (func_package , scalar_func_name )
@@ -220,7 +242,7 @@ def clip(_x, _min, _max):
220242
221243@numba_funcify .register (Composite )
222244def numba_funcify_Composite (op , node , ** kwargs ):
223- signature = create_numba_signature (node , force_scalar = True )
245+ signature = create_numba_signature (op . fgraph , force_scalar = True )
224246
225247 _ = kwargs .pop ("storage_map" , None )
226248
0 commit comments