Skip to content

Commit d4a0433

Browse files
committed
compare_numba_and_py: Check for accidental input mutation
1 parent fbee416 commit d4a0433

File tree

5 files changed

+62
-10
lines changed

5 files changed

+62
-10
lines changed

tests/link/numba/test_basic.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import copy
23
from collections.abc import Callable, Iterable
34
from typing import TYPE_CHECKING, Any
45
from unittest import mock
@@ -232,8 +233,19 @@ def assert_fn(x, y):
232233
):
233234
raise ValueError("Inputs must be root variables")
234235

236+
test_input_deepcopy = None
237+
if not inplace:
238+
test_input_deepcopy = [
239+
i.copy() if isinstance(i, np.ndarray) else copy.deepcopy(i)
240+
for i in test_inputs
241+
]
242+
235243
pytensor_py_fn = function(
236-
graph_inputs, graph_outputs, mode=py_mode, accept_inplace=True, updates=updates
244+
graph_inputs,
245+
graph_outputs,
246+
mode=py_mode,
247+
accept_inplace=inplace,
248+
updates=updates,
237249
)
238250

239251
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
@@ -250,11 +262,20 @@ def assert_fn(x, y):
250262
graph_inputs,
251263
graph_outputs,
252264
mode=numba_mode,
253-
accept_inplace=True,
265+
accept_inplace=inplace,
254266
updates=updates,
255267
)
256268
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
257269
numba_res = pytensor_numba_fn(*test_inputs_copy)
270+
271+
if not inplace:
272+
# Check we did not accidentally modify the inputs inplace
273+
for test_input, test_input_copy in zip(test_inputs, test_input_deepcopy):
274+
try:
275+
assert_fn(test_input, test_input_copy)
276+
except AssertionError as e:
277+
raise AssertionError("Inputs were modified inplace") from e
278+
258279
if isinstance(graph_outputs, tuple | list):
259280
for numba_res_i, python_res_i in zip(numba_res, py_res, strict=True):
260281
assert_fn(numba_res_i, python_res_i)

tests/link/numba/test_elemwise.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,14 @@
117117
)
118118
def test_Elemwise(inputs, input_vals, output_fn):
119119
outputs = output_fn(*inputs)
120+
if not isinstance(outputs, tuple | list):
121+
outputs = [outputs]
120122

121123
compare_numba_and_py(
122124
inputs,
123125
outputs,
124126
input_vals,
127+
inplace=outputs[0].owner.op.destroy_map,
125128
)
126129

127130

tests/link/numba/test_extra_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def test_CumOp(val, axis, mode):
8484
)
8585

8686

87+
@pytest.mark.xfail(reason="Implementation works inplace!")
8788
def test_FillDiagonal():
8889
a = pt.lmatrix("a")
8990
test_a = np.zeros((10, 2), dtype="int64")

tests/link/numba/test_sparse.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy as np
24
import pytest
35
import scipy as sp
@@ -16,6 +18,23 @@
1618
pytestmark = pytest.mark.filterwarnings("error")
1719

1820

21+
def sparse_assert_fn(a, b):
22+
a_is_sparse = sp.sparse.issparse(a)
23+
assert a_is_sparse == sp.sparse.issparse(b)
24+
if a_is_sparse:
25+
assert a.format == b.format
26+
assert a.dtype == b.dtype
27+
assert a.shape == b.shape
28+
np.testing.assert_allclose(a.data, b.data, strict=True)
29+
np.testing.assert_allclose(a.indices, b.indices, strict=True)
30+
np.testing.assert_allclose(a.indptr, b.indptr, strict=True)
31+
else:
32+
np.testing.assert_allclose(a, b, strict=True)
33+
34+
35+
compare_numba_and_py_sparse = partial(compare_numba_and_py, assert_fn=sparse_assert_fn)
36+
37+
1938
def test_sparse_unboxing():
2039
@numba.njit
2140
def test_unboxing(x, y):
@@ -93,11 +112,15 @@ def test_sparse_objmode():
93112

94113
out = Dot()(x, y)
95114

96-
x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
97-
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
115+
x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX, format="csc")
116+
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX, format="csc")
98117

99118
with pytest.warns(
100119
UserWarning,
101120
match="Numba will use object mode to run SparseDot's perform method",
102121
):
103-
compare_numba_and_py([x, y], out, [x_val, y_val])
122+
compare_numba_and_py_sparse(
123+
[x, y],
124+
out,
125+
[x_val, y_val],
126+
)

tests/link/numba/test_subtensor.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def test_IncSubtensor(x, y, indices):
259259
x_pt = x.type()
260260
out_pt = set_subtensor(x_pt[indices], y, inplace=True)
261261
assert isinstance(out_pt.owner.op, IncSubtensor)
262-
compare_numba_and_py([x_pt], [out_pt], [x.data])
262+
compare_numba_and_py([x_pt], [out_pt], [x.data], inplace=True)
263263

264264

265265
@pytest.mark.parametrize(
@@ -313,13 +313,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
313313

314314
out_pt = AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices)
315315
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
316-
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data])
316+
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data], inplace=True)
317317

318318
out_pt = AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)(
319319
x_pt, y_pt, *indices
320320
)
321321
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
322-
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data])
322+
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data], inplace=True)
323323

324324

325325
@pytest.mark.parametrize(
@@ -526,7 +526,9 @@ def test_AdvancedIncSubtensor(
526526
if set_requires_objmode
527527
else contextlib.nullcontext()
528528
):
529-
fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode)
529+
fn, _ = compare_numba_and_py(
530+
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
531+
)
530532

531533
if inplace:
532534
# Test updates inplace
@@ -546,7 +548,9 @@ def test_AdvancedIncSubtensor(
546548
if inc_requires_objmode
547549
else contextlib.nullcontext()
548550
):
549-
fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode)
551+
fn, _ = compare_numba_and_py(
552+
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
553+
)
550554
if inplace:
551555
# Test updates inplace
552556
x_orig = x.copy()

0 commit comments

Comments
 (0)