Skip to content
Merged
30 changes: 0 additions & 30 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,36 +224,6 @@ def codegen(context, builder, signature, args):
return sig, codegen


def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""

if (
all(inp.type.dtype == out_dtype for inp in inputs)
and np.dtype(out_dtype).kind == "f"
):

@numba_njit(inline="always")
def inputs_cast(x):
return x

elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")

@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)

else:
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
args_dtype = np.dtype(f"f{args_dtype_sz}")

@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)

return inputs_cast


@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
Expand Down
16 changes: 14 additions & 2 deletions pytensor/link/numba/dispatch/linalg/_LAPACK.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from numba.core import cgutils, types
from numba.core.extending import get_cython_function_address, intrinsic
from numba.core.types import Complex
from numba.np.linalg import ensure_lapack, get_blas_kind


Expand Down Expand Up @@ -486,8 +487,7 @@ def numba_xgeqp3(cls, dtype):
Used in QR decomposition with pivoting.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
functype = ctypes.CFUNCTYPE(
None,
ctype_args = (
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
Expand All @@ -496,8 +496,20 @@ def numba_xgeqp3(cls, dtype):
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
)

if isinstance(dtype, Complex):
ctype_args = (
*ctype_args,
float_pointer, # RWORK)
)

functype = ctypes.CFUNCTYPE(
None,
*ctype_args,
_ptr_int, # INFO
)

return functype(lapack_ptr)

@classmethod
Expand Down
34 changes: 22 additions & 12 deletions pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numba.types import Float
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix


def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
Expand All @@ -24,30 +24,36 @@ def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky")
dtype = A.dtype
w_type = _get_underlying_float(dtype)

numba_potrf = _LAPACK().numba_xpotrf(dtype)

def impl(A, lower=0, overwrite_a=False, check_finite=True):
def impl(A, lower=False, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

transposed = False
if overwrite_a and A.flags.f_contiguous:
A_copy = A
elif overwrite_a and A.flags.c_contiguous:
# We can work on the transpose of A directly
A_copy = A.T
transposed = True
lower = not lower
else:
A_copy = _copy_to_fortran_order(A)

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
A_copy.ctypes,
LDA,
INFO,
)
Expand All @@ -61,6 +67,10 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
for i in range(j + 1, _N):
A_copy[i, j] = 0.0

return A_copy, int_ptr_to_val(INFO)
info_int = int_ptr_to_val(INFO)

if transposed:
return A_copy.T, info_int
return A_copy, info_int

return impl
15 changes: 8 additions & 7 deletions pytensor/link/numba/dispatch/linalg/decomposition/lu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@

import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix


@numba_basic.numba_njit
def _pivot_to_permutation(p, dtype):
p_inv = np.arange(len(p)).astype(dtype)
def _pivot_to_permutation(p):
p_inv = np.arange(len(p))
for i in range(len(p)):
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
return p_inv
Expand All @@ -29,7 +30,7 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):

# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
IPIV = IPIV - 1
p_inv = _pivot_to_permutation(IPIV, dtype=dtype)
p_inv = _pivot_to_permutation(IPIV)
perm = np.argsort(p_inv).astype("int32")

return perm, L, U
Expand Down Expand Up @@ -116,7 +117,7 @@ def lu_impl_1(
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype

def impl(
Expand Down Expand Up @@ -146,7 +147,7 @@ def lu_impl_2(
"""

ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype

def impl(
Expand Down Expand Up @@ -179,7 +180,7 @@ def lu_impl_3(
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype

def impl(
Expand Down
13 changes: 5 additions & 8 deletions pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@

import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix


def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
Expand All @@ -38,9 +36,8 @@ def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)

def impl(
Expand All @@ -59,7 +56,7 @@ def impl(
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)

numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
numba_getrf(M, N, A_copy.ctypes, LDA, IPIV.ctypes, INFO)

return A_copy, IPIV, int_ptr_to_val(INFO)

Expand All @@ -79,7 +76,7 @@ def lu_factor_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "lu_factor")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor")

def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
Expand Down
Loading
Loading