Skip to content

Commit 9df47bb

Browse files
committed
Numba linalg: handle dtypes more strictly
1 parent be76954 commit 9df47bb

File tree

13 files changed

+285
-166
lines changed

13 files changed

+285
-166
lines changed

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import numpy as np
22
from numba.core.extending import overload
33
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
4+
from numba.types import Float
45
from scipy import linalg
56

67
from pytensor.link.numba.dispatch.linalg._LAPACK import (
78
_LAPACK,
8-
_get_underlying_float,
99
int_ptr_to_val,
1010
val_to_int_ptr,
1111
)
12-
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
12+
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
1313

1414

1515
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
@@ -24,9 +24,9 @@ def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
2424
@overload(_cholesky)
2525
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
2626
ensure_lapack()
27-
_check_scipy_linalg_matrix(A, "cholesky")
27+
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky")
2828
dtype = A.dtype
29-
w_type = _get_underlying_float(dtype)
29+
3030
numba_potrf = _LAPACK().numba_xpotrf(dtype)
3131

3232
def impl(A, lower=0, overwrite_a=False, check_finite=True):
@@ -47,7 +47,7 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
4747
numba_potrf(
4848
UPLO,
4949
N,
50-
A_copy.view(w_type).ctypes,
50+
A_copy.ctypes,
5151
LDA,
5252
INFO,
5353
)

pytensor/link/numba/dispatch/linalg/decomposition/lu.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
import numpy as np
55
from numba.core.extending import overload
6+
from numba.core.types import Float
67
from numba.np.linalg import ensure_lapack
78
from scipy import linalg
89

910
from pytensor.link.numba.dispatch import basic as numba_basic
1011
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
11-
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
12+
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
1213

1314

