@@ -22,13 +22,46 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
2222 return lu! (copy (A), normalize_pivot (pivot), thread; kwargs... )
2323end
2424
25+ const CUSTOMIZABLE_PIVOT = VERSION >= v " 1.8.0-DEV.1507"
26+
27+ struct NotIPIV <: AbstractVector{BlasInt}
28+ len:: Int
29+ end
30+ Base. size (A:: NotIPIV ) = (A. len,)
31+ Base. getindex (:: NotIPIV , i:: Int ) = i
32+ Base. view (:: NotIPIV , r:: AbstractUnitRange ) = NotIPIV (length (r))
33+ function init_pivot (:: Val{false} , minmn)
34+ @static if CUSTOMIZABLE_PIVOT
35+ NotIPIV (minmn)
36+ else
37+ init_pivot (Val (true ), minmn)
38+ end
39+ end
40+ init_pivot (:: Val{true} , minmn) = Vector {BlasInt} (undef, minmn)
41+
42+ if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_cols! )
43+ function LinearAlgebra. _ipiv_cols! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
44+ B:: StridedVecOrMat )
45+ return B
46+ end
47+ end
48+ if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_rows! )
49+ function LinearAlgebra. _ipiv_rows! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
50+ B:: StridedVecOrMat )
51+ return B
52+ end
53+ end
54+
2555function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
2656 m, n = size (A)
2757 minmn = min (m, n)
28- F = if minmn < 10 # avx introduces small performance degradation
58+ npivot = normalize_pivot (pivot)
59+ # we want the type on both branches to match. When pivot = Val(false), we construct
60+ # a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
61+ F = if pivot === Val (true ) && minmn < 10 # avx introduces small performance degradation
2962 LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check = check)
3063 else
31- lu! (A, Vector {BlasInt} (undef , minmn), normalize_pivot (pivot) , thread; check = check,
64+ lu! (A, init_pivot (npivot , minmn), npivot , thread; check = check,
3265 kwargs... )
3366 end
3467 return F
@@ -44,6 +77,8 @@ pick_threshold() = LoopVectorization.register_size() == 64 ? 48 : 40
4477recurse (:: StridedArray ) = true
4578recurse (_) = false
4679
80+ _ptrarray (ipiv) = PtrArray (ipiv)
81+ _ptrarray (ipiv:: NotIPIV ) = ipiv
4782function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
4883 pivot = Val (true ), thread = Val (true );
4984 check:: Bool = true ,
@@ -54,11 +89,14 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
5489 info = zero (BlasInt)
5590 m, n = size (A)
5691 mnmin = min (m, n)
92+ if pivot === Val (false ) && ! CUSTOMIZABLE_PIVOT
93+ copyto! (ipiv, 1 : mnmin)
94+ end
5795 if recurse (A) && mnmin > threshold
5896 if T <: Union{Float32, Float64}
5997 GC. @preserve ipiv A begin info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
6098 m, n, mnmin,
61- PtrArray (ipiv), info, blocksize,
99+ _ptrarray (ipiv), info, blocksize,
62100 thread) end
63101 else
64102 info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
90128 # [AL AR]
91129 AL = @view A[:, 1 : m]
92130 AR = @view A[:, (m + 1 ): n]
93- apply_permutation! (ipiv, AR, Val ( Thread))
94- ldiv! (_unit_lower_triangular (AL), AR, Val ( Thread))
131+ Pivot && apply_permutation! (ipiv, AR, Val { Thread} ( ))
132+ ldiv! (_unit_lower_triangular (AL), AR, Val { Thread} ( ))
95133 end
96134 info
97135end
@@ -187,8 +225,10 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
187225 Pivot && apply_permutation! (P2, A21, thread)
188226
189227 info != previnfo && (info += n1)
190- @turbo warn_check_args= false for i in 1 : n2
191- P2[i] += n1
228+ if Pivot
229+ @turbo warn_check_args= false for i in 1 : n2
230+ P2[i] += n1
231+ end
192232 end
193233 return info
194234 end # inbounds
@@ -234,8 +274,8 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
234274 amax = absi
235275 end
236276 end
277+ ipiv[k] = kp
237278 end
238- ipiv[k] = kp
239279 if ! iszero (A[kp, k])
240280 if k != kp
241281 # Interchange
0 commit comments