Skip to content

Commit b7ee3d2

Browse files
authored
Merge pull request #397 from ev-br/take_along_axis_torch
ENH: test `take` and `take_along_axis` with indices < 0
2 parents ee3d9b7 + 2784516 commit b7ee3d2

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

array_api_tests/test_indexing_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818
def test_take(x, data):
1919
# TODO:
20-
# * negative indices
2120
# * different dtypes for indices
2221

2322
# axis is optional but only if x.ndim == 1
@@ -28,7 +27,7 @@ def test_take(x, data):
2827
kw = {"axis": data.draw(_axis_st)}
2928
axis = kw.get("axis", 0)
3029
_indices = data.draw(
31-
st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
30+
st.lists(st.integers(-x.shape[axis], x.shape[axis] - 1), min_size=1, unique=True),
3231
label="_indices",
3332
)
3433
n_axis = axis if axis>=0 else x.ndim + axis
@@ -81,7 +80,6 @@ def test_take(x, data):
8180
)
8281
def test_take_along_axis(x, data):
8382
# TODO
84-
# 2. negative indices
8583
# 3. different dtypes for indices
8684
# 4. "broadcast-compatible" indices
8785
axis = data.draw(
@@ -101,7 +99,7 @@ def test_take_along_axis(x, data):
10199
hh.arrays(
102100
shape=idx_shape,
103101
dtype=dh.default_int,
104-
elements={"min_value": 0, "max_value": x.shape[n_axis]-1}
102+
elements={"min_value": -x.shape[n_axis], "max_value": x.shape[n_axis]-1}
105103
),
106104
label="indices"
107105
)

0 commit comments

Comments
 (0)