Skip to content

Commit 8814ce1

Browse files
Move SDE default algorithm from DifferentialEquations.jl to StochasticDiffEq.jl
This PR moves the default SDE solver implementation from DifferentialEquations.jl to StochasticDiffEq.jl, following the pattern established in SciML/DelayDiffEq.jl#326 and SciML/DelayDiffEq.jl#334. ## Changes - Added `src/default_sde_alg.jl` containing the default algorithm selection logic - Implemented `__init` and `__solve` dispatches for `SDEProblem` with `Nothing` algorithm - Added `get_alg_hints` helper function for extracting algorithm hints from kwargs - Added comprehensive tests in `test/default_solver_test.jl` - Updated module to include the new default algorithm file ## Default Algorithm Behavior When no algorithm is specified, the solver now automatically selects: - SOSRI() as the standard default - RKMilCommute() for commutative noise - ImplicitRKMil() for stiff problems or non-identity mass matrices - RKMil() for Stratonovich interpretation - LambaEM() / LambaEulerHeun() for non-diagonal noise - ISSEM() / ImplicitEulerHeun() for stiff non-diagonal problems - SOSRA() / SKenCarp() for additive noise ## Test Plan - [x] Added tests verifying default solver dispatch - [x] Tests verify correct algorithm selection for various problem types - [x] All tests pass locally This is part of the ongoing effort to modularize DifferentialEquations.jl by moving default solvers to their respective packages. 🤖 Generated with [Claude Code](https://claude.ai/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent e23a84d commit 8814ce1

File tree

4 files changed

+139
-0
lines changed

4 files changed

+139
-0
lines changed

src/StochasticDiffEq.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ include("iterated_integrals.jl")
158158
include("SROCK_utils.jl")
159159
include("composite_algs.jl")
160160
include("weak_utils.jl")
161+
include("default_sde_alg.jl")
161162

162163
export StochasticDiffEqAlgorithm, StochasticDiffEqAdaptiveAlgorithm,
163164
StochasticCompositeAlgorithm

src/default_sde_alg.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Default algorithm selection for SDEs
2+
# Moved from DifferentialEquations.jl as part of modularization effort
3+
4+
using LinearAlgebra: I
5+
6+
# Helper function to extract alg_hints from keyword arguments
7+
function get_alg_hints(o)
8+
:alg_hints keys(o) ? alg_hints = o[:alg_hints] : alg_hints = Symbol[:auto]
9+
end
10+
11+
function default_algorithm(
12+
prob::DiffEqBase.AbstractSDEProblem{uType, tType, isinplace, ND};
13+
kwargs...) where {uType, tType, isinplace, ND}
14+
o = Dict{Symbol, Any}(kwargs)
15+
alg = SOSRI() # Standard default
16+
17+
alg_hints = get_alg_hints(o)
18+
19+
if :commutative alg_hints
20+
alg = RKMilCommute()
21+
end
22+
23+
is_stiff = :stiff alg_hints
24+
is_stratonovich = :stratonovich alg_hints
25+
if is_stiff || prob.f.mass_matrix !== I
26+
alg = ImplicitRKMil(autodiff = false)
27+
end
28+
29+
if is_stratonovich
30+
if is_stiff || prob.f.mass_matrix !== I
31+
alg = ImplicitRKMil(autodiff = false,
32+
interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich)
33+
else
34+
alg = RKMil(interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich)
35+
end
36+
end
37+
38+
if prob.noise_rate_prototype != nothing || prob.noise != nothing
39+
if is_stratonovich
40+
if is_stiff || prob.f.mass_matrix !== I
41+
alg = ImplicitEulerHeun(autodiff = false)
42+
else
43+
alg = LambaEulerHeun()
44+
end
45+
else
46+
if is_stiff || prob.f.mass_matrix !== I
47+
alg = ISSEM(autodiff = false)
48+
else
49+
alg = LambaEM()
50+
end
51+
end
52+
end
53+
54+
if :additive alg_hints
55+
if is_stiff || prob.f.mass_matrix !== I
56+
alg = SKenCarp(autodiff = false)
57+
else
58+
alg = SOSRA()
59+
end
60+
end
61+
62+
return alg
63+
end
64+
65+
# Dispatch for __init with Nothing algorithm - use default
66+
function DiffEqBase.__init(
67+
prob::DiffEqBase.AbstractSDEProblem, ::Nothing, args...; kwargs...)
68+
alg = default_algorithm(prob; kwargs...)
69+
DiffEqBase.__init(prob, alg, args...; kwargs...)
70+
end
71+
72+
# Dispatch for __solve with Nothing algorithm - use default
73+
function DiffEqBase.__solve(
74+
prob::DiffEqBase.AbstractSDEProblem, ::Nothing, args...; kwargs...)
75+
alg = default_algorithm(prob; kwargs...)
76+
DiffEqBase.__solve(prob, alg, args...; kwargs...)
77+
end

