Skip to content

Commit 0ddad8b

Browse files
Merge pull request #633 from ChrisRackauckas-Claude/move-default-sde-algorithm
Move SDE default algorithm from DifferentialEquations.jl to StochasticDiffEq.jl
2 parents e23a84d + 8814ce1 commit 0ddad8b

File tree

4 files changed

+139
-0
lines changed

4 files changed

+139
-0
lines changed

src/StochasticDiffEq.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ include("iterated_integrals.jl")
158158
include("SROCK_utils.jl")
159159
include("composite_algs.jl")
160160
include("weak_utils.jl")
161+
include("default_sde_alg.jl")
161162

162163
export StochasticDiffEqAlgorithm, StochasticDiffEqAdaptiveAlgorithm,
163164
StochasticCompositeAlgorithm

src/default_sde_alg.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Default algorithm selection for SDEs
2+
# Moved from DifferentialEquations.jl as part of modularization effort
3+
4+
using LinearAlgebra: I
5+
6+
# Helper function to extract alg_hints from keyword arguments
7+
function get_alg_hints(o)
8+
:alg_hints keys(o) ? alg_hints = o[:alg_hints] : alg_hints = Symbol[:auto]
9+
end
10+
11+
function default_algorithm(
12+
prob::DiffEqBase.AbstractSDEProblem{uType, tType, isinplace, ND};
13+
kwargs...) where {uType, tType, isinplace, ND}
14+
o = Dict{Symbol, Any}(kwargs)
15+
alg = SOSRI() # Standard default
16+
17+
alg_hints = get_alg_hints(o)
18+
19+
if :commutative alg_hints
20+
alg = RKMilCommute()
21+
end
22+
23+
is_stiff = :stiff alg_hints
24+
is_stratonovich = :stratonovich alg_hints
25+
if is_stiff || prob.f.mass_matrix !== I
26+
alg = ImplicitRKMil(autodiff = false)
27+
end
28+
29+
if is_stratonovich
30+
if is_stiff || prob.f.mass_matrix !== I
31+
alg = ImplicitRKMil(autodiff = false,
32+
interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich)
33+
else
34+
alg = RKMil(interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich)
35+
end
36+
end
37+
38+
if prob.noise_rate_prototype != nothing || prob.noise != nothing
39+
if is_stratonovich
40+
if is_stiff || prob.f.mass_matrix !== I
41+
alg = ImplicitEulerHeun(autodiff = false)
42+
else
43+
alg = LambaEulerHeun()
44+
end
45+
else
46+
if is_stiff || prob.f.mass_matrix !== I
47+
alg = ISSEM(autodiff = false)
48+
else
49+
alg = LambaEM()
50+
end
51+
end
52+
end
53+
54+
if :additive alg_hints
55+
if is_stiff || prob.f.mass_matrix !== I
56+
alg = SKenCarp(autodiff = false)
57+
else
58+
alg = SOSRA()
59+
end
60+
end
61+
62+
return alg
63+
end
64+
65+
# Dispatch for __init with Nothing algorithm - use default
66+
function DiffEqBase.__init(
67+
prob::DiffEqBase.AbstractSDEProblem, ::Nothing, args...; kwargs...)
68+
alg = default_algorithm(prob; kwargs...)
69+
DiffEqBase.__init(prob, alg, args...; kwargs...)
70+
end
71+
72+
# Dispatch for __solve with Nothing algorithm - use default
73+
function DiffEqBase.__solve(
74+
prob::DiffEqBase.AbstractSDEProblem, ::Nothing, args...; kwargs...)
75+
alg = default_algorithm(prob; kwargs...)
76+
DiffEqBase.__solve(prob, alg, args...; kwargs...)
77+
end

test/default_solver_test.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using StochasticDiffEq, Test
2+
import SciMLBase
3+
using Random
4+
5+
# Additive SDE test problem
6+
f_additive(u, p, t) = @. p[2] / sqrt(1 + t) - u / (2 * (1 + t))
7+
σ_additive(u, p, t) = @. p[1] * p[2] / sqrt(1 + t)
8+
p = (0.1, 0.05)
9+
additive_analytic(u0, p, t, W) = @. u0 / sqrt(1 + t) + p[2] * (t + p[1] * W) / sqrt(1 + t)
10+
ff_additive = SDEFunction(f_additive, σ_additive, analytic = additive_analytic)
11+
prob_sde_additive = SDEProblem(ff_additive, σ_additive, 1.0, (0.0, 1.0), p)
12+
13+
Random.seed!(100)
14+
15+
# Test default (no algorithm specified) - should use SOSRI
16+
prob = prob_sde_additive
17+
sol = solve(prob, dt = 1 / 2^(3))
18+
@test sol.alg isa SOSRI
19+
20+
# Test with :additive hint - should use SOSRA
21+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:additive])
22+
@test sol.alg isa SOSRA
23+
24+
# Test with :stratonovich hint - should use RKMil with Stratonovich interpretation
25+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stratonovich])
26+
@test SciMLBase.alg_interpretation(sol.alg) ==
27+
SciMLBase.AlgorithmInterpretation.Stratonovich
28+
@test sol.alg isa RKMil
29+
30+
# Non-diagonal noise test problem
31+
f = (du, u, p, t) -> du .= 1.01u
32+
g = function (du, u, p, t)
33+
du[1, 1] = 0.3u[1]
34+
du[1, 2] = 0.6u[1]
35+
du[1, 3] = 0.9u[1]
36+
du[1, 4] = 0.12u[2]
37+
du[2, 1] = 1.2u[1]
38+
du[2, 2] = 0.2u[2]
39+
du[2, 3] = 0.3u[2]
40+
du[2, 4] = 1.8u[2]
41+
end
42+
prob = SDEProblem(f, g, ones(2), (0.0, 1.0), noise_rate_prototype = zeros(2, 4))
43+
44+
# Test default with non-diagonal noise - should use LambaEM
45+
sol = solve(prob, dt = 1 / 2^(3))
46+
@test sol.alg isa LambaEM
47+
48+
# Test with :stiff hint - should use ISSEM
49+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stiff])
50+
@test sol.alg isa ISSEM
51+
52+
# Test with :additive hint - should still use SOSRA (overrides non-diagonal)
53+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:additive])
54+
@test sol.alg isa SOSRA
55+
56+
# Test with :stratonovich hint - should use LambaEulerHeun
57+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stratonovich])
58+
@test sol.alg isa LambaEulerHeun

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR")
1515

1616
@time begin
1717
if GROUP == "All" || GROUP == "Interface1"
18+
@time @safetestset "Default Solver Tests" begin
19+
include("default_solver_test.jl")
20+
end
1821
@time @safetestset "First Rand Tests" begin
1922
include("first_rand_test.jl")
2023
end

0 commit comments

Comments
 (0)