Skip to content

Commit 0433078

Browse files
committed
Numba Blockwise: Fix OpFromGraph as core_op
1 parent 3ff7603 commit 0433078

File tree

4 files changed

+27
-5
lines changed

4 files changed

+27
-5
lines changed

pytensor/compile/builders.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Callable, Sequence
55
from copy import copy
66
from functools import partial
7+
from itertools import chain
78
from typing import Union, cast
89

910
from pytensor.compile.function import function
@@ -47,11 +48,15 @@ def infer_shape(outs, inputs, input_shapes):
4748
assert len(inp_shp) == inp.type.ndim
4849

4950
shape_feature = ShapeFeature()
50-
shape_feature.on_attach(FunctionGraph([], []))
51+
fgraph = FunctionGraph([], [], features=[shape_feature])
52+
for v in chain.from_iterable(s for s in input_shapes if s is not None):
53+
# Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before
54+
if (node := v.owner) is not None:
55+
fgraph.import_node(node, import_missing=True)
5156

5257
# Initialize shape_of with the input shapes
5358
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
54-
shape_feature.set_shape(inp, inp_shp)
59+
shape_feature.set_shape(inp, inp_shp, override=True)
5560

5661
def local_traverse(out):
5762
"""

pytensor/link/numba/dispatch/blockwise.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
3636
core_op_fn, core_op_key = numba_funcify_and_cache_key(
3737
core_op,
3838
node=core_node,
39-
parent_node=node,
4039
**kwargs,
4140
)
4241
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
273273
scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key(
274274
op.scalar_op,
275275
node=scalar_node,
276-
parent_node=node,
277276
**kwargs,
278277
)
279278

tests/link/numba/test_compile_ops.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from pytensor import OpFromGraph, config, function, ifelse
55
from pytensor import tensor as pt
66
from pytensor.compile import ViewOp
7+
from pytensor.graph import vectorize_graph
78
from pytensor.raise_op import assert_op
89
from pytensor.scalar import Add
9-
from pytensor.tensor import matrix
10+
from pytensor.tensor import dmatrix, dtensor3, matrix
1011
from pytensor.tensor.elemwise import Elemwise
1112
from tests.link.numba.test_basic import compare_numba_and_py
1213

@@ -171,6 +172,24 @@ def test_ofg_aliased_outputs():
171172
np.testing.assert_allclose(res, np.ones((2, 2)))
172173

173174

175+
def test_ofg_elemwise_regression():
176+
# Regression bug for https://github.com/pymc-devs/pytensor/issues/1507
177+
x = dmatrix("x", shape=(None, None))
178+
z = OpFromGraph(
179+
inputs=[x],
180+
outputs=[x + 1],
181+
)(x)
182+
183+
x_batched = dtensor3("X_batched", shape=(None, None, None))
184+
z_batched = vectorize_graph(z, {x: x_batched})
185+
compare_numba_and_py(
186+
[x_batched],
187+
[z_batched],
188+
[np.random.normal(size=(3, 2, 4))],
189+
eval_obj_mode=False,
190+
)
191+
192+
174193
def test_check_and_raise():
175194
x = pt.vector()
176195
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)

0 commit comments

Comments
 (0)