From ef562094125da581b26f558f0da9cb0d38846310 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 14 Nov 2025 01:17:00 -0500 Subject: [PATCH 1/6] Disable Enzyme on Julia v1.12 and enable v1.11 testing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit makes the following changes: - Add Julia v1.11 to the CI testing matrix - Conditionally load Enzyme only on Julia versions < 1.12 - Define dummy types for EnzymeAdjoint and EnzymeVJP on v1.12+ that error if instantiated - Wrap Enzyme-specific method definitions in version checks - Make AutoEnzyme import conditional on Julia version Enzyme is currently incompatible with Julia v1.12, so this change ensures the package can still be used on v1.12 with other AD backends while maintaining full Enzyme support on earlier versions. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .github/workflows/CI.yml | 1 + src/SciMLSensitivity.jl | 25 ++++- src/concrete_solve.jl | 178 +++++++++++++++++----------------- src/sensitivity_algorithms.jl | 123 +++++++++++++---------- 4 files changed, 184 insertions(+), 143 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0d25447dd..5e009b53e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -31,6 +31,7 @@ jobs: - SDE3 version: - '1' + - '1.11' - 'lts' steps: - uses: actions/checkout@v4 diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 7a003b4d3..1e215c898 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -1,7 +1,12 @@ module SciMLSensitivity -using ADTypes: ADTypes, AutoEnzyme, AutoFiniteDiff, AutoForwardDiff, - AutoReverseDiff, AutoTracker, AutoZygote +@static if VERSION < v"1.12" + using ADTypes: ADTypes, AutoEnzyme, AutoFiniteDiff, AutoForwardDiff, + AutoReverseDiff, AutoTracker, AutoZygote +else + using ADTypes: ADTypes, AutoFiniteDiff, AutoForwardDiff, + AutoReverseDiff, AutoTracker, AutoZygote +end using Accessors: @reset using Adapt: Adapt, adapt using ArrayInterface: ArrayInterface @@ -45,7 +50,9 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqCore, BrownFullBasicInit, DefaultInit, # AD Backends using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent, AbstractThunk, AbstractTangent -using Enzyme: Enzyme +@static if VERSION < v"1.12" + using Enzyme: Enzyme +end using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using Tracker: Tracker, TrackedArray @@ -97,14 +104,22 @@ export ODEForwardSensitivityFunction, ODEForwardSensitivityProblem, SensitivityF export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, GaussKronrodAdjoint, InterpolatingAdjoint, TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint, MooncakeAdjoint, - EnzymeAdjoint, ForwardSensitivity, ForwardDiffSensitivity, + ForwardSensitivity, ForwardDiffSensitivity, ForwardDiffOverAdjoint, SteadyStateAdjoint, ForwardLSS, AdjointLSS, NILSS, NILSAS +@static if VERSION < v"1.12" + export EnzymeAdjoint +end + export second_order_sensitivities, second_order_sensitivity_product -export TrackerVJP, ZygoteVJP, EnzymeVJP, ReverseDiffVJP +export TrackerVJP, ZygoteVJP, ReverseDiffVJP + +@static if VERSION < v"1.12" + export EnzymeVJP +end export StochasticTransformedFunction diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index a2af3e485..e02f39289 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1285,105 +1285,107 @@ function DiffEqBase._concrete_solve_adjoint( p) end -function DiffEqBase._concrete_solve_adjoint( - prob::Union{SciMLBase.AbstractDiscreteProblem, - SciMLBase.AbstractODEProblem, - SciMLBase.AbstractDAEProblem, - SciMLBase.AbstractDDEProblem, - SciMLBase.AbstractSDEProblem, - SciMLBase.AbstractSDDEProblem, - SciMLBase.AbstractRODEProblem - }, - alg, sensealg::EnzymeAdjoint, - u0, p, originator::SciMLBase.ADOriginator, - args...; kwargs...) - kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) - du0 = Enzyme.make_zero(u0) - dp = Enzyme.make_zero(p) - mode = sensealg.mode - - # Force no FunctionWrappers for Enzyme - _prob = remake(prob, f = f = ODEFunction{isinplace(prob), SciMLBase.FullSpecialize}(unwrapped_f(prob.f)) ) - - diff_func = (u0, - p) -> solve(_prob, alg, args...; u0 = u0, p = p, - sensealg = SensitivityADPassThrough(), - kwargs_filtered...) - - splitmode = if mode isa Enzyme.ForwardMode - error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.") - elseif mode === nothing || mode isa Enzyme.ReverseMode - Enzyme.set_runtime_activity(Enzyme.ReverseSplitWithPrimal) - end +@static if VERSION < v"1.12" + function DiffEqBase._concrete_solve_adjoint( + prob::Union{SciMLBase.AbstractDiscreteProblem, + SciMLBase.AbstractODEProblem, + SciMLBase.AbstractDAEProblem, + SciMLBase.AbstractDDEProblem, + SciMLBase.AbstractSDEProblem, + SciMLBase.AbstractSDDEProblem, + SciMLBase.AbstractRODEProblem + }, + alg, sensealg::EnzymeAdjoint, + u0, p, originator::SciMLBase.ADOriginator, + args...; kwargs...) + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) + du0 = Enzyme.make_zero(u0) + dp = Enzyme.make_zero(p) + mode = sensealg.mode - forward, - reverse = Enzyme.autodiff_thunk( - splitmode, Enzyme.Const{typeof(diff_func)}, Enzyme.Duplicated, - Enzyme.Duplicated{typeof(u0)}, Enzyme.Duplicated{typeof(p)}) - tape, result, - shadow_result = forward( - Enzyme.Const(diff_func), Enzyme.Duplicated(copy(u0), du0), Enzyme.Duplicated(copy(p), dp)) - - function enzyme_sensitivity_backpass(Δ) - if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray) - for (x, y) in zip(shadow_result.u, Δ.u) - x .= y - end - else - error("typeof(Δ) = $(typeof(Δ)) is not currently handled in EnzymeAdjoint. Please open an issue with an MWE to add support") + # Force no FunctionWrappers for Enzyme + _prob = remake(prob, f = f = ODEFunction{isinplace(prob), SciMLBase.FullSpecialize}(unwrapped_f(prob.f)) ) + + diff_func = (u0, + p) -> solve(_prob, alg, args...; u0 = u0, p = p, + sensealg = SensitivityADPassThrough(), + kwargs_filtered...) + + splitmode = if mode isa Enzyme.ForwardMode + error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.") + elseif mode === nothing || mode isa Enzyme.ReverseMode + Enzyme.set_runtime_activity(Enzyme.ReverseSplitWithPrimal) end - reverse(Enzyme.Const(diff_func), Enzyme.Duplicated(u0, du0), Enzyme.Duplicated(p, dp), tape) - if originator isa SciMLBase.TrackerOriginator || - originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), du0, dp, NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) - else - (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) + + forward, + reverse = Enzyme.autodiff_thunk( + splitmode, Enzyme.Const{typeof(diff_func)}, Enzyme.Duplicated, + Enzyme.Duplicated{typeof(u0)}, Enzyme.Duplicated{typeof(p)}) + tape, result, + shadow_result = forward( + Enzyme.Const(diff_func), Enzyme.Duplicated(copy(u0), du0), Enzyme.Duplicated(copy(p), dp)) + + function enzyme_sensitivity_backpass(Δ) + if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray) + for (x, y) in zip(shadow_result.u, Δ.u) + x .= y + end + else + error("typeof(Δ) = $(typeof(Δ)) is not currently handled in EnzymeAdjoint. Please open an issue with an MWE to add support") + end + reverse(Enzyme.Const(diff_func), Enzyme.Duplicated(u0, du0), Enzyme.Duplicated(p, dp), tape) + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end end + result, enzyme_sensitivity_backpass end - result, enzyme_sensitivity_backpass -end -# NOTE: This is needed to prevent a method ambiguity error -function DiffEqBase._concrete_solve_adjoint( - prob::AbstractNonlinearProblem, alg, sensealg::EnzymeAdjoint, - u0, p, originator::SciMLBase.ADOriginator, - args...; kwargs...) - kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) - - du0 = make_zero(u0) - dp = make_zero(p) - mode = sensealg.mode + # NOTE: This is needed to prevent a method ambiguity error + function DiffEqBase._concrete_solve_adjoint( + prob::AbstractNonlinearProblem, alg, sensealg::EnzymeAdjoint, + u0, p, originator::SciMLBase.ADOriginator, + args...; kwargs...) + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) - f = (u0, - p) -> solve(prob, alg, args...; u0 = u0, p = p, - sensealg = SensitivityADPassThrough(), - kwargs_filtered...) + du0 = make_zero(u0) + dp = make_zero(p) + mode = sensealg.mode - splitmode = if mode isa Forward - error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.") - elseif mode === nothing || mode === Reverse - ReverseSplitWithPrimal - end + f = (u0, + p) -> solve(prob, alg, args...; u0 = u0, p = p, + sensealg = SensitivityADPassThrough(), + kwargs_filtered...) - forward, - reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated, - Duplicated{typeof(u0)}, Duplicated{typeof(p)}) - tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp)) + splitmode = if mode isa Forward + error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.") + elseif mode === nothing || mode === Reverse + ReverseSplitWithPrimal + end - function enzyme_sensitivity_backpass(Δ) - reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape) - if originator isa SciMLBase.TrackerOriginator || - originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), du0, dp, NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) - else - (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) + forward, + reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated, + Duplicated{typeof(u0)}, Duplicated{typeof(p)}) + tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp)) + + function enzyme_sensitivity_backpass(Δ) + reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape) + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end end + sol, enzyme_sensitivity_backpass end - sol, enzyme_sensitivity_backpass end const ENZYME_TRACKED_REAL_ERROR_MESSAGE = """ diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index f63b8f0ec..460de80c3 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -772,39 +772,50 @@ Currently fails on almost every solver. """ struct ZygoteAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} end -""" -EnzymeAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} +@static if VERSION < v"1.12" + """ + EnzymeAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} -An implementation of discrete adjoint sensitivity analysis -using the Enzyme.jl source-to-source AD directly on the differential equation -solver. + An implementation of discrete adjoint sensitivity analysis + using the Enzyme.jl source-to-source AD directly on the differential equation + solver. -!!! warn + !!! warn - This is currently experimental and supports only explicit solvers. It will - support all solvers in the future. + This is currently experimental and supports only explicit solvers. It will + support all solvers in the future. -## Constructor + ## Constructor -```julia -EnzymeAdjoint(mode = nothing) -``` + ```julia + EnzymeAdjoint(mode = nothing) + ``` -## Arguments + ## Arguments - - `mode::M` determines the autodiff mode (forward or reverse). It can be: + - `mode::M` determines the autodiff mode (forward or reverse). It can be: - + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required - + `nothing` to choose the best mode automatically + + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required + + `nothing` to choose the best mode automatically -## SciMLProblem Support + ## SciMLProblem Support -Currently fails on almost every solver. -""" -struct EnzymeAdjoint{M <: Union{Nothing, Enzyme.EnzymeCore.Mode}} <: - AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} - mode::M - EnzymeAdjoint(mode = nothing) = new{typeof(mode)}(mode) + Currently fails on almost every solver. + """ + struct EnzymeAdjoint{M <: Union{Nothing, Enzyme.EnzymeCore.Mode}} <: + AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} + mode::M + EnzymeAdjoint(mode = nothing) = new{typeof(mode)}(mode) + end +else + # Dummy type for Julia 1.12+ - Enzyme is not supported on this version + struct EnzymeAdjoint{M <: Nothing} <: + AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} + mode::M + function EnzymeAdjoint(mode = nothing) + error("EnzymeAdjoint is not supported on Julia 1.12+. Please use a different sensitivity algorithm.") + end + end end """ @@ -1291,39 +1302,49 @@ struct ZygoteVJP <: VJPChoice end ZygoteVJP(; allow_nothing = false) = ZygoteVJP(allow_nothing) -""" -```julia -EnzymeVJP <: VJPChoice -``` +@static if VERSION < v"1.12" + """ + ```julia + EnzymeVJP <: VJPChoice + ``` -Uses Enzyme.jl to compute vector-Jacobian products. Is the fastest VJP whenever applicable, -though Enzyme.jl currently has low coverage over the Julia programming language, for example -restricting the user's defined `f` function to not do things like require garbage collection -or calls to BLAS/LAPACK. However, mutation is supported, meaning that in-place `f` with -fully mutating non-allocating code will work with Enzyme (provided no high-level calls to C -like BLAS/LAPACK are used) and this will be the most efficient adjoint implementation. + Uses Enzyme.jl to compute vector-Jacobian products. Is the fastest VJP whenever applicable, + though Enzyme.jl currently has low coverage over the Julia programming language, for example + restricting the user's defined `f` function to not do things like require garbage collection + or calls to BLAS/LAPACK. However, mutation is supported, meaning that in-place `f` with + fully mutating non-allocating code will work with Enzyme (provided no high-level calls to C + like BLAS/LAPACK are used) and this will be the most efficient adjoint implementation. -## Constructor + ## Constructor -```julia -EnzymeVJP(; chunksize = 0) -``` + ```julia + EnzymeVJP(; chunksize = 0) + ``` -## Keyword Arguments + ## Keyword Arguments + + - `chunksize`: the default chunk size for the temporary variables inside the vjp's right + hand side definition. This is used for compatibility with ODE solves that default to using + ForwardDiff.jl for the Jacobian of the stiff ODE solve, such as OrdinaryDiffEq.jl. This + should be set to the maximum chunksize that can occur during an integration to preallocate + the `DualCaches` for PreallocationTools.jl. It defaults to 0, using `ForwardDiff.pickchunksize` + but could be decreased if this value is known to be lower to conserve memory. + """ + struct EnzymeVJP <: VJPChoice + chunksize::Int + end - - `chunksize`: the default chunk size for the temporary variables inside the vjp's right - hand side definition. This is used for compatibility with ODE solves that default to using - ForwardDiff.jl for the Jacobian of the stiff ODE solve, such as OrdinaryDiffEq.jl. This - should be set to the maximum chunksize that can occur during an integration to preallocate - the `DualCaches` for PreallocationTools.jl. It defaults to 0, using `ForwardDiff.pickchunksize` - but could be decreased if this value is known to be lower to conserve memory. -""" -struct EnzymeVJP <: VJPChoice - chunksize::Int + EnzymeVJP(; chunksize = 0) = EnzymeVJP(chunksize) +else + # Dummy type for Julia 1.12+ - Enzyme is not supported on this version + struct EnzymeVJP <: VJPChoice + chunksize::Int + function EnzymeVJP(; chunksize = 0) + error("EnzymeVJP is not supported on Julia 1.12+. Please use a different VJP method.") + end + end end -EnzymeVJP(; chunksize = 0) = EnzymeVJP(chunksize) - """ ```julia TrackerVJP <: VJPChoice @@ -1496,7 +1517,9 @@ function get_autodiff_from_vjp(::ReverseDiffVJP{compile}) where {compile} return AutoReverseDiff(; compile) end get_autodiff_from_vjp(::ZygoteVJP) = AutoZygote() -get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme() +@static if VERSION < v"1.12" + get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme() +end get_autodiff_from_vjp(::TrackerVJP) = AutoTracker() get_autodiff_from_vjp(::Nothing) = AutoZygote() get_autodiff_from_vjp(b::Bool) = ifelse(b, AutoForwardDiff(), AutoFiniteDiff()) From 020e990042e81338cab829ae9d464610d2e431eb Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 14 Nov 2025 01:49:57 -0500 Subject: [PATCH 2/6] Simplify Enzyme conditional loading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Keep EnzymeAdjoint and EnzymeVJP types and exports always available. Only conditionally load `using Enzyme` on Julia < 1.12. The types will error with a clear message if instantiated on v1.12+. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/SciMLSensitivity.jl | 21 ++----- src/sensitivity_algorithms.jl | 108 +++++++++++++++------------------- 2 files changed, 52 insertions(+), 77 deletions(-) diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 1e215c898..faa655b60 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -1,12 +1,7 @@ module SciMLSensitivity -@static if VERSION < v"1.12" - using ADTypes: ADTypes, AutoEnzyme, AutoFiniteDiff, AutoForwardDiff, - AutoReverseDiff, AutoTracker, AutoZygote -else - using ADTypes: ADTypes, AutoFiniteDiff, AutoForwardDiff, - AutoReverseDiff, AutoTracker, AutoZygote -end +using ADTypes: ADTypes, AutoEnzyme, AutoFiniteDiff, AutoForwardDiff, + AutoReverseDiff, AutoTracker, AutoZygote using Accessors: @reset using Adapt: Adapt, adapt using ArrayInterface: ArrayInterface @@ -104,22 +99,14 @@ export ODEForwardSensitivityFunction, ODEForwardSensitivityProblem, SensitivityF export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, GaussKronrodAdjoint, InterpolatingAdjoint, TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint, MooncakeAdjoint, - ForwardSensitivity, ForwardDiffSensitivity, + EnzymeAdjoint, ForwardSensitivity, ForwardDiffSensitivity, ForwardDiffOverAdjoint, SteadyStateAdjoint, ForwardLSS, AdjointLSS, NILSS, NILSAS -@static if VERSION < v"1.12" - export EnzymeAdjoint -end - export second_order_sensitivities, second_order_sensitivity_product -export TrackerVJP, ZygoteVJP, ReverseDiffVJP - -@static if VERSION < v"1.12" - export EnzymeVJP -end +export TrackerVJP, ZygoteVJP, EnzymeVJP, ReverseDiffVJP export StochasticTransformedFunction diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 460de80c3..fef683a9b 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -772,43 +772,43 @@ Currently fails on almost every solver. """ struct ZygoteAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} end -@static if VERSION < v"1.12" - """ - EnzymeAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} +""" +EnzymeAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing} - An implementation of discrete adjoint sensitivity analysis - using the Enzyme.jl source-to-source AD directly on the differential equation - solver. +An implementation of discrete adjoint sensitivity analysis +using the Enzyme.jl source-to-source AD directly on the differential equation +solver. - !!! warn +!!! warn - This is currently experimental and supports only explicit solvers. It will - support all solvers in the future. + This is currently experimental and supports only explicit solvers. It will + support all solvers in the future. - ## Constructor +## Constructor - ```julia - EnzymeAdjoint(mode = nothing) - ``` +```julia +EnzymeAdjoint(mode = nothing) +``` - ## Arguments +## Arguments - - `mode::M` determines the autodiff mode (forward or reverse). It can be: + - `mode::M` determines the autodiff mode (forward or reverse). It can be: - + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required - + `nothing` to choose the best mode automatically + + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required + + `nothing` to choose the best mode automatically - ## SciMLProblem Support +## SciMLProblem Support - Currently fails on almost every solver. - """ +Currently fails on almost every solver. +""" +@static if VERSION < v"1.12" struct EnzymeAdjoint{M <: Union{Nothing, Enzyme.EnzymeCore.Mode}} <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} mode::M EnzymeAdjoint(mode = nothing) = new{typeof(mode)}(mode) end else - # Dummy type for Julia 1.12+ - Enzyme is not supported on this version + # Dummy type for Julia 1.12+ - Enzyme is not loaded on this version struct EnzymeAdjoint{M <: Nothing} <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} mode::M @@ -1302,49 +1302,39 @@ struct ZygoteVJP <: VJPChoice end ZygoteVJP(; allow_nothing = false) = ZygoteVJP(allow_nothing) -@static if VERSION < v"1.12" - """ - ```julia - EnzymeVJP <: VJPChoice - ``` +""" +```julia +EnzymeVJP <: VJPChoice +``` - Uses Enzyme.jl to compute vector-Jacobian products. Is the fastest VJP whenever applicable, - though Enzyme.jl currently has low coverage over the Julia programming language, for example - restricting the user's defined `f` function to not do things like require garbage collection - or calls to BLAS/LAPACK. However, mutation is supported, meaning that in-place `f` with - fully mutating non-allocating code will work with Enzyme (provided no high-level calls to C - like BLAS/LAPACK are used) and this will be the most efficient adjoint implementation. +Uses Enzyme.jl to compute vector-Jacobian products. Is the fastest VJP whenever applicable, +though Enzyme.jl currently has low coverage over the Julia programming language, for example +restricting the user's defined `f` function to not do things like require garbage collection +or calls to BLAS/LAPACK. However, mutation is supported, meaning that in-place `f` with +fully mutating non-allocating code will work with Enzyme (provided no high-level calls to C +like BLAS/LAPACK are used) and this will be the most efficient adjoint implementation. - ## Constructor +## Constructor - ```julia - EnzymeVJP(; chunksize = 0) - ``` +```julia +EnzymeVJP(; chunksize = 0) +``` - ## Keyword Arguments - - - `chunksize`: the default chunk size for the temporary variables inside the vjp's right - hand side definition. This is used for compatibility with ODE solves that default to using - ForwardDiff.jl for the Jacobian of the stiff ODE solve, such as OrdinaryDiffEq.jl. This - should be set to the maximum chunksize that can occur during an integration to preallocate - the `DualCaches` for PreallocationTools.jl. It defaults to 0, using `ForwardDiff.pickchunksize` - but could be decreased if this value is known to be lower to conserve memory. - """ - struct EnzymeVJP <: VJPChoice - chunksize::Int - end +## Keyword Arguments - EnzymeVJP(; chunksize = 0) = EnzymeVJP(chunksize) -else - # Dummy type for Julia 1.12+ - Enzyme is not supported on this version - struct EnzymeVJP <: VJPChoice - chunksize::Int - function EnzymeVJP(; chunksize = 0) - error("EnzymeVJP is not supported on Julia 1.12+. Please use a different VJP method.") - end - end + - `chunksize`: the default chunk size for the temporary variables inside the vjp's right + hand side definition. This is used for compatibility with ODE solves that default to using + ForwardDiff.jl for the Jacobian of the stiff ODE solve, such as OrdinaryDiffEq.jl. This + should be set to the maximum chunksize that can occur during an integration to preallocate + the `DualCaches` for PreallocationTools.jl. It defaults to 0, using `ForwardDiff.pickchunksize` + but could be decreased if this value is known to be lower to conserve memory. +""" +struct EnzymeVJP <: VJPChoice + chunksize::Int end +EnzymeVJP(; chunksize = 0) = EnzymeVJP(chunksize) + """ ```julia TrackerVJP <: VJPChoice @@ -1517,9 +1507,7 @@ function get_autodiff_from_vjp(::ReverseDiffVJP{compile}) where {compile} return AutoReverseDiff(; compile) end get_autodiff_from_vjp(::ZygoteVJP) = AutoZygote() -@static if VERSION < v"1.12" - get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme() -end +get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme() get_autodiff_from_vjp(::TrackerVJP) = AutoTracker() get_autodiff_from_vjp(::Nothing) = AutoZygote() get_autodiff_from_vjp(b::Bool) = ifelse(b, AutoForwardDiff(), AutoFiniteDiff()) From 89f9efc9c347981519426d1caaceba27740a00d3 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 14 Nov 2025 07:35:06 -0500 Subject: [PATCH 3/6] Conditionally skip Enzyme tests on Julia v1.12 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Skip enzyme_closure.jl test entirely on v1.12 - Wrap EnzymeAdjoint gradient computations in concrete_solve_derivatives.jl - Wrap corresponding test assertions for Enzyme results 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- test/concrete_solve_derivatives.jl | 56 ++++++++++++++++++------------ test/runtests.jl | 4 ++- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/test/concrete_solve_derivatives.jl b/test/concrete_solve_derivatives.jl index 0a7b14c4a..3329a2639 100644 --- a/test/concrete_solve_derivatives.jl +++ b/test/concrete_solve_derivatives.jl @@ -93,15 +93,17 @@ dp7 = Zygote.gradient( sensealg = MooncakeAdjoint())), u0, p) -du08, -dp8 = Zygote.gradient( - (u0, - p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, - abstol = 1e-14, reltol = 1e-14, - saveat = 0.1, - sensealg = EnzymeAdjoint())), - u0, - p) +@static if VERSION < v"1.12" + du08, + dp8 = Zygote.gradient( + (u0, + p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = EnzymeAdjoint())), + u0, + p) +end @test ū0≈du01 rtol=1e-12 @test ū0 == du02 @@ -110,7 +112,9 @@ dp8 = Zygote.gradient( #@test ū0 ≈ du05 rtol=1e-12 @test ū0≈du06 rtol=1e-12 @test_broken ū0≈du07 rtol=1e-12 -@test ū0≈du08 rtol=1e-12 +@static if VERSION < v"1.12" + @test ū0≈du08 rtol=1e-12 +end @test adj≈dp1' rtol=1e-12 @test adj == dp2' @test adj≈dp3' rtol=1e-12 @@ -118,7 +122,9 @@ dp8 = Zygote.gradient( #@test adj ≈ dp5' rtol=1e-12 @test adj≈dp6' rtol=1e-12 @test_broken adj≈dp7' rtol=1e-12 -@test adj≈dp8' rtol=1e-12 +@static if VERSION < v"1.12" + @test adj≈dp8' rtol=1e-12 +end ### ### Direct from prob @@ -406,15 +412,17 @@ dp7 = Zygote.gradient( sensealg = MooncakeAdjoint())), u0, p) -du08, -dp8 = Zygote.gradient( - (u0, - p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, - abstol = 1e-14, reltol = 1e-14, - saveat = 0.1, - sensealg = EnzymeAdjoint())), - u0, - p) +@static if VERSION < v"1.12" + du08, + dp8 = Zygote.gradient( + (u0, + p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = EnzymeAdjoint())), + u0, + p) +end du09, dp9 = Zygote.gradient( (u0, @@ -441,7 +449,9 @@ dp10 = Zygote.gradient( #@test ū0 ≈ du05 rtol=1e-12 @test ū0≈du06 rtol=1e-12 @test_broken ū0≈du07 rtol=1e-12 -@test ū0≈du08 rtol=1e-12 +@static if VERSION < v"1.12" + @test ū0≈du08 rtol=1e-12 +end @test ū0≈du09 rtol=1e-12 @test ū0≈du010 rtol=1e-12 @test adj≈dp1' rtol=1e-12 @@ -451,7 +461,9 @@ dp10 = Zygote.gradient( #@test adj ≈ dp5' rtol=1e-12 @test adj≈dp6' rtol=1e-12 @test_broken adj≈dp7' rtol=1e-12 -@test adj≈dp8' rtol=1e-12 +@static if VERSION < v"1.12" + @test adj≈dp8' rtol=1e-12 +end @test adj≈dp9' rtol=1e-12 @test adj≈dp10' rtol=1e-12 diff --git a/test/runtests.jl b/test/runtests.jl index cfb90641e..1eea19cd3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,7 +80,9 @@ end if GROUP == "All" || GROUP == "Core6" @testset "Core 6" begin - @time @safetestset "Enzyme Closures" include("enzyme_closure.jl") + if VERSION < v"1.12" + @time @safetestset "Enzyme Closures" include("enzyme_closure.jl") + end @time @safetestset "Complex Matrix FiniteDiff Adjoint" include("complex_matrix_finitediff.jl") @time @safetestset "Null Parameters" include("null_parameters.jl") @time @safetestset "Forward Mode Prob Kwargs" include("forward_prob_kwargs.jl") From 64c767b929dc888e394bbac52e471846dc219c42 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 14 Nov 2025 07:36:23 -0500 Subject: [PATCH 4/6] Wrap Enzyme tests in adjoint.jl with version checks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Conditionally compile EnzymeVJP tests on Julia < 1.12 - Wrap InterpolatingAdjoint, QuadratureAdjoint, GaussAdjoint, and GaussKronrodAdjoint tests using EnzymeVJP - Wrap BacksolveAdjoint test with EnzymeVJP and its helper function 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude EOF ) --- test/adjoint.jl | 96 ++++++++++++++++++++++++++----------------------- 1 file changed, 52 insertions(+), 44 deletions(-) diff --git a/test/adjoint.jl b/test/adjoint.jl index c6b4ded54..c2efdeee7 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -139,16 +139,18 @@ easy_res11 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP(true))) -_, -easy_res12 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, - abstol = 1e-14, - reltol = 1e-14, - sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) -_, -easy_res13 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, - abstol = 1e-14, - reltol = 1e-14, - sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +@static if VERSION < v"1.12" + _, + easy_res12 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) + _, + easy_res13 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +end _, easy_res14 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, @@ -179,11 +181,13 @@ easy_res143 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true))) -_, -easy_res144 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, - abstol = 1e-14, - reltol = 1e-14, - sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +@static if VERSION < v"1.12" + _, + easy_res144 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +end _, easy_res145 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, @@ -212,11 +216,13 @@ easy_res143k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = GaussKronrodAdjoint(autojacvec = ReverseDiffVJP(true))) -_, -easy_res144k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, - abstol = 1e-14, - reltol = 1e-14, - sensealg = GaussKronrodAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +@static if VERSION < v"1.12" + _, + easy_res144k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = GaussKronrodAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +end _, easy_res145k = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, @@ -1049,34 +1055,36 @@ function dynamics!(du, u, p, t) du[2] = -u[2] + tanh(p[3] * u[1] + p[4] * u[2]) end -function backsolve_grad(sol, lqr_params, checkpointing) - bwd_sol = solve( - ODEAdjointProblem(sol, - BacksolveAdjoint(autojacvec = EnzymeVJP(), - checkpointing = checkpointing), +@static if VERSION < v"1.12" + function backsolve_grad(sol, lqr_params, checkpointing) + bwd_sol = solve( + ODEAdjointProblem(sol, + BacksolveAdjoint(autojacvec = EnzymeVJP(), + checkpointing = checkpointing), + Tsit5(), + nothing, nothing, nothing, nothing, nothing, + (x, lqr_params, t) -> cost(x, lqr_params)), Tsit5(), - nothing, nothing, nothing, nothing, nothing, - (x, lqr_params, t) -> cost(x, lqr_params)), - Tsit5(), + dense = false, + save_everystep = false) + + bwd_sol.u[end][1:(end - x_dim)] + #fwd_sol, bwd_sol + end + + x0 = ones(x_dim) + fwd_sol = solve(ODEProblem(dynamics!, x0, (0, T), params), + Tsit5(), abstol = 1e-9, reltol = 1e-9, + u0 = x0, + p = params, dense = false, - save_everystep = false) + save_everystep = true) - bwd_sol.u[end][1:(end - x_dim)] - #fwd_sol, bwd_sol -end + backsolve_results = backsolve_grad(fwd_sol, params, false) + backsolve_checkpointing_results = backsolve_grad(fwd_sol, params, true) -x0 = ones(x_dim) -fwd_sol = solve(ODEProblem(dynamics!, x0, (0, T), params), - Tsit5(), abstol = 1e-9, reltol = 1e-9, - u0 = x0, - p = params, - dense = false, - save_everystep = true) - -backsolve_results = backsolve_grad(fwd_sol, params, false) -backsolve_checkpointing_results = backsolve_grad(fwd_sol, params, true) - -@test backsolve_results != backsolve_checkpointing_results + @test backsolve_results != backsolve_checkpointing_results +end int_u0, int_p = adjoint_sensitivities(fwd_sol, Tsit5(), From e55afdf52aa14e9d43117aa04f80e1a12597fa44 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 14 Nov 2025 10:41:51 -0500 Subject: [PATCH 5/6] Fix PredictiveController reference in autodiff_events.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PredictiveController is in OrdinaryDiffEqCore, not OrdinaryDiffEq. Import OrdinaryDiffEqCore and use the correct module reference. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- test/autodiff_events.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/autodiff_events.jl b/test/autodiff_events.jl index c799e0f2c..2b7438aa8 100644 --- a/test/autodiff_events.jl +++ b/test/autodiff_events.jl @@ -1,5 +1,5 @@ using SciMLSensitivity -using OrdinaryDiffEq, Calculus, Test +using OrdinaryDiffEq, OrdinaryDiffEqCore, Calculus, Test using Zygote function f(du, u, p, t) @@ -56,11 +56,11 @@ g4 = Zygote.gradient(θ -> test_f2(θ, ReverseDiffAdjoint(), PIController(7 // 5 p) g6 = Zygote.gradient( θ -> test_f2(θ, ForwardDiffSensitivity(), - OrdinaryDiffEq.PredictiveController(), TRBDF2()), + OrdinaryDiffEqCore.PredictiveController(), TRBDF2()), p) @test_broken g7 = Zygote.gradient( θ -> test_f2(θ, ReverseDiffAdjoint(), - OrdinaryDiffEq.PredictiveController(), + OrdinaryDiffEqCore.PredictiveController(), TRBDF2()), p) From 1b462cc3dfe4fa29c15e343d4ffec9eab119ff4b Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Sat, 15 Nov 2025 10:11:08 -0500 Subject: [PATCH 6/6] Centralize Enzyme version check into ENZYME_ENABLED constant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Define ENZYME_ENABLED = VERSION < v"1.12" in SciMLSensitivity module - Export ENZYME_ENABLED for use in tests and downstream packages - Replace all @static if VERSION < v"1.12" with @static if ENZYME_ENABLED - Replace all if VERSION < v"1.12" with if ENZYME_ENABLED in tests - Makes it easy to change the Enzyme compatibility version in one place 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/SciMLSensitivity.jl | 7 ++++++- src/concrete_solve.jl | 2 +- src/sensitivity_algorithms.jl | 2 +- test/adjoint.jl | 8 ++++---- test/concrete_solve_derivatives.jl | 12 ++++++------ test/runtests.jl | 2 +- 6 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index faa655b60..16e57ee24 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -1,5 +1,8 @@ module SciMLSensitivity +# Enzyme is not compatible with Julia 1.12+ +const ENZYME_ENABLED = VERSION < v"1.12" + using ADTypes: ADTypes, AutoEnzyme, AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote using Accessors: @reset @@ -45,7 +48,7 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqCore, BrownFullBasicInit, DefaultInit, # AD Backends using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent, AbstractThunk, AbstractTangent -@static if VERSION < v"1.12" +@static if ENZYME_ENABLED using Enzyme: Enzyme end using FiniteDiff: FiniteDiff @@ -88,6 +91,8 @@ include("sde_tools.jl") export extract_local_sensitivities +export ENZYME_ENABLED + export ODEForwardSensitivityFunction, ODEForwardSensitivityProblem, SensitivityFunction, ODEAdjointProblem, AdjointSensitivityIntegrand, SDEAdjointProblem, RODEAdjointProblem, SensitivityAlg, diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index e02f39289..7b35d61c9 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1285,7 +1285,7 @@ function DiffEqBase._concrete_solve_adjoint( p) end -@static if VERSION < v"1.12" +@static if ENZYME_ENABLED function DiffEqBase._concrete_solve_adjoint( prob::Union{SciMLBase.AbstractDiscreteProblem, SciMLBase.AbstractODEProblem, diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index fef683a9b..124378ea9 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -801,7 +801,7 @@ EnzymeAdjoint(mode = nothing) Currently fails on almost every solver. """ -@static if VERSION < v"1.12" +@static if ENZYME_ENABLED struct EnzymeAdjoint{M <: Union{Nothing, Enzyme.EnzymeCore.Mode}} <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} mode::M diff --git a/test/adjoint.jl b/test/adjoint.jl index c2efdeee7..acc1269d4 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -139,7 +139,7 @@ easy_res11 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP(true))) -@static if VERSION < v"1.12" +@static if SciMLSensitivity.ENZYME_ENABLED _, easy_res12 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, @@ -181,7 +181,7 @@ easy_res143 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true))) -@static if VERSION < v"1.12" +@static if SciMLSensitivity.ENZYME_ENABLED _, easy_res144 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, @@ -216,7 +216,7 @@ easy_res143k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = GaussKronrodAdjoint(autojacvec = ReverseDiffVJP(true))) -@static if VERSION < v"1.12" +@static if SciMLSensitivity.ENZYME_ENABLED _, easy_res144k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, @@ -1055,7 +1055,7 @@ function dynamics!(du, u, p, t) du[2] = -u[2] + tanh(p[3] * u[1] + p[4] * u[2]) end -@static if VERSION < v"1.12" +@static if SciMLSensitivity.ENZYME_ENABLED function backsolve_grad(sol, lqr_params, checkpointing) bwd_sol = solve( ODEAdjointProblem(sol, diff --git a/test/concrete_solve_derivatives.jl b/test/concrete_solve_derivatives.jl index 3329a2639..988939e55 100644 --- a/test/concrete_solve_derivatives.jl +++ b/test/concrete_solve_derivatives.jl @@ -93,7 +93,7 @@ dp7 = Zygote.gradient( sensealg = MooncakeAdjoint())), u0, p) -@static if VERSION < v"1.12" +@static if SciMLSensitivity.ENZYME_ENABLED du08, dp8 = Zygote.gradient( (u0, @@ -112,7 +112,7 @@ end #@test ū0 ≈ du05 rtol=1e-12 @test ū0≈du06 rtol=1e-12 @test_broken ū0≈du07 rtol=1e-12 -@static if VERSION < v"1.12" +@static if SciMLSensitivity.ENZYME_ENABLED @test ū0≈du08 rtol=1e-12 end @test adj≈dp1' rtol=1e-12 @@ -122,7 +122,7 @@ end #@test adj ≈ dp5' rtol=1e-12 @test adj≈dp6' rtol=1e-12 @test_broken adj≈dp7' rtol=1e-12 -@static if VERSION < v"1.12" +@static if SciMLSensitivity.ENZYME_ENABLED @test adj≈dp8' rtol=1e-12 end @@ -412,7 +412,7 @@ dp7 = Zygote.gradient( sensealg = MooncakeAdjoint())), u0, p) -@static if VERSION < v"1.12" +@static if SciMLSensitivity.ENZYME_ENABLED du08, dp8 = Zygote.gradient( (u0, @@ -449,7 +449,7 @@ dp10 = Zygote.gradient( #@test ū0 ≈ du05 rtol=1e-12 @test ū0≈du06 rtol=1e-12 @test_broken ū0≈du07 rtol=1e-12 -@static if VERSION < v"1.12" +@static if SciMLSensitivity.ENZYME_ENABLED @test ū0≈du08 rtol=1e-12 end @test ū0≈du09 rtol=1e-12 @@ -461,7 +461,7 @@ end #@test adj ≈ dp5' rtol=1e-12 @test adj≈dp6' rtol=1e-12 @test_broken adj≈dp7' rtol=1e-12 -@static if VERSION < v"1.12" +@static if SciMLSensitivity.ENZYME_ENABLED @test adj≈dp8' rtol=1e-12 end @test adj≈dp9' rtol=1e-12 diff --git a/test/runtests.jl b/test/runtests.jl index 1eea19cd3..13848db88 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,7 +80,7 @@ end if GROUP == "All" || GROUP == "Core6" @testset "Core 6" begin - if VERSION < v"1.12" + if SciMLSensitivity.ENZYME_ENABLED @time @safetestset "Enzyme Closures" include("enzyme_closure.jl") end @time @safetestset "Complex Matrix FiniteDiff Adjoint" include("complex_matrix_finitediff.jl")