22
33import numpy as np
44from numba .core .extending import overload
5+ from numba .core .types import Complex , Float
56from numba .np .linalg import _copy_to_fortran_order , ensure_lapack
67from scipy .linalg import get_lapack_funcs , qr
78
1112 int_ptr_to_val ,
1213 val_to_int_ptr ,
1314)
15+ from pytensor .link .numba .dispatch .linalg .utils import _check_linalg_matrix
1416
1517
1618def _xgeqrf (A : np .ndarray , overwrite_a : bool , lwork : int ):
@@ -184,6 +186,7 @@ def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
184186@overload (_xungqr )
185187def xungqr_impl (A , tau , overwrite_a , lwork ):
186188 ensure_lapack ()
189+ _check_linalg_matrix (A , ndim = 2 , dtype = (Float , Complex ), func_name = "qr" )
187190 dtype = A .dtype
188191 w_type = _get_underlying_float (dtype )
189192 ungqr = _LAPACK ().numba_xungqr (dtype )
@@ -211,7 +214,7 @@ def impl(A, tau, overwrite_a, lwork):
211214 val_to_int_ptr (M ),
212215 val_to_int_ptr (N ),
213216 val_to_int_ptr (K ),
214- A_copy .view (w_type ).ctypes ,
217+ A_copy .T . view (w_type ). T .ctypes ,
215218 LDA ,
216219 tau .view (w_type ).ctypes ,
217220 WORK .view (w_type ).ctypes ,
@@ -378,10 +381,15 @@ def qr_full_pivot_impl(
378381 x , mode = "full" , pivoting = True , overwrite_a = False , check_finite = False , lwork = None
379382):
380383 ensure_lapack ()
384+ _check_linalg_matrix (x , ndim = 2 , dtype = (Float , Complex ), func_name = "qr" )
381385 dtype = x .dtype
382386 w_type = _get_underlying_float (dtype )
383387 geqp3 = _LAPACK ().numba_xgeqp3 (dtype )
384- orgqr = _LAPACK ().numba_xorgqr (dtype )
388+ orgqr = (
389+ _LAPACK ().numba_xorgqr (dtype )
390+ if isinstance (dtype , Float )
391+ else _LAPACK ().numba_xungqr (dtype )
392+ )
385393
386394 def impl (
387395 x ,
@@ -420,8 +428,7 @@ def impl(
420428 val_to_int_ptr (- 1 ),
421429 val_to_int_ptr (1 ),
422430 )
423- lwork_val = int (WORK .item ())
424-
431+ lwork_val = int (WORK .item ().real )
425432 else :
426433 lwork_val = lwork
427434
@@ -460,14 +467,14 @@ def impl(
460467 val_to_int_ptr (M ),
461468 val_to_int_ptr (Q_in .shape [1 ]),
462469 val_to_int_ptr (K ),
463- Q_in .view (w_type ).ctypes ,
470+ Q_in .T . view (w_type ). T .ctypes ,
464471 val_to_int_ptr (M ),
465472 TAU .view (w_type ).ctypes ,
466473 WORKQ .view (w_type ).ctypes ,
467474 val_to_int_ptr (- 1 ),
468475 val_to_int_ptr (1 ),
469476 )
470- lwork_q = int (WORKQ .item ())
477+ lwork_q = int (WORKQ .item (). real )
471478
472479 else :
473480 lwork_q = lwork
@@ -478,7 +485,7 @@ def impl(
478485 val_to_int_ptr (M ),
479486 val_to_int_ptr (Q_in .shape [1 ]),
480487 val_to_int_ptr (K ),
481- Q_in .view (w_type ).ctypes ,
488+ Q_in .T . view (w_type ). T .ctypes ,
482489 val_to_int_ptr (M ),
483490 TAU .view (w_type ).ctypes ,
484491 WORKQ .view (w_type ).ctypes ,
@@ -495,10 +502,15 @@ def qr_full_no_pivot_impl(
495502 x , mode = "full" , pivoting = False , overwrite_a = False , check_finite = False , lwork = None
496503):
497504 ensure_lapack ()
505+ _check_linalg_matrix (x , ndim = 2 , dtype = (Float , Complex ), func_name = "qr" )
498506 dtype = x .dtype
499507 w_type = _get_underlying_float (dtype )
500508 geqrf = _LAPACK ().numba_xgeqrf (dtype )
501- orgqr = _LAPACK ().numba_xorgqr (dtype )
509+ orgqr = (
510+ _LAPACK ().numba_xorgqr (dtype )
511+ if isinstance (dtype , Float )
512+ else _LAPACK ().numba_xungqr (dtype )
513+ )
502514
503515 def impl (
504516 x ,
@@ -528,14 +540,14 @@ def impl(
528540 geqrf (
529541 val_to_int_ptr (M ),
530542 val_to_int_ptr (N ),
531- x_copy .view (w_type ).ctypes ,
543+ x_copy .T . view (w_type ). T .ctypes ,
532544 LDA ,
533545 TAU .view (w_type ).ctypes ,
534546 WORK .view (w_type ).ctypes ,
535547 val_to_int_ptr (- 1 ),
536548 val_to_int_ptr (1 ),
537549 )
538- lwork_val = int (WORK .item ())
550+ lwork_val = int (WORK .item (). real )
539551 else :
540552 lwork_val = lwork
541553
@@ -545,7 +557,7 @@ def impl(
545557 geqrf (
546558 val_to_int_ptr (M ),
547559 val_to_int_ptr (N ),
548- x_copy .view (w_type ).ctypes ,
560+ x_copy .T . view (w_type ). T .ctypes ,
549561 LDA ,
550562 TAU .view (w_type ).ctypes ,
551563 WORK .view (w_type ).ctypes ,
@@ -573,14 +585,14 @@ def impl(
573585 val_to_int_ptr (M ),
574586 val_to_int_ptr (Q_in .shape [1 ]),
575587 val_to_int_ptr (K ),
576- Q_in .view (w_type ).ctypes ,
588+ Q_in .T . view (w_type ). T .ctypes ,
577589 val_to_int_ptr (M ),
578590 TAU .view (w_type ).ctypes ,
579591 WORKQ .view (w_type ).ctypes ,
580592 val_to_int_ptr (- 1 ),
581593 val_to_int_ptr (1 ),
582594 )
583- lwork_q = int (WORKQ .item ())
595+ lwork_q = int (WORKQ .real . item ())
584596 else :
585597 lwork_q = lwork
586598
@@ -591,7 +603,7 @@ def impl(
591603 val_to_int_ptr (M ), # M
592604 val_to_int_ptr (Q_in .shape [1 ]), # N
593605 val_to_int_ptr (K ), # K
594- Q_in .view (w_type ).ctypes , # A
606+ Q_in .T . view (w_type ). T .ctypes , # A
595607 val_to_int_ptr (M ), # LDA
596608 TAU .view (w_type ).ctypes , # TAU
597609 WORKQ .view (w_type ).ctypes , # WORK
@@ -608,6 +620,7 @@ def qr_r_pivot_impl(
608620 x , mode = "r" , pivoting = True , overwrite_a = False , check_finite = False , lwork = None
609621):
610622 ensure_lapack ()
623+ _check_linalg_matrix (x , ndim = 2 , dtype = (Float , Complex ), func_name = "qr" )
611624 dtype = x .dtype
612625 w_type = _get_underlying_float (dtype )
613626 geqp3 = _LAPACK ().numba_xgeqp3 (dtype )
@@ -640,15 +653,15 @@ def impl(
640653 geqp3 (
641654 val_to_int_ptr (M ),
642655 val_to_int_ptr (N ),
643- x_copy .view (w_type ).ctypes ,
656+ x_copy .T . view (w_type ). T .ctypes ,
644657 LDA ,
645658 JPVT .ctypes ,
646659 TAU .view (w_type ).ctypes ,
647660 WORK .view (w_type ).ctypes ,
648661 val_to_int_ptr (- 1 ),
649662 val_to_int_ptr (1 ),
650663 )
651- lwork_val = int (WORK .item ())
664+ lwork_val = int (WORK .item (). real )
652665 else :
653666 lwork_val = lwork
654667
@@ -658,7 +671,7 @@ def impl(
658671 geqp3 (
659672 val_to_int_ptr (M ),
660673 val_to_int_ptr (N ),
661- x_copy .view (w_type ).ctypes ,
674+ x_copy .T . view (w_type ). T .ctypes ,
662675 LDA ,
663676 JPVT .ctypes ,
664677 TAU .view (w_type ).ctypes ,
@@ -683,6 +696,7 @@ def qr_r_no_pivot_impl(
683696 x , mode = "r" , pivoting = False , overwrite_a = False , check_finite = False , lwork = None
684697):
685698 ensure_lapack ()
699+ _check_linalg_matrix (x , ndim = 2 , dtype = (Float , Complex ), func_name = "qr" )
686700 dtype = x .dtype
687701 w_type = _get_underlying_float (dtype )
688702 geqrf = _LAPACK ().numba_xgeqrf (dtype )
@@ -714,14 +728,14 @@ def impl(
714728 geqrf (
715729 val_to_int_ptr (M ),
716730 val_to_int_ptr (N ),
717- x_copy .view (w_type ).ctypes ,
731+ x_copy .T . view (w_type ). T .ctypes ,
718732 LDA ,
719733 TAU .view (w_type ).ctypes ,
720734 WORK .view (w_type ).ctypes ,
721735 val_to_int_ptr (- 1 ),
722736 val_to_int_ptr (1 ),
723737 )
724- lwork_val = int (WORK .item ())
738+ lwork_val = int (WORK .item (). real )
725739 else :
726740 lwork_val = lwork
727741
@@ -731,7 +745,7 @@ def impl(
731745 geqrf (
732746 val_to_int_ptr (M ),
733747 val_to_int_ptr (N ),
734- x_copy .view (w_type ).ctypes ,
748+ x_copy .T . view (w_type ). T .ctypes ,
735749 LDA ,
736750 TAU .view (w_type ).ctypes ,
737751 WORK .view (w_type ).ctypes ,
@@ -755,6 +769,7 @@ def qr_raw_no_pivot_impl(
755769 x , mode = "raw" , pivoting = False , overwrite_a = False , check_finite = False , lwork = None
756770):
757771 ensure_lapack ()
772+ _check_linalg_matrix (x , ndim = 2 , dtype = (Float , Complex ), func_name = "qr" )
758773 dtype = x .dtype
759774 w_type = _get_underlying_float (dtype )
760775 geqrf = _LAPACK ().numba_xgeqrf (dtype )
@@ -786,14 +801,14 @@ def impl(
786801 geqrf (
787802 val_to_int_ptr (M ),
788803 val_to_int_ptr (N ),
789- x_copy .view (w_type ).ctypes ,
804+ x_copy .T . view (w_type ). T .ctypes ,
790805 LDA ,
791806 TAU .view (w_type ).ctypes ,
792807 WORK .view (w_type ).ctypes ,
793808 val_to_int_ptr (- 1 ),
794809 val_to_int_ptr (1 ),
795810 )
796- lwork_val = int (WORK .item ())
811+ lwork_val = int (WORK .item (). real )
797812 else :
798813 lwork_val = lwork
799814
@@ -803,7 +818,7 @@ def impl(
803818 geqrf (
804819 val_to_int_ptr (M ),
805820 val_to_int_ptr (N ),
806- x_copy .view (w_type ).ctypes ,
821+ x_copy .T . view (w_type ). T .ctypes ,
807822 LDA ,
808823 TAU .view (w_type ).ctypes ,
809824 WORK .view (w_type ).ctypes ,
@@ -826,6 +841,7 @@ def qr_raw_pivot_impl(
826841 x , mode = "raw" , pivoting = True , overwrite_a = False , check_finite = False , lwork = None
827842):
828843 ensure_lapack ()
844+ _check_linalg_matrix (x , ndim = 2 , dtype = (Float , Complex ), func_name = "qr" )
829845 dtype = x .dtype
830846 w_type = _get_underlying_float (dtype )
831847 geqp3 = _LAPACK ().numba_xgeqp3 (dtype )
@@ -858,15 +874,15 @@ def impl(
858874 geqp3 (
859875 val_to_int_ptr (M ),
860876 val_to_int_ptr (N ),
861- x_copy .view (w_type ).ctypes ,
877+ x_copy .T . view (w_type ). T .ctypes ,
862878 LDA ,
863879 JPVT .ctypes ,
864880 TAU .view (w_type ).ctypes ,
865881 WORK .view (w_type ).ctypes ,
866882 val_to_int_ptr (- 1 ),
867883 val_to_int_ptr (1 ),
868884 )
869- lwork_val = int (WORK .item ())
885+ lwork_val = int (WORK .item (). real )
870886 else :
871887 lwork_val = lwork
872888
@@ -876,7 +892,7 @@ def impl(
876892 geqp3 (
877893 val_to_int_ptr (M ),
878894 val_to_int_ptr (N ),
879- x_copy .view (w_type ).ctypes ,
895+ x_copy .T . view (w_type ). T .ctypes ,
880896 LDA ,
881897 JPVT .ctypes ,
882898 TAU .view (w_type ).ctypes ,
0 commit comments