test/default_solver_test.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using StochasticDiffEq, Test
2+
import SciMLBase
3+
using Random
4+
5+
# Additive SDE test problem
6+
f_additive(u, p, t) = @. p[2] / sqrt(1 + t) - u / (2 * (1 + t))
7+
σ_additive(u, p, t) = @. p[1] * p[2] / sqrt(1 + t)
8+
p = (0.1, 0.05)
9+
additive_analytic(u0, p, t, W) = @. u0 / sqrt(1 + t) + p[2] * (t + p[1] * W) / sqrt(1 + t)
10+
ff_additive = SDEFunction(f_additive, σ_additive, analytic = additive_analytic)
11+
prob_sde_additive = SDEProblem(ff_additive, σ_additive, 1.0, (0.0, 1.0), p)
12+
13+
Random.seed!(100)
14+
15+
# Test default (no algorithm specified) - should use SOSRI
16+
prob = prob_sde_additive
17+
sol = solve(prob, dt = 1 / 2^(3))
18+
@test sol.alg isa SOSRI
19+
20+
# Test with :additive hint - should use SOSRA
21+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:additive])
22+
@test sol.alg isa SOSRA
23+
24+
# Test with :stratonovich hint - should use RKMil with Stratonovich interpretation
25+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stratonovich])
26+
@test SciMLBase.alg_interpretation(sol.alg) ==
27+
SciMLBase.AlgorithmInterpretation.Stratonovich
28+
@test sol.alg isa RKMil
29+
30+
# Non-diagonal noise test problem
31+
f = (du, u, p, t) -> du .= 1.01u
32+
g = function (du, u, p, t)
33+
du[1, 1] = 0.3u[1]
34+
du[1, 2] = 0.6u[1]
35+
du[1, 3] = 0.9u[1]
36+
du[1, 4] = 0.12u[2]
37+
du[2, 1] = 1.2u[1]
38+
du[2, 2] = 0.2u[2]
39+
du[2, 3] = 0.3u[2]
40+
du[2, 4] = 1.8u[2]
41+
end
42+
prob = SDEProblem(f, g, ones(2), (0.0, 1.0), noise_rate_prototype = zeros(2, 4))
43+
44+
# Test default with non-diagonal noise - should use LambaEM
45+
sol = solve(prob, dt = 1 / 2^(3))
46+
@test sol.alg isa LambaEM
47+
48+
# Test with :stiff hint - should use ISSEM
49+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stiff])
50+
@test sol.alg isa ISSEM
51+
52+
# Test with :additive hint - should still use SOSRA (overrides non-diagonal)
53+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:additive])
54+
@test sol.alg isa SOSRA
55+
56+
# Test with :stratonovich hint - should use LambaEulerHeun
57+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stratonovich])
58+
@test sol.alg isa LambaEulerHeun

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR")
1515

1616
@time begin
1717
if GROUP == "All" || GROUP == "Interface1"
18+
@time @safetestset "Default Solver Tests" begin
19+
include("default_solver_test.jl")
20+
end
1821
@time @safetestset "First Rand Tests" begin
1922
include("first_rand_test.jl")
2023
end

0 commit comments

Comments
 (0)