Skip to content

Commit 0226ce6

Browse files
committed
.wip qr complex type support
1 parent eab0961 commit 0226ce6

File tree

4 files changed

+63
-38
lines changed

4 files changed

+63
-38
lines changed

pytensor/link/numba/dispatch/linalg/decomposition/qr.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
from numba.core.extending import overload
5+
from numba.core.types import Complex, Float
56
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
67
from scipy.linalg import get_lapack_funcs, qr
78

@@ -11,6 +12,7 @@
1112
int_ptr_to_val,
1213
val_to_int_ptr,
1314
)
15+
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
1416

1517

1618
def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
@@ -184,6 +186,7 @@ def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
184186
@overload(_xungqr)
185187
def xungqr_impl(A, tau, overwrite_a, lwork):
186188
ensure_lapack()
189+
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="qr")
187190
dtype = A.dtype
188191
w_type = _get_underlying_float(dtype)
189192
ungqr = _LAPACK().numba_xungqr(dtype)
@@ -211,7 +214,7 @@ def impl(A, tau, overwrite_a, lwork):
211214
val_to_int_ptr(M),
212215
val_to_int_ptr(N),
213216
val_to_int_ptr(K),
214-
A_copy.view(w_type).ctypes,
217+
A_copy.T.view(w_type).T.ctypes,
215218
LDA,
216219
tau.view(w_type).ctypes,
217220
WORK.view(w_type).ctypes,
@@ -378,10 +381,15 @@ def qr_full_pivot_impl(
378381
x, mode="full", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
379382
):
380383
ensure_lapack()
384+
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
381385
dtype = x.dtype
382386
w_type = _get_underlying_float(dtype)
383387
geqp3 = _LAPACK().numba_xgeqp3(dtype)
384-
orgqr = _LAPACK().numba_xorgqr(dtype)
388+
orgqr = (
389+
_LAPACK().numba_xorgqr(dtype)
390+
if isinstance(dtype, Float)
391+
else _LAPACK().numba_xungqr(dtype)
392+
)
385393