1415
@numba_basic.numba_njit
@@ -116,7 +117,7 @@ def lu_impl_1(
116117
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
117118
"""
118119
ensure_lapack()
119-
_check_scipy_linalg_matrix(a, "lu")
120+
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
120121
dtype = a.dtype
121122

122123
def impl(
@@ -146,7 +147,7 @@ def lu_impl_2(
146147
"""
147148

148149
ensure_lapack()
149-
_check_scipy_linalg_matrix(a, "lu")
150+
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
150151
dtype = a.dtype
151152

152153
def impl(
@@ -179,7 +180,7 @@ def lu_impl_3(
179180
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
180181
"""
181182
ensure_lapack()
182-
_check_scipy_linalg_matrix(a, "lu")
183+
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
183184
dtype = a.dtype
184185

185186
def impl(

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,16 @@
33

44
import numpy as np
55
from numba.core.extending import overload
6+
from numba.core.types import Float
67
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
78
from scipy import linalg
89

910
from pytensor.link.numba.dispatch.linalg._LAPACK import (
1011
_LAPACK,
11-
_get_underlying_float,
1212
int_ptr_to_val,
1313
val_to_int_ptr,
1414
)
15-
from pytensor.link.numba.dispatch.linalg.utils import (
16-
_check_scipy_linalg_matrix,
17-
)
15+
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
1816

1917

2018
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
@@ -38,9 +36,8 @@ def getrf_impl(
3836
A: np.ndarray, overwrite_a: bool = False
3937
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
4038
ensure_lapack()
41-
_check_scipy_linalg_matrix(A, "getrf")
39+
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="getrf")
4240
dtype = A.dtype
43-
w_type = _get_underlying_float(dtype)
4441
numba_getrf = _LAPACK().numba_xgetrf(dtype)
4542

4643
def impl(
@@ -59,7 +56,7 @@ def impl(
5956
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
6057
INFO = val_to_int_ptr(0)
6158

62-
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
59+
numba_getrf(M, N, A_copy.ctypes, LDA, IPIV.ctypes, INFO)
6360

6461
return A_copy, IPIV, int_ptr_to_val(INFO)
6562

@@ -79,7 +76,7 @@ def lu_factor_impl(
7976
A: np.ndarray, overwrite_a: bool = False
8077
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
8178
ensure_lapack()
82-
_check_scipy_linalg_matrix(A, "lu_factor")
79+
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor")
8380

8481
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
8582
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)

pytensor/link/numba/dispatch/linalg/solve/cholesky.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import numpy as np
22
from numba.core.extending import overload
3+
from numba.core.types import Float
34
from numba.np.linalg import ensure_lapack
45
from scipy import linalg
56

67
from pytensor.link.numba.dispatch.linalg._LAPACK import (
78
_LAPACK,
8-
_get_underlying_float,
99
int_ptr_to_val,
1010
val_to_int_ptr,
1111
)
1212
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
1313
from pytensor.link.numba.dispatch.linalg.utils import (
14-
_check_scipy_linalg_matrix,
14+
_check_dtypes_match,
15+
_check_linalg_matrix,
1516
_copy_to_fortran_order_even_if_1d,
1617
_solve_check,
1718
)
@@ -31,10 +32,10 @@ def _cho_solve(
3132
@overload(_cho_solve)
3233
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
3334
ensure_lapack()
34-
_check_scipy_linalg_matrix(C, "cho_solve")
35-
_check_scipy_linalg_matrix(B, "cho_solve")
35+
_check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve")
36+
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve")
37+
_check_dtypes_match((C, B), func_name="cho_solve")
3638
dtype = C.dtype
37-
w_type = _get_underlying_float(dtype)
3839
numba_potrs = _LAPACK().numba_xpotrs(dtype)
3940

4041
def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
@@ -71,9 +72,9 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
7172
UPLO,
7273
N,
7374
NRHS,
74-
C_f.view(w_type).ctypes,
75+
C_f.ctypes,
7576
LDA,
76-
B_copy.view(w_type).ctypes,
77+
B_copy.ctypes,
7778
LDB,
7879
INFO,
7980
)

pytensor/link/numba/dispatch/linalg/solve/general.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
import numpy as np
44
from numba.core.extending import overload
5+
from numba.core.types import Float
56
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
67
from scipy import linalg
78

89
from pytensor.link.numba.dispatch.linalg._LAPACK import (
910
_LAPACK,
10-
_get_underlying_float,
1111
int_ptr_to_val,
1212
val_to_int_ptr,
1313
)
@@ -16,7 +16,8 @@
1616
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
1717
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
1818
from pytensor.link.numba.dispatch.linalg.utils import (
19-
_check_scipy_linalg_matrix,
19+
_check_dtypes_match,
20+
_check_linalg_matrix,
2021
_solve_check,
2122
)
2223

@@ -37,9 +38,8 @@ def xgecon_impl(
3738
Compute the condition number of a matrix A.
3839
"""
3940
ensure_lapack()
40-
_check_scipy_linalg_matrix(A, "gecon")
41+
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="gecon")
4142
dtype = A.dtype
42-
w_type = _get_underlying_float(dtype)
4343
numba_gecon = _LAPACK().numba_xgecon(dtype)
4444

4545
def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
@@ -58,11 +58,11 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
5858
numba_gecon(
5959
NORM,
6060
N,
61-
A_copy.view(w_type).ctypes,
61+
A_copy.ctypes,
6262
LDA,
63-
A_NORM.view(w_type).ctypes,
64-
RCOND.view(w_type).ctypes,
65-
WORK.view(w_type).ctypes,
63+
A_NORM.ctypes,
64+
RCOND.ctypes,
65+
WORK.ctypes,
6666
IWORK.ctypes,
6767
INFO,
6868
)
@@ -106,8 +106,9 @@ def solve_gen_impl(
106106
transposed: bool,
107107
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
108108
ensure_lapack()
109-
_check_scipy_linalg_matrix(A, "solve")
110-
_check_scipy_linalg_matrix(B, "solve")
109+
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
110+
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
111+
_check_dtypes_match((A, B), "solve")
111112

112113
def impl(
113114
A: np.ndarray,

pytensor/link/numba/dispatch/linalg/solve/lu_solve.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33

44
import numpy as np
55
from numba.core.extending import overload
6+
from numba.core.types import Float, int32
67
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
78
from scipy import linalg
89

910
from pytensor.link.numba.dispatch.linalg._LAPACK import (
1011
_LAPACK,
11-
_get_underlying_float,
1212
int_ptr_to_val,
1313
val_to_int_ptr,
1414
)
1515
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
1616
from pytensor.link.numba.dispatch.linalg.utils import (
17-
_check_scipy_linalg_matrix,
17+
_check_dtypes_match,
18+
_check_linalg_matrix,
1819
_copy_to_fortran_order_even_if_1d,
1920
_solve_check,
2021
_trans_char_to_int,
@@ -44,10 +45,11 @@ def getrs_impl(
4445
[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int]
4546
]:
4647
ensure_lapack()
47-
_check_scipy_linalg_matrix(LU, "getrs")
48-
_check_scipy_linalg_matrix(B, "getrs")
48+
_check_linalg_matrix(LU, ndim=2, dtype=Float, func_name="getrs")
49+
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="getrs")
50+
_check_dtypes_match((LU, B), func_name="getrs")
51+
_check_linalg_matrix(IPIV, ndim=1, dtype=int32, func_name="getrs")
4952
dtype = LU.dtype
50-
w_type = _get_underlying_float(dtype)
5153
numba_getrs = _LAPACK().numba_xgetrs(dtype)
5254

5355
def impl(
@@ -84,10 +86,10 @@ def impl(
8486
TRANS,
8587
N,
8688
NRHS,
87-
LU.view(w_type).ctypes,
89+
LU.ctypes,
8890
LDA,
8991
IPIV.ctypes,
90-
B_copy.view(w_type).ctypes,
92+
B_copy.ctypes,
9193
LDB,
9294
INFO,
9395
)
@@ -124,8 +126,10 @@ def lu_solve_impl(
124126
check_finite: bool,
125127
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
126128
ensure_lapack()
127-
_check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve")
128-
_check_scipy_linalg_matrix(b, "lu_solve")
129+
lu, _piv = lu_and_piv
130+
_check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve")
131+
_check_linalg_matrix(b, ndim=(1, 2), dtype=Float, func_name="lu_solve")
132+
_check_dtypes_match((lu, b), func_name="lu_solve")
129133

130134
def impl(
131135
lu: np.ndarray,

pytensor/link/numba/dispatch/linalg/solve/norm.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
import numpy as np
44
from numba.core.extending import overload
5+
from numba.core.types import Float
56
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
67

78
from pytensor.link.numba.dispatch.linalg._LAPACK import (
89
_LAPACK,
9-
_get_underlying_float,
1010
val_to_int_ptr,
1111
)
12-
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
12+
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
1313

1414

1515
def _xlange(A: np.ndarray, order: str | None = None) -> float:
@@ -28,9 +28,8 @@ def xlange_impl(
2828
largest absolute value of a matrix A.
2929
"""
3030
ensure_lapack()
31-
_check_scipy_linalg_matrix(A, "norm")
31+
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="norm")
3232
dtype = A.dtype
33-
w_type = _get_underlying_float(dtype)
3433
numba_lange = _LAPACK().numba_xlange(dtype)
3534

3635
def impl(A: np.ndarray, order: str | None = None):
@@ -49,9 +48,7 @@ def impl(A: np.ndarray, order: str | None = None):
4948
)
5049
WORK = np.empty(_M, dtype=dtype) # type: ignore
5150

52-
result = numba_lange(
53-
NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes
54-
)
51+
result = numba_lange(NORM, M, N, A_copy.ctypes, LDA, WORK.ctypes)
5552

5653
return result
5754

0 commit comments

Comments
 (0)