Skip to content

Commit 24e4922

Browse files
feat: add SemilinearODEFunction and SemilinearODEProblem
1 parent 3961f76 commit 24e4922

File tree

4 files changed

+639
-1
lines changed

4 files changed

+639
-1
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
2626
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
2727
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
2828
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
29+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
2930
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
3031
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3132
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
@@ -44,6 +45,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
4445
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
4546
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
4647
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
48+
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
4749
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
4850
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4951
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -114,6 +116,7 @@ DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
114116
EnumX = "1.0.4"
115117
ExprTools = "0.1.10"
116118
FMI = "0.14"
119+
FillArrays = "1.13.0"
117120
FindFirstFunctions = "1"
118121
ForwardDiff = "0.10.3, 1"
119122
FunctionWrappers = "1.1"
@@ -141,6 +144,7 @@ OrdinaryDiffEq = "6.82.0"
141144
OrdinaryDiffEqCore = "1.34.0"
142145
OrdinaryDiffEqDefault = "1.2"
143146
OrdinaryDiffEqNonlinearSolve = "1.5.0"
147+
PreallocationTools = "0.4.27"
144148
PrecompileTools = "1"
145149
Pyomo = "0.1.0"
146150
REPL = "1"

src/ModelingToolkit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ const DQ = DynamicQuantities
104104
import DifferentiationInterface as DI
105105
using ADTypes: AutoForwardDiff
106106
import SciMLPublic: @public
107+
import PreallocationTools
108+
import PreallocationTools: DiffCache
109+
import FillArrays
107110

108111
export @derivatives
109112

@@ -262,6 +265,7 @@ export IntervalNonlinearProblem
262265
export OptimizationProblem, constraints
263266
export SteadyStateProblem
264267
export JumpProblem
268+
export SemilinearODEFunction, SemilinearODEProblem
265269
export alias_elimination, flatten
266270
export connect, domain_connect, @connector, Connection, AnalysisPoint, Flow, Stream,
267271
instream

src/problems/odeproblem.jl

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,163 @@ end
108108
maybe_codegen_scimlproblem(expression, SteadyStateProblem{iip}, args; kwargs...)
109109
end
110110

