Skip to content

Commit ae5b2b7

Browse files
authored
ENH: support 'out' keyword in Categorical.__array_ufunc__ (#45381)
1 parent 5357f79 commit ae5b2b7

File tree

5 files changed

+56
-1
lines changed

5 files changed

+56
-1
lines changed

pandas/core/arrays/categorical.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,6 +1517,12 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
15171517
if result is not NotImplemented:
15181518
return result
15191519

1520+
if "out" in kwargs:
1521+
# e.g. test_numpy_ufuncs_out
1522+
return arraylike.dispatch_ufunc_with_out(
1523+
self, ufunc, method, *inputs, **kwargs
1524+
)
1525+
15201526
if method == "reduce":
15211527
# e.g. TestCategoricalAnalytics::test_min_max_ordered
15221528
result = arraylike.dispatch_reduction_ufunc(

pandas/core/arrays/numpy_.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
141141
if result is not NotImplemented:
142142
return result
143143

144+
if "out" in kwargs:
145+
# e.g. test_ufunc_unary
146+
return arraylike.dispatch_ufunc_with_out(
147+
self, ufunc, method, *inputs, **kwargs
148+
)
149+
144150
if method == "reduce":
145151
result = arraylike.dispatch_reduction_ufunc(
146152
self, ufunc, method, *inputs, **kwargs

pandas/tests/arrays/test_numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ def test_ufunc_unary(ufunc):
237237
expected = PandasArray(ufunc(arr._ndarray))
238238
tm.assert_extension_array_equal(result, expected)
239239

240+
# same thing but with the 'out' keyword
241+
out = PandasArray(np.array([-9.0, -9.0, -9.0]))
242+
ufunc(arr, out=out)
243+
tm.assert_extension_array_equal(out, expected)
244+
240245

241246
def test_ufunc():
242247
arr = PandasArray(np.array([-1.0, 0.0, 1.0]))

pandas/tests/extension/decimal/array.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
119119
):
120120
return NotImplemented
121121

122+
result = arraylike.maybe_dispatch_ufunc_to_dunder_op(
123+
self, ufunc, method, *inputs, **kwargs
124+
)
125+
if result is not NotImplemented:
126+
# e.g. test_array_ufunc_series_scalar_other
127+
return result
128+
129+
if "out" in kwargs:
130+
return arraylike.dispatch_ufunc_with_out(
131+
self, ufunc, method, *inputs, **kwargs
132+
)
133+
122134
inputs = tuple(x._data if isinstance(x, DecimalArray) else x for x in inputs)
123135
result = getattr(ufunc, method)(*inputs, **kwargs)
124136

pandas/tests/indexes/test_numpy_compat.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@
1818
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin
1919

2020

21+
def test_numpy_ufuncs_out(index):
22+
result = index == index
23+
24+
out = np.empty(index.shape, dtype=bool)
25+
np.equal(index, index, out=out)
26+
tm.assert_numpy_array_equal(out, result)
27+
28+
if not index._is_multi:
29+
# same thing on the ExtensionArray
30+
out = np.empty(index.shape, dtype=bool)
31+
np.equal(index.array, index.array, out=out)
32+
tm.assert_numpy_array_equal(out, result)
33+
34+
2135
@pytest.mark.parametrize(
2236
"func",
2337
[
@@ -91,6 +105,10 @@ def test_numpy_ufuncs_other(index, func, request):
91105
# numpy 1.18 changed isinf and isnan to not raise on dt64/td64
92106
result = func(index)
93107
assert isinstance(result, np.ndarray)
108+
109+
out = np.empty(index.shape, dtype=bool)
110+
func(index, out=out)
111+
tm.assert_numpy_array_equal(out, result)
94112
else:
95113
with tm.external_error_raised(TypeError):
96114
func(index)
@@ -109,7 +127,15 @@ def test_numpy_ufuncs_other(index, func, request):
109127
assert isinstance(result, BooleanArray)
110128
else:
111129
assert isinstance(result, np.ndarray)
112-
assert not isinstance(result, Index)
130+
131+
out = np.empty(index.shape, dtype=bool)
132+
func(index, out=out)
133+
134+
if not isinstance(index.dtype, np.dtype):
135+
tm.assert_numpy_array_equal(out, result._data)
136+
else:
137+
tm.assert_numpy_array_equal(out, result)
138+
113139
else:
114140
if len(index) == 0:
115141
pass

0 commit comments

Comments
 (0)