Skip to content

Commit 4298b76

Browse files
committed
Numba Split: Validate sizes
1 parent de6aca8 commit 4298b76

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,19 @@ def join(axis, *tensors):
124124
@register_funcify_default_op_cache_key(Split)
125125
def numba_funcify_Split(op, **kwargs):
126126
@numba_basic.numba_njit
127-
def split(tensor, axis, indices):
128-
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item())
127+
def split(x, axis, sizes):
128+
if (sizes < 0).any():
129+
raise ValueError("Split sizes cannot be negative")
130+
axis = axis.item()
131+
split_indices = np.cumsum(sizes)
132+
if split_indices[-1] != x.shape[axis]:
133+
raise ValueError(
134+
f"Split sizes sum to {split_indices[-1]}; expected {x.shape[axis]}"
135+
)
136+
return np.split(x, split_indices[:-1], axis=axis)
129137

130-
return split
138+
cache_version = 1
139+
return split, cache_version
131140

132141

133142
@register_funcify_default_op_cache_key(ExtractDiag)

tests/link/numba/test_tensor_basic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tests.link.numba.test_basic import (
1111
compare_numba_and_py,
1212
compare_shape_dtype,
13+
numba_mode,
1314
)
1415
from tests.tensor.test_basic import check_alloc_runtime_broadcast
1516

@@ -245,6 +246,18 @@ def test_Split_view():
245246
)
246247

247248

249+
def test_split_errors():
250+
x = pt.dvector("x", shape=(5,))
251+
splits = pt.tensor(shape=(3,), dtype="int64")
252+
outs = pt.split(x, splits)
253+
fn = function([x, splits], outs, mode=numba_mode)
254+
test_x = np.zeros((5,))
255+
with pytest.raises(ValueError, match="Split sizes sum to 4; expected 5"):
256+
fn(test_x, np.array([1, 2, 1], dtype="int64"))
257+
with pytest.raises(ValueError, match="Split sizes cannot be negative"):
258+
fn(test_x, np.array([2, 4, -1], dtype="int64"))
259+
260+
248261
@pytest.mark.parametrize(
249262
"val, offset",
250263
[

0 commit comments

Comments
 (0)