|
5 | 5 | from functools import wraps |
6 | 6 | from inspect import signature |
7 | 7 |
|
8 | | -from .common._helpers import get_namespace |
9 | | - |
10 | | -def get_xp(f): |
| 8 | +def get_xp(xp): |
11 | 9 | """ |
12 | | - Decorator to automatically replace xp with the corresponding array module |
| 10 | + Decorator to automatically replace xp with the corresponding array module. |
13 | 11 |
|
14 | 12 | Use like |
15 | 13 |
|
16 | | - @get_xp |
| 14 | + import numpy as np |
| 15 | +
|
| 16 | + @get_xp(np) |
17 | 17 | def func(x, /, xp, kwarg=None): |
18 | 18 | return xp.func(x, kwarg=kwarg) |
19 | 19 |
|
20 | | - Note that xp must be able to be passed as a keyword argument. |
| 20 | + Note that xp must be a keyword argument and come after all non-keyword |
| 21 | + arguments. |
| 22 | +
|
21 | 23 | """ |
22 | | - @wraps(f) |
23 | | - def inner(*args, **kwargs): |
24 | | - xp = get_namespace(*args, _use_compat=False) |
25 | | - return f(*args, xp=xp, **kwargs) |
| 24 | + def inner(f): |
| 25 | + sig = signature(f) |
| 26 | + |
| 27 | + @wraps(f) |
| 28 | + def wrapped_f(*args, **kwargs): |
| 29 | + return f(*args, xp=xp, **kwargs) |
26 | 30 |
|
27 | | - sig = signature(f) |
28 | | - new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) |
| 31 | + new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) |
29 | 32 |
|
30 | | - if inner.__doc__ is None: |
31 | | - inner.__doc__ = f"""\ |
| 33 | + if wrapped_f.__doc__ is None: |
| 34 | + wrapped_f.__doc__ = f"""\ |
32 | 35 | Array API compatibility wrapper for {f.__name__}. |
33 | 36 |
|
34 | 37 | See the corresponding documentation in NumPy/CuPy and/or the array API |
35 | 38 | specification for more details. |
36 | 39 |
|
37 | 40 | """ |
38 | | - inner.__signature__ = new_sig |
| 41 | + # wrapped_f.__signature__ = new_sig |
| 42 | + return wrapped_f |
39 | 43 |
|
40 | 44 | return inner |
0 commit comments