Skip to content

Commit 5a2efea

Browse files
committed
Numba linalg: Handle empty inputs
1 parent 9df47bb commit 5a2efea

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,15 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
362362

363363
@numba_basic.numba_njit(cache=False)
364364
def lu_factor_tridiagonal(dl, d, du):
365+
if d.size == 0:
366+
return (
367+
np.zeros(dl.shape, dtype=out_dtype),
368+
np.zeros(d.shape, dtype=out_dtype),
369+
np.zeros(du.shape, dtype=out_dtype),
370+
np.zeros(d.shape, dtype=out_dtype),
371+
np.zeros(d.shape, dtype="int32"),
372+
)
373+
365374
if must_cast_inputs[0]:
366375
d = d.astype(out_dtype)
367376
if must_cast_inputs[1]:
@@ -389,6 +398,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
389398
return generate_fallback_impl(op, node=node)
390399
out_dtype = node.outputs[0].type.numpy_dtype
391400

401+
b_ndim = op.b_ndim
392402
overwrite_b = op.overwrite_b
393403
transposed = op.transposed
394404

@@ -401,6 +411,12 @@ def numba_funcify_SolveLUFactorTridiagonal(
401411

402412
@numba_basic.numba_njit(cache=False)
403413
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
414+
if d.size == 0:
415+
if b_ndim == 1:
416+
return np.zeros(d.shape, dtype=out_dtype)
417+
else:
418+
return np.zeros((d.shape[0], b.shape[1]), dtype=out_dtype)
419+
404420
if must_cast_inputs[0]:
405421
dl = dl.astype(out_dtype)
406422
if must_cast_inputs[1]:

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def numba_funcify_Cholesky(op, node, **kwargs):
7474

7575
@numba_basic.numba_njit
7676
def cholesky(a):
77+
if a.size == 0:
78+
return np.zeros(a.shape, dtype=out_dtype)
79+
7780
if discrete_inp:
7881
a = a.astype(out_dtype)
7982
elif check_finite:
@@ -114,7 +117,8 @@ def numba_pivot_to_permutation(piv):
114117

115118
return np.argsort(p_inv)
116119

117-
return numba_pivot_to_permutation
120+
cache_key = 1
121+
return numba_pivot_to_permutation, cache_key
118122

119123

120124
@numba_funcify.register(LU)
@@ -134,6 +138,18 @@ def numba_funcify_LU(op, node, **kwargs):
134138

135139
@numba_basic.numba_njit
136140
def lu(a):
141+
if a.size == 0:
142+
L = np.zeros(a.shape, dtype=a.dtype)
143+
U = np.zeros(a.shape, dtype=a.dtype)
144+
if permute_l:
145+
return L, U
146+
elif p_indices:
147+
P = np.zeros(a.shape[0], dtype="int32")
148+
return P, L, U
149+
else:
150+
P = np.zeros(a.shape, dtype=a.dtype)
151+
return P, L, U
152+
137153
if discrete_inp:
138154
a = a.astype(out_dtype)
139155
elif check_finite:
@@ -187,6 +203,12 @@ def numba_funcify_LUFactor(op, node, **kwargs):
187203

188204
@numba_basic.numba_njit
189205
def lu_factor(a):
206+
if a.size == 0:
207+
return (
208+
np.zeros(a.shape, dtype=out_dtype),
209+
np.zeros(a.shape[0], dtype="int32"),
210+
)
211+
190212
if discrete_inp:
191213
a = a.astype(out_dtype)
192214
elif check_finite:
@@ -226,7 +248,7 @@ def block_diag(*arrs):
226248

227249
@numba_funcify.register(Solve)
228250
def numba_funcify_Solve(op, node, **kwargs):
229-
A_dtype, b_dtype = (i.numpy_dtype for i in node.inputs)
251+
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
230252
out_dtype = node.outputs[0].type.numpy_dtype
231253

232254
if A_dtype.kind == "c" or b_dtype.kind == "c":
@@ -269,6 +291,9 @@ def numba_funcify_Solve(op, node, **kwargs):
269291

270292
@numba_basic.numba_njit
271293
def solve(a, b):
294+
if b.size == 0:
295+
return np.zeros(b.shape, dtype=out_dtype)
296+
272297
if must_cast_A:
273298
a = a.astype(out_dtype)
274299
if must_cast_B:
@@ -297,7 +322,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
297322
overwrite_b = op.overwrite_b
298323
b_ndim = op.b_ndim
299324

300-
A_dtype, b_dtype = (i.numpy_dtype for i in node.inputs)
325+
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
301326
out_dtype = node.outputs[0].type.numpy_dtype
302327

303328
if A_dtype.kind == "c" or b_dtype.kind == "c":
@@ -311,6 +336,8 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
311336

312337
@numba_basic.numba_njit
313338
def solve_triangular(a, b):
339+
if b.size == 0:
340+
return np.zeros(b.shape, dtype=out_dtype)
314341
if must_cast_A:
315342
a = a.astype(out_dtype)
316343
if must_cast_B:
@@ -360,6 +387,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
360387

361388
@numba_basic.numba_njit
362389
def cho_solve(c, b):
390+
if b.size == 0:
391+
return np.zeros(b.shape, dtype=out_dtype)
363392
if must_cast_c:
364393
c = c.astype(out_dtype)
365394
if check_finite:

tests/link/numba/test_slinalg.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
LUFactor,
1717
Solve,
1818
SolveTriangular,
19+
cho_solve,
20+
cholesky,
21+
lu,
22+
lu_factor,
23+
lu_solve,
24+
solve,
25+
solve_triangular,
1926
)
2027
from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode
2128

@@ -483,6 +490,27 @@ def test_lu_solve(
483490
# Can never destroy non-contiguous inputs
484491
np.testing.assert_allclose(b_val_not_contig, b_val)
485492

493+
@pytest.mark.parametrize(
494+
"solve_op",
495+
[solve, solve_triangular, cho_solve, lu_solve],
496+
ids=lambda x: x.__name__,
497+
)
498+
def test_empty(self, solve_op):
499+
a = pt.matrix("x")
500+
b = pt.vector("b")
501+
if solve_op is cho_solve:
502+
out = solve_op((a, True), b)
503+
elif solve_op is lu_solve:
504+
out = solve_op((a, b.astype("int32")), b)
505+
else:
506+
out = solve_op(a, b)
507+
compare_numba_and_py(
508+
[a, b],
509+
[out],
510+
[np.zeros((0, 0)), np.zeros(0)],
511+
eval_obj_mode=False, # pivot_to_permutation seems to still be jitted despite the monkey patching
512+
)
513+
486514

487515
class TestDecompositions:
488516
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@@ -750,6 +778,20 @@ def test_qr(self, mode, pivoting, overwrite_a):
750778
# Cannot destroy non-contiguous input
751779
np.testing.assert_allclose(val_not_contig, A_val)
752780

781+
@pytest.mark.parametrize(
782+
"decomp_op", (cholesky, lu, lu_factor), ids=lambda x: x.__name__
783+
)
784+
def test_empty(self, decomp_op):
785+
x = pt.matrix("x")
786+
outs = decomp_op(x)
787+
if not isinstance(outs, tuple | list):
788+
outs = [outs]
789+
compare_numba_and_py(
790+
[x],
791+
outs,
792+
[np.zeros((0, 0))],
793+
)
794+
753795

754796
def test_block_diag():
755797
A = pt.matrix("A")

0 commit comments

Comments
 (0)