386394
def impl(
387395
x,
@@ -420,8 +428,7 @@ def impl(
420428
val_to_int_ptr(-1),
421429
val_to_int_ptr(1),
422430
)
423-
lwork_val = int(WORK.item())
424-
431+
lwork_val = int(WORK.item().real)
425432
else:
426433
lwork_val = lwork
427434

@@ -460,14 +467,14 @@ def impl(
460467
val_to_int_ptr(M),
461468
val_to_int_ptr(Q_in.shape[1]),
462469
val_to_int_ptr(K),
463-
Q_in.view(w_type).ctypes,
470+
Q_in.T.view(w_type).T.ctypes,
464471
val_to_int_ptr(M),
465472
TAU.view(w_type).ctypes,
466473
WORKQ.view(w_type).ctypes,
467474
val_to_int_ptr(-1),
468475
val_to_int_ptr(1),
469476
)
470-
lwork_q = int(WORKQ.item())
477+
lwork_q = int(WORKQ.item().real)
471478

472479
else:
473480
lwork_q = lwork
@@ -478,7 +485,7 @@ def impl(
478485
val_to_int_ptr(M),
479486
val_to_int_ptr(Q_in.shape[1]),
480487
val_to_int_ptr(K),
481-
Q_in.view(w_type).ctypes,
488+
Q_in.T.view(w_type).T.ctypes,
482489
val_to_int_ptr(M),
483490
TAU.view(w_type).ctypes,
484491
WORKQ.view(w_type).ctypes,
@@ -495,10 +502,15 @@ def qr_full_no_pivot_impl(
495502
x, mode="full", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
496503
):
497504
ensure_lapack()
505+
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
498506
dtype = x.dtype
499507
w_type = _get_underlying_float(dtype)
500508
geqrf = _LAPACK().numba_xgeqrf(dtype)
501-
orgqr = _LAPACK().numba_xorgqr(dtype)
509+
orgqr = (
510+
_LAPACK().numba_xorgqr(dtype)
511+
if isinstance(dtype, Float)
512+
else _LAPACK().numba_xungqr(dtype)
513+
)
502514

503515
def impl(
504516
x,
@@ -528,14 +540,14 @@ def impl(
528540
geqrf(
529541
val_to_int_ptr(M),
530542
val_to_int_ptr(N),
531-
x_copy.view(w_type).ctypes,
543+
x_copy.T.view(w_type).T.ctypes,
532544
LDA,
533545
TAU.view(w_type).ctypes,
534546
WORK.view(w_type).ctypes,
535547
val_to_int_ptr(-1),
536548
val_to_int_ptr(1),
537549
)
538-
lwork_val = int(WORK.item())
550+
lwork_val = int(WORK.item().real)
539551
else:
540552
lwork_val = lwork
541553

@@ -545,7 +557,7 @@ def impl(
545557
geqrf(
546558
val_to_int_ptr(M),
547559
val_to_int_ptr(N),
548-
x_copy.view(w_type).ctypes,
560+
x_copy.T.view(w_type).T.ctypes,
549561
LDA,
550562
TAU.view(w_type).ctypes,
551563
WORK.view(w_type).ctypes,
@@ -573,14 +585,14 @@ def impl(
573585
val_to_int_ptr(M),
574586
val_to_int_ptr(Q_in.shape[1]),
575587
val_to_int_ptr(K),
576-
Q_in.view(w_type).ctypes,
588+
Q_in.T.view(w_type).T.ctypes,
577589
val_to_int_ptr(M),
578590
TAU.view(w_type).ctypes,
579591
WORKQ.view(w_type).ctypes,
580592
val_to_int_ptr(-1),
581593
val_to_int_ptr(1),
582594
)
583-
lwork_q = int(WORKQ.item())
595+
lwork_q = int(WORKQ.real.item())
584596
else:
585597
lwork_q = lwork
586598

@@ -591,7 +603,7 @@ def impl(
591603
val_to_int_ptr(M), # M
592604
val_to_int_ptr(Q_in.shape[1]), # N
593605
val_to_int_ptr(K), # K
594-
Q_in.view(w_type).ctypes, # A
606+
Q_in.T.view(w_type).T.ctypes, # A
595607
val_to_int_ptr(M), # LDA
596608
TAU.view(w_type).ctypes, # TAU
597609
WORKQ.view(w_type).ctypes, # WORK
@@ -608,6 +620,7 @@ def qr_r_pivot_impl(
608620
x, mode="r", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
609621
):
610622
ensure_lapack()
623+
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
611624
dtype = x.dtype
612625
w_type = _get_underlying_float(dtype)
613626
geqp3 = _LAPACK().numba_xgeqp3(dtype)
@@ -640,15 +653,15 @@ def impl(
640653
geqp3(
641654
val_to_int_ptr(M),
642655
val_to_int_ptr(N),
643-
x_copy.view(w_type).ctypes,
656+
x_copy.T.view(w_type).T.ctypes,
644657
LDA,
645658
JPVT.ctypes,
646659
TAU.view(w_type).ctypes,
647660
WORK.view(w_type).ctypes,
648661
val_to_int_ptr(-1),
649662
val_to_int_ptr(1),
650663
)
651-
lwork_val = int(WORK.item())
664+
lwork_val = int(WORK.item().real)
652665
else:
653666
lwork_val = lwork
654667

@@ -658,7 +671,7 @@ def impl(
658671
geqp3(
659672
val_to_int_ptr(M),
660673
val_to_int_ptr(N),
661-
x_copy.view(w_type).ctypes,
674+
x_copy.T.view(w_type).T.ctypes,
662675
LDA,
663676
JPVT.ctypes,
664677
TAU.view(w_type).ctypes,
@@ -683,6 +696,7 @@ def qr_r_no_pivot_impl(
683696
x, mode="r", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
684697
):
685698
ensure_lapack()
699+
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
686700
dtype = x.dtype
687701
w_type = _get_underlying_float(dtype)
688702
geqrf = _LAPACK().numba_xgeqrf(dtype)
@@ -714,14 +728,14 @@ def impl(
714728
geqrf(
715729
val_to_int_ptr(M),
716730
val_to_int_ptr(N),
717-
x_copy.view(w_type).ctypes,
731+
x_copy.T.view(w_type).T.ctypes,
718732
LDA,
719733
TAU.view(w_type).ctypes,
720734
WORK.view(w_type).ctypes,
721735
val_to_int_ptr(-1),
722736
val_to_int_ptr(1),
723737
)
724-
lwork_val = int(WORK.item())
738+
lwork_val = int(WORK.item().real)
725739
else:
726740
lwork_val = lwork
727741

@@ -731,7 +745,7 @@ def impl(
731745
geqrf(
732746
val_to_int_ptr(M),
733747
val_to_int_ptr(N),
734-
x_copy.view(w_type).ctypes,
748+
x_copy.T.view(w_type).T.ctypes,
735749
LDA,
736750
TAU.view(w_type).ctypes,
737751
WORK.view(w_type).ctypes,
@@ -755,6 +769,7 @@ def qr_raw_no_pivot_impl(
755769
x, mode="raw", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
756770
):
757771
ensure_lapack()
772+
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
758773
dtype = x.dtype
759774
w_type = _get_underlying_float(dtype)
760775
geqrf = _LAPACK().numba_xgeqrf(dtype)
@@ -786,14 +801,14 @@ def impl(
786801
geqrf(
787802
val_to_int_ptr(M),
788803
val_to_int_ptr(N),
789-
x_copy.view(w_type).ctypes,
804+
x_copy.T.view(w_type).T.ctypes,
790805
LDA,
791806
TAU.view(w_type).ctypes,
792807
WORK.view(w_type).ctypes,
793808
val_to_int_ptr(-1),
794809
val_to_int_ptr(1),
795810
)
796-
lwork_val = int(WORK.item())
811+
lwork_val = int(WORK.item().real)
797812
else:
798813
lwork_val = lwork
799814

@@ -803,7 +818,7 @@ def impl(
803818
geqrf(
804819
val_to_int_ptr(M),
805820
val_to_int_ptr(N),
806-
x_copy.view(w_type).ctypes,
821+
x_copy.T.view(w_type).T.ctypes,
807822
LDA,
808823
TAU.view(w_type).ctypes,
809824
WORK.view(w_type).ctypes,
@@ -826,6 +841,7 @@ def qr_raw_pivot_impl(
826841
x, mode="raw", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
827842
):
828843
ensure_lapack()
844+
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
829845
dtype = x.dtype
830846
w_type = _get_underlying_float(dtype)
831847
geqp3 = _LAPACK().numba_xgeqp3(dtype)
@@ -858,15 +874,15 @@ def impl(
858874
geqp3(
859875
val_to_int_ptr(M),
860876
val_to_int_ptr(N),
861-
x_copy.view(w_type).ctypes,
877+
x_copy.T.view(w_type).T.ctypes,
862878
LDA,
863879
JPVT.ctypes,
864880
TAU.view(w_type).ctypes,
865881
WORK.view(w_type).ctypes,
866882
val_to_int_ptr(-1),
867883
val_to_int_ptr(1),
868884
)
869-
lwork_val = int(WORK.item())
885+
lwork_val = int(WORK.item().real)
870886
else:
871887
lwork_val = lwork
872888

@@ -876,7 +892,7 @@ def impl(
876892
geqp3(
877893
val_to_int_ptr(M),
878894
val_to_int_ptr(N),
879-
x_copy.view(w_type).ctypes,
895+
x_copy.T.view(w_type).T.ctypes,
880896
LDA,
881897
JPVT.ctypes,
882898
TAU.view(w_type).ctypes,

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
Solve,
4343
SolveTriangular,
4444
)
45-
from pytensor.tensor.type import complex_dtypes, integer_dtypes
4645

4746

4847
@numba_funcify.register(Cholesky)
@@ -418,12 +417,15 @@ def numba_funcify_QR(op, node, **kwargs):
418417
pivoting = op.pivoting
419418
overwrite_a = op.overwrite_a
420419

421-
dtype = node.inputs[0].dtype
422-
if dtype in complex_dtypes:
423-
return generate_fallback_impl(op, node=node, **kwargs)
420+
in_dtype = node.inputs[0].type.numpy_dtype
421+
# if in_dtype.kind == "c":
422+
# return generate_fallback_impl(op, node=node, **kwargs)
423+
424+
integer_input = in_dtype.kind in "ibu"
425+
if integer_input and config.compiler_verbose:
426+
print("QR requires casting discrete input to float") # noqa: T201
424427

425-
integer_input = dtype in integer_dtypes
426-
in_dtype = config.floatX if integer_input else dtype
428+
out_dtype = node.outputs[0].type.numpy_dtype
427429

428430
@numba_basic.numba_njit
429431
def qr(a):
@@ -434,7 +436,7 @@ def qr(a):
434436
)
435437

436438
if integer_input:
437-
a = a.astype(in_dtype)
439+
a = a.astype(out_dtype)
438440

439441
if (mode == "full" or mode == "economic") and pivoting:
440442
Q, R, P = _qr_full_pivot(

pytensor/tensor/slinalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1824,7 +1824,10 @@ def make_node(self, x):
18241824
K = None
18251825

18261826
in_dtype = x.type.numpy_dtype
1827-
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
1827+
if in_dtype.kind in "ibu":
1828+
out_dtype = "float64" if in_dtype.itemsize > 2 else "float32"
1829+
else:
1830+
out_dtype = "float64" if in_dtype.itemsize > 4 else "float32"
18281831

18291832
match self.mode:
18301833
case "full":

tests/link/numba/test_slinalg.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -718,17 +718,21 @@ def test_lu_factor(self, overwrite_a):
718718
ids=["economic", "full_pivot", "r", "raw_pivot"],
719719
)
720720
@pytest.mark.parametrize(
721-
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
721+
"overwrite_a", [False, True][:1], ids=["overwrite_a", "no_overwrite"][:1]
722722
)
723-
def test_qr(self, mode, pivoting, overwrite_a):
723+
@pytest.mark.parametrize("complex", (False, True)[1:])
724+
def test_qr(self, mode, pivoting, overwrite_a, complex):
724725
shape = (5, 5)
725726
rng = np.random.default_rng()
726727
A = pt.tensor(
727728
"A",
728729
shape=shape,
729-
dtype=config.floatX,
730+
dtype="complex128" if complex else "float64",
730731
)
731-
A_val = rng.normal(size=shape).astype(config.floatX)
732+
if complex:
733+
A_val = rng.normal(size=(*shape, 2)).view(dtype=A.dtype).squeeze(-1)
734+
else:
735+
A_val = rng.normal(size=shape).astype(A.dtype)
732736

733737
qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting)
734738

0 commit comments

Comments
 (0)