Skip to content

Commit aa790e4

Browse files
ricardoV94jessegrabowski
authored andcommitted
Numba QR: Support complex dtype inputs
1 parent eab0961 commit aa790e4

File tree

5 files changed

+182
-78
lines changed

5 files changed

+182
-78
lines changed

pytensor/link/numba/dispatch/linalg/_LAPACK.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
from numba.core import cgutils, types
55
from numba.core.extending import get_cython_function_address, intrinsic
6+
from numba.core.types import Complex
67
from numba.np.linalg import ensure_lapack, get_blas_kind
78

89

@@ -486,8 +487,7 @@ def numba_xgeqp3(cls, dtype):
486487
Used in QR decomposition with pivoting.
487488
"""
488489
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
489-
functype = ctypes.CFUNCTYPE(
490-
None,
490+
ctype_args = (
491491
_ptr_int, # M
492492
_ptr_int, # N
493493
float_pointer, # A
@@ -496,8 +496,20 @@ def numba_xgeqp3(cls, dtype):
496496
float_pointer, # TAU
497497
float_pointer, # WORK
498498
_ptr_int, # LWORK
499+
)
500+
501+
if isinstance(dtype, Complex):
502+
ctype_args = (
503+
*ctype_args,
504+
float_pointer, # RWORK)
505+
)
506+
507+
functype = ctypes.CFUNCTYPE(
508+
None,
509+
*ctype_args,
499510
_ptr_int, # INFO
500511
)
512+
501513
return functype(lapack_ptr)
502514

503515
@classmethod

0 commit comments

Comments
 (0)