111+
@fallback_iip_specialize function SemilinearODEFunction{iip, specialize}(
112+
sys::System; u0 = nothing, p = nothing, t = nothing,
113+
semiquadratic_form = nothing,
114+
stiff_linear = true, stiff_quadratic = false, stiff_nonlinear = false,
115+
eval_expression = false, eval_module = @__MODULE__,
116+
expression = Val{false}, sparse = false, check_compatibility = true,
117+
jac = false, checkbounds = false, cse = true, initialization_data = nothing,
118+
analytic = nothing, kwargs...) where {iip, specialize}
119+
check_complete(sys, SemilinearODEFunction)
120+
check_compatibility && check_compatible_system(SemilinearODEFunction, sys)
121+
122+
if semiquadratic_form === nothing
123+
semiquadratic_form = calculate_semiquadratic_form(sys; sparse)
124+
sys = add_semiquadratic_parameters(sys, semiquadratic_form...)
125+
end
126+
127+
A, B, C = semiquadratic_form
128+
M = calculate_massmatrix(sys)
129+
_M = concrete_massmatrix(M; sparse, u0)
130+
dvs = unknowns(sys)
131+
132+
f1,
133+
f2 = generate_semiquadratic_functions(
134+
sys, A, B, C; stiff_linear, stiff_quadratic,
135+
stiff_nonlinear, expression, wrap_gfw = Val{true},
136+
eval_expression, eval_module, kwargs...)
137+
138+
if jac
139+
Cjac = (C === nothing || !stiff_nonlinear) ? nothing : Symbolics.jacobian(C, dvs)
140+
_jac = generate_semiquadratic_jacobian(
141+
sys, A, B, C, Cjac; sparse, expression,
142+
wrap_gfw = Val{true}, eval_expression, eval_module, kwargs...)
143+
_W_sparsity = get_semiquadratic_W_sparsity(
144+
sys, A, B, C, Cjac; stiff_linear, stiff_quadratic, stiff_nonlinear, mm = M)
145+
W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse)
146+
else
147+
_jac = nothing
148+
W_prototype = nothing
149+
end
150+
151+
observedfun = ObservedFunctionCache(
152+
sys; expression, steady_state = false, eval_expression, eval_module, checkbounds, cse)
153+
154+
args = (; f1)
155+
kwargs = (; jac = _jac, jac_prototype = W_prototype)
156+
f1 = maybe_codegen_scimlfn(expression, ODEFunction{iip, specialize}, args; kwargs...)
157+
158+
args = (; f1, f2)
159+
kwargs = (;
160+
sys = sys,
161+
jac = _jac,
162+
mass_matrix = _M,
163+
jac_prototype = W_prototype,
164+
observed = observedfun,
165+
analytic,
166+
initialization_data)
167+
168+
return maybe_codegen_scimlfn(
169+
expression, SplitFunction{iip, specialize}, args; kwargs...)
170+
end
171+
172+
@fallback_iip_specialize function SemilinearODEProblem{iip, spec}(
173+
sys::System, op, tspan; check_compatibility = true, u0_eltype = nothing,
174+
expression = Val{false}, callback = nothing, sparse = false,
175+
stiff_linear = true, stiff_quadratic = false, stiff_nonlinear = false,
176+
jac = false, kwargs...) where {
177+
iip, spec}
178+
check_complete(sys, SemilinearODEProblem)
179+
check_compatibility && check_compatible_system(SemilinearODEProblem, sys)
180+
181+
A, B, C = semiquadratic_form = calculate_semiquadratic_form(sys; sparse)
182+
eqs = equations(sys)
183+
dvs = unknowns(sys)
184+
185+
sys = add_semiquadratic_parameters(sys, A, B, C)
186+
if A !== nothing
187+
linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME))
188+
else
189+
linear_matrix_param = nothing
190+
end
191+
if B !== nothing
192+
quadratic_forms = [unwrap(getproperty(sys, get_quadratic_form_name(i)))
193+
for i in 1:length(eqs)]
194+
diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME))
195+
else
196+
quadratic_forms = diffcache_par = nothing
197+
end
198+
199+
op = to_varmap(op, dvs)
200+
floatT = calculate_float_type(op, typeof(op))
201+
_u0_eltype = something(u0_eltype, floatT)
202+
203+
guess = copy(guesses(sys))
204+
defs = copy(defaults(sys))
205+
if A !== nothing
206+
guess[linear_matrix_param] = fill(NaN, size(A))
207+
defs[linear_matrix_param] = A
208+
end
209+
if B !== nothing
210+
for (par, mat) in zip(quadratic_forms, B)
211+
guess[par] = fill(NaN, size(mat))
212+
defs[par] = mat
213+
end
214+
cachelen = jac ? length(dvs) * length(eqs) : length(dvs)
215+
defs[diffcache_par] = DiffCache(zeros(DiffEqBase.value(_u0_eltype), cachelen))
216+
end
217+
@set! sys.guesses = guess
218+
@set! sys.defaults = defs
219+
220+
f, u0,
221+
p = process_SciMLProblem(SemilinearODEFunction{iip, spec}, sys, op;
222+
t = tspan !== nothing ? tspan[1] : tspan, expression, check_compatibility,
223+
semiquadratic_form, sparse, u0_eltype, stiff_linear, stiff_quadratic, stiff_nonlinear, jac, kwargs...)
224+
225+
kwargs = process_kwargs(sys; expression, callback, kwargs...)
226+
227+
args = (; f, u0, tspan, p)
228+
maybe_codegen_scimlproblem(expression, SplitODEProblem{iip}, args; kwargs...)
229+
end
230+
231+
"""
232+
$(TYPEDSIGNATURES)
233+
234+
Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices
235+
`A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref).
236+
"""
237+
function add_semiquadratic_parameters(sys::System, A, B, C)
238+
eqs = equations(sys)
239+
n = length(eqs)
240+
var_to_name = copy(get_var_to_name(sys))
241+
if B !== nothing
242+
for i in eachindex(B)
243+
B[i] === nothing && continue
244+
par = get_quadratic_form_param((n, n), i)
245+
var_to_name[get_quadratic_form_name(i)] = par
246+
sys = with_additional_constant_parameter(sys, par)
247+
end
248+
par = get_diffcache_param(Float64)
249+
var_to_name[DIFFCACHE_PARAM_NAME] = par
250+
sys = with_additional_nonnumeric_parameter(sys, par)
251+
end
252+
if A !== nothing
253+
par = get_linear_matrix_param((n, n))
254+
var_to_name[LINEAR_MATRIX_PARAM_NAME] = par
255+
sys = with_additional_constant_parameter(sys, par)
256+
end
257+
@set! sys.var_to_name = var_to_name
258+
if get_parent(sys) !== nothing
259+
@set! sys.parent = add_semiquadratic_parameters(get_parent(sys), A, B, C)
260+
end
261+
return sys
262+
end
263+
111264
function check_compatible_system(
112265
T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
113-
Type{DAEProblem}, Type{SteadyStateProblem}},
266+
Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
267+
Type{SemilinearODEProblem}},
114268
sys::System)
115269
check_time_dependent(sys, T)
116270
check_not_dde(sys)

0 commit comments

Comments
 (0)