diff --git a/Project.toml b/Project.toml index 95ca7785b..92b356633 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1" SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -85,7 +86,7 @@ LeastSquaresOptim = "0.8.5" LineSearch = "0.1.4" LineSearches = "7.3" LinearAlgebra = "1.10" -LinearSolve = "3.46" +LinearSolve = "3.48" MINPACK = "1.2" MPI = "0.20.22" NLSolvers = "0.5" @@ -106,9 +107,9 @@ Random = "1.10" ReTestItems = "1.24" Reexport = "1.2.2" ReverseDiff = "1.15" -SciMLLogging = "1.3" SIAMFANLEquations = "1.0.1" SciMLBase = "2.127" +SciMLLogging = "1.3" SimpleNonlinearSolve = "2.11" SparseArrays = "1.10" SparseConnectivityTracer = "1" @@ -146,9 +147,9 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1" SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl index 50512300b..ee6db6760 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl @@ -77,4 +77,8 @@ function set_lincache_A!(lincache, new_A) return end +function LinearSolve.update_tolerances!(cache::LinearSolveJLCache; kwargs...) + LinearSolve.update_tolerances!(cache.lincache; kwargs...) +end + end diff --git a/lib/NonlinearSolveBase/src/verbosity.jl b/lib/NonlinearSolveBase/src/verbosity.jl index f7d50aa63..3df12def2 100644 --- a/lib/NonlinearSolveBase/src/verbosity.jl +++ b/lib/NonlinearSolveBase/src/verbosity.jl @@ -68,6 +68,7 @@ verbose = NonlinearVerbosity( termination_condition # Numerical threshold_state + forcing end # Group classifications @@ -76,7 +77,7 @@ const error_control_options = ( :termination_condition ) const performance_options = () -const numerical_options = (:threshold_state,) +const numerical_options = (:threshold_state,:forcing) function option_group(option::Symbol) if option in error_control_options @@ -140,7 +141,8 @@ function NonlinearVerbosity(; alias_u0_immutable = WarnLevel(), linsolve_failed_noncurrent = WarnLevel(), termination_condition = WarnLevel(), - threshold_state = WarnLevel() + threshold_state = WarnLevel(), + forcing = Silent(), ) # Apply group-level settings @@ -177,7 +179,8 @@ function NonlinearVerbosity(verbose::AbstractVerbosityPreset) alias_u0_immutable = Silent(), linsolve_failed_noncurrent = WarnLevel(), termination_condition = Silent(), - threshold_state = Silent() + threshold_state = Silent(), + forcing = Silent(), ) elseif verbose isa Standard # Standard: Everything from Minimal + non-fatal warnings @@ -190,7 +193,8 @@ function NonlinearVerbosity(verbose::AbstractVerbosityPreset) alias_u0_immutable = WarnLevel(), linsolve_failed_noncurrent = WarnLevel(), termination_condition = WarnLevel(), - threshold_state = WarnLevel() + threshold_state = WarnLevel(), + forcing = InfoLevel(), ) elseif verbose isa All # All: Maximum verbosity - every possible logging message at InfoLevel @@ -200,7 +204,8 @@ function NonlinearVerbosity(verbose::AbstractVerbosityPreset) alias_u0_immutable = WarnLevel(), linsolve_failed_noncurrent = WarnLevel(), termination_condition = WarnLevel(), - threshold_state = InfoLevel() + threshold_state = InfoLevel(), + forcing = InfoLevel(), ) end end @@ -212,7 +217,8 @@ end Silent(), Silent(), Silent(), - Silent() + Silent(), + Silent(), ) end diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 72ce1ae1e..c2dc141db 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -33,6 +33,7 @@ using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD include("solve.jl") include("raphson.jl") +include("eisenstat_walker.jl") include("gauss_newton.jl") include("levenberg_marquardt.jl") include("trust_region.jl") @@ -101,6 +102,8 @@ end export NewtonRaphson, PseudoTransient export GaussNewton, LevenbergMarquardt, TrustRegion +export EisenstatWalkerForcing2 + export RadiusUpdateSchemes export GeneralizedFirstOrderAlgorithm diff --git a/lib/NonlinearSolveFirstOrder/src/eisenstat_walker.jl b/lib/NonlinearSolveFirstOrder/src/eisenstat_walker.jl new file mode 100644 index 000000000..66030044c --- /dev/null +++ b/lib/NonlinearSolveFirstOrder/src/eisenstat_walker.jl @@ -0,0 +1,111 @@ +""" + EisenstatWalkerForcing2(; η₀ = 0.5, ηₘₐₓ = 0.9, γ = 0.9, α = 2, safeguard = true, safeguard_threshold = 0.1) + +Algorithm 2 from the classical work by Eisenstat and Walker (1996) as described by formula (2.6): + ηₖ = γ * (||rₖ|| / ||rₖ₋₁||)^α + +Here the variables denote: + rₖ residual at iteration k + η₀ ∈ [0,1) initial value for η + ηₘₐₓ ∈ [0,1) maximum value for η + γ ∈ [0,1) correction factor + α ∈ [1,2) correction exponent + +Furthermore, the proposed safeguard is implemented: + ηₖ = max(ηₖ, γ*ηₖ₋₁^α) if γ*ηₖ₋₁^α > safeguard_threshold +to prevent ηₖ from shrinking too fast. +""" +@concrete struct EisenstatWalkerForcing2 + η₀ + ηₘₐₓ + γ + α + safeguard + safeguard_threshold +end + +function EisenstatWalkerForcing2(; η₀ = 0.5, ηₘₐₓ = 0.9, γ = 0.9, α = 2, safeguard = true, safeguard_threshold = 0.1) + EisenstatWalkerForcing2(η₀, ηₘₐₓ, γ, α, safeguard, safeguard_threshold) +end + + +@concrete mutable struct EisenstatWalkerForcing2Cache + p::EisenstatWalkerForcing2 + η + rnorm + rnorm_prev + internalnorm + verbosity +end + + + +function pre_step_forcing!(cache::EisenstatWalkerForcing2Cache, descend_cache::NonlinearSolveBase.NewtonDescentCache, J, u, fu, iter) + @SciMLMessage("Eisenstat-Walker forcing residual norm $(cache.rnorm) with rate estimate $(cache.rnorm / cache.rnorm_prev).", cache.verbosity, :forcing) + + # On the first iteration we initialize η with the default initial value and stop. + if iter == 0 + cache.η = cache.p.η₀ + @SciMLMessage("Eisenstat-Walker initial iteration to η=$(cache.η).", cache.verbosity, :forcing) + LinearSolve.update_tolerances!(descend_cache.lincache; reltol=cache.η) + return nothing + end + + # Store previous + ηprev = cache.η + + # Formula (2.6) + # ||r|| > 0 should be guaranteed by the convergence criterion + (; rnorm, rnorm_prev) = cache + (; α, γ) = cache.p + cache.η = γ * (rnorm / rnorm_prev)^α + + # Safeguard 2 to prevent over-solving + if cache.p.safeguard + ηsg = γ*ηprev^α + if ηsg > cache.p.safeguard_threshold && ηsg > cache.η + cache.η = ηsg + end + end + + # Far away from the root we also need to respect η ∈ [0,1) + cache.η = clamp(cache.η, 0.0, cache.p.ηₘₐₓ) + + @SciMLMessage("Eisenstat-Walker iter $iter update to η=$(cache.η).", cache.verbosity, :forcing) + + # Communicate new relative tolerance to linear solve + LinearSolve.update_tolerances!(descend_cache.lincache; reltol=cache.η) + + return nothing +end + + + +function post_step_forcing!(cache::EisenstatWalkerForcing2Cache, J, u, fu, δu, iter) + # Cache previous residual norm + cache.rnorm_prev = cache.rnorm + cache.rnorm = cache.internalnorm(fu) + + # @SciMLMessage("Eisenstat-Walker sanity check: $(cache.internalnorm(fu + J*δu)) ≤ $(cache.η * cache.internalnorm(fu)).", cache.verbosity, :linear_verbosity) +end + + + +function InternalAPI.init( + prob::AbstractNonlinearProblem, alg::EisenstatWalkerForcing2, f, fu, u, p, + args...; verbose, internalnorm::F = L2_NORM, kwargs... +) where {F} + fu_norm = internalnorm(fu) + + return EisenstatWalkerForcing2Cache( + alg, alg.η₀, fu_norm, fu_norm, internalnorm, verbose + ) +end + + + +function InternalAPI.reinit!( + cache::EisenstatWalkerForcing2Cache; p = cache.p, kwargs... +) + cache.p = p +end diff --git a/lib/NonlinearSolveFirstOrder/src/raphson.jl b/lib/NonlinearSolveFirstOrder/src/raphson.jl index d482f4846..07cfbb068 100644 --- a/lib/NonlinearSolveFirstOrder/src/raphson.jl +++ b/lib/NonlinearSolveFirstOrder/src/raphson.jl @@ -1,7 +1,8 @@ """ NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = missing, - autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing + autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing, + forcing = nothing, ) An advanced NewtonRaphson implementation with support for efficient handling of sparse @@ -10,13 +11,15 @@ for large-scale and numerically-difficult nonlinear systems. """ function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = missing, - autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing + autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing, + forcing = nothing, ) return GeneralizedFirstOrderAlgorithm(; linesearch, descent = NewtonDescent(; linsolve), autodiff, vjp_autodiff, jvp_autodiff, concrete_jac, + forcing, name = :NewtonRaphson ) end diff --git a/lib/NonlinearSolveFirstOrder/src/solve.jl b/lib/NonlinearSolveFirstOrder/src/solve.jl index 7b7267aac..ed99c305f 100644 --- a/lib/NonlinearSolveFirstOrder/src/solve.jl +++ b/lib/NonlinearSolveFirstOrder/src/solve.jl @@ -25,6 +25,7 @@ order of convergence. linesearch trustregion descent + forcing max_shrink_times::Int autodiff @@ -38,12 +39,12 @@ end function GeneralizedFirstOrderAlgorithm(; descent, linesearch = missing, trustregion = missing, autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing, max_shrink_times::Int = typemax(Int), - concrete_jac = Val(false), name::Symbol = :unknown + concrete_jac = Val(false), forcing = nothing, name::Symbol = :unknown ) concrete_jac = concrete_jac isa Bool ? Val(concrete_jac) : (concrete_jac isa Val ? concrete_jac : Val(concrete_jac !== nothing)) return GeneralizedFirstOrderAlgorithm( - linesearch, trustregion, descent, max_shrink_times, + linesearch, trustregion, descent, forcing, max_shrink_times, autodiff, vjp_autodiff, jvp_autodiff, concrete_jac, name ) @@ -62,6 +63,7 @@ end # Internal Caches jac_cache descent_cache + forcing_cache linesearch_cache trustregion_cache @@ -125,7 +127,7 @@ function InternalAPI.reinit_self!( end NonlinearSolveBase.@internal_caches(GeneralizedFirstOrderAlgorithmCache, - :jac_cache, :descent_cache, :linesearch_cache, :trustregion_cache) + :jac_cache, :descent_cache, :linesearch_cache, :trustregion_cache, :forcing_cache) function SciMLBase.__init( prob::AbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; @@ -196,6 +198,7 @@ function SciMLBase.__init( has_linesearch = alg.linesearch !== missing && alg.linesearch !== nothing has_trustregion = alg.trustregion !== missing && alg.trustregion !== nothing + has_forcing = alg.forcing !== missing && alg.forcing !== nothing && !(u isa Number) && !(J isa Diagonal) if has_trustregion && has_linesearch error("TrustRegion and LineSearch methods are algorithmically incompatible.") @@ -204,6 +207,7 @@ function SciMLBase.__init( globalization = Val(:None) linesearch_cache = nothing trustregion_cache = nothing + forcing_cache = nothing if has_trustregion NonlinearSolveBase.supports_trust_region(alg.descent) || @@ -228,13 +232,24 @@ function SciMLBase.__init( globalization = Val(:LineSearch) end + if has_forcing + forcing_cache = InternalAPI.init( + prob, alg.forcing, fu, u, u, prob.p; stats, internalnorm, + autodiff = ifelse( + provided_jvp_autodiff, alg.jvp_autodiff, alg.vjp_autodiff + ), + verbose, + kwargs... + ) + end + trace = NonlinearSolveBase.init_nonlinearsolve_trace( prob, alg, u, fu, J, du; kwargs... ) cache = GeneralizedFirstOrderAlgorithmCache( fu, u, u_cache, prob.p, alg, prob, globalization, - jac_cache, descent_cache, linesearch_cache, trustregion_cache, + jac_cache, descent_cache, forcing_cache, linesearch_cache, trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer, 0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs, initializealg, verbose @@ -259,6 +274,12 @@ function InternalAPI.step!( end end + has_forcing = cache.forcing_cache !== nothing && cache.forcing_cache !== missing && !(cache.u isa Number) && !(J isa Diagonal) + + if has_forcing + pre_step_forcing!(cache.forcing_cache, cache.descent_cache, J, cache.u, cache.fu, cache.nsteps) + end + @static_timeit cache.timer "descent" begin if cache.trustregion_cache !== nothing && hasfield(typeof(cache.trustregion_cache), :trust_region) @@ -293,6 +314,10 @@ function InternalAPI.step!( δu, descent_intermediates = descent_result.δu, descent_result.extras if descent_result.success + if has_forcing + post_step_forcing!(cache.forcing_cache, J, cache.u, cache.fu, δu, cache.nsteps) + end + cache.make_new_jacobian = true if cache.globalization isa Val{:LineSearch} @static_timeit cache.timer "linesearch" begin diff --git a/lib/NonlinearSolveFirstOrder/src/trust_region.jl b/lib/NonlinearSolveFirstOrder/src/trust_region.jl index 03a1b3f9a..553748331 100644 --- a/lib/NonlinearSolveFirstOrder/src/trust_region.jl +++ b/lib/NonlinearSolveFirstOrder/src/trust_region.jl @@ -6,7 +6,7 @@ shrink_threshold::Real = 1 // 4, expand_threshold::Real = 3 // 4, shrink_factor::Real = 1 // 4, expand_factor::Real = 2 // 1, max_shrink_times::Int = 32, - vjp_autodiff = nothing, autodiff = nothing, jvp_autodiff = nothing + vjp_autodiff = nothing, autodiff = nothing, jvp_autodiff = nothing, ) An advanced TrustRegion implementation with support for efficient handling of sparse @@ -29,7 +29,7 @@ function TrustRegion(; shrink_threshold::Real = 1 // 4, expand_threshold::Real = 3 // 4, shrink_factor::Real = 1 // 4, expand_factor::Real = 2 // 1, max_shrink_times::Int = 32, - autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing + autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing, ) descent = Dogleg(; linsolve) trustregion = GenericTrustRegionScheme(; diff --git a/lib/NonlinearSolveFirstOrder/test/rootfind_tests.jl b/lib/NonlinearSolveFirstOrder/test/rootfind_tests.jl index 0d122c2d6..6d4b72e87 100644 --- a/lib/NonlinearSolveFirstOrder/test/rootfind_tests.jl +++ b/lib/NonlinearSolveFirstOrder/test/rootfind_tests.jl @@ -4,6 +4,53 @@ include("../../../common/common_rootfind_testing.jl") end +@testitem "Eisenstadt-Walker Newton-Krylov" setup=[CoreRootfindTesting] tags=[:core] begin + using LinearAlgebra, Random, LinearSolve + using BenchmarkTools: @ballocated + using StaticArrays: @SVector + + @testset for (concrete_jac, linsolve) in ( + (Val(false), KrylovJL_CG(; precs = nothing)), + (Val(false), KrylovJL_GMRES(; precs = nothing)), + ( + Val(true), + KrylovJL_GMRES(; + precs = (A, + p = nothing) -> ( + Diagonal(randn!(similar(A, size(A, 1)))), LinearAlgebra.I + ) + ), + ), + ) + @testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0]) + solver = NewtonRaphson(; forcing=EisenstatWalkerForcing2(), linsolve, concrete_jac) + sol = solve_oop(quadratic_f, u0; solver) + @test SciMLBase.successful_retcode(sol) + err = maximum(abs, quadratic_f(sol.u, 2.0)) + @test err < 1e-9 + + cache = init( + NonlinearProblem{false}(quadratic_f, u0, 2.0), solver, abstol = 1e-9 + ) + @test (@ballocated solve!($cache)) < 200 + end + + @testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],) + solver = NewtonRaphson(; forcing=EisenstatWalkerForcing2(), linsolve, concrete_jac) + + sol = solve_iip(quadratic_f!, u0; solver) + @test SciMLBase.successful_retcode(sol) + err = maximum(abs, quadratic_f(sol.u, 2.0)) + @test err < 1e-9 + + cache = init( + NonlinearProblem{true}(quadratic_f!, u0, 2.0), solver, abstol = 1e-9 + ) + @test (@ballocated solve!($cache)) ≤ 64 + end + end +end + @testitem "NewtonRaphson" setup=[CoreRootfindTesting] tags=[:core] begin using ADTypes, LineSearch, LinearAlgebra, Random, LinearSolve using LineSearches: LineSearches