@@ -108,9 +108,163 @@ end
108108 maybe_codegen_scimlproblem (expression, SteadyStateProblem{iip}, args; kwargs... )
109109end
110110
111+ @fallback_iip_specialize function SemilinearODEFunction {iip, specialize} (
112+ sys:: System ; u0 = nothing , p = nothing , t = nothing ,
113+ semiquadratic_form = nothing ,
114+ stiff_linear = true , stiff_quadratic = false , stiff_nonlinear = false ,
115+ eval_expression = false , eval_module = @__MODULE__ ,
116+ expression = Val{false }, sparse = false , check_compatibility = true ,
117+ jac = false , checkbounds = false , cse = true , initialization_data = nothing ,
118+ analytic = nothing , kwargs... ) where {iip, specialize}
119+ check_complete (sys, SemilinearODEFunction)
120+ check_compatibility && check_compatible_system (SemilinearODEFunction, sys)
121+
122+ if semiquadratic_form === nothing
123+ semiquadratic_form = calculate_semiquadratic_form (sys; sparse)
124+ sys = add_semiquadratic_parameters (sys, semiquadratic_form... )
125+ end
126+
127+ A, B, C = semiquadratic_form
128+ M = calculate_massmatrix (sys)
129+ _M = concrete_massmatrix (M; sparse, u0)
130+ dvs = unknowns (sys)
131+
132+ f1,
133+ f2 = generate_semiquadratic_functions (
134+ sys, A, B, C; stiff_linear, stiff_quadratic,
135+ stiff_nonlinear, expression, wrap_gfw = Val{true },
136+ eval_expression, eval_module, kwargs... )
137+
138+ if jac
139+ Cjac = (C === nothing || ! stiff_nonlinear) ? nothing : Symbolics. jacobian (C, dvs)
140+ _jac = generate_semiquadratic_jacobian (
141+ sys, A, B, C, Cjac; sparse, expression,
142+ wrap_gfw = Val{true }, eval_expression, eval_module, kwargs... )
143+ _W_sparsity = get_semiquadratic_W_sparsity (
144+ sys, A, B, C, Cjac; stiff_linear, stiff_quadratic, stiff_nonlinear, mm = M)
145+ W_prototype = calculate_W_prototype (_W_sparsity; u0, sparse)
146+ else
147+ _jac = nothing
148+ W_prototype = nothing
149+ end
150+
151+ observedfun = ObservedFunctionCache (
152+ sys; expression, steady_state = false , eval_expression, eval_module, checkbounds, cse)
153+
154+ args = (; f1)
155+ kwargs = (; jac = _jac, jac_prototype = W_prototype)
156+ f1 = maybe_codegen_scimlfn (expression, ODEFunction{iip, specialize}, args; kwargs... )
157+
158+ args = (; f1, f2)
159+ kwargs = (;
160+ sys = sys,
161+ jac = _jac,
162+ mass_matrix = _M,
163+ jac_prototype = W_prototype,
164+ observed = observedfun,
165+ analytic,
166+ initialization_data)
167+
168+ return maybe_codegen_scimlfn (
169+ expression, SplitFunction{iip, specialize}, args; kwargs... )
170+ end
171+
172+ @fallback_iip_specialize function SemilinearODEProblem {iip, spec} (
173+ sys:: System , op, tspan; check_compatibility = true , u0_eltype = nothing ,
174+ expression = Val{false }, callback = nothing , sparse = false ,
175+ stiff_linear = true , stiff_quadratic = false , stiff_nonlinear = false ,
176+ jac = false , kwargs... ) where {
177+ iip, spec}
178+ check_complete (sys, SemilinearODEProblem)
179+ check_compatibility && check_compatible_system (SemilinearODEProblem, sys)
180+
181+ A, B, C = semiquadratic_form = calculate_semiquadratic_form (sys; sparse)
182+ eqs = equations (sys)
183+ dvs = unknowns (sys)
184+
185+ sys = add_semiquadratic_parameters (sys, A, B, C)
186+ if A != = nothing
187+ linear_matrix_param = unwrap (getproperty (sys, LINEAR_MATRIX_PARAM_NAME))
188+ else
189+ linear_matrix_param = nothing
190+ end
191+ if B != = nothing
192+ quadratic_forms = [unwrap (getproperty (sys, get_quadratic_form_name (i)))
193+ for i in 1 : length (eqs)]
194+ diffcache_par = unwrap (getproperty (sys, DIFFCACHE_PARAM_NAME))
195+ else
196+ quadratic_forms = diffcache_par = nothing
197+ end
198+
199+ op = to_varmap (op, dvs)
200+ floatT = calculate_float_type (op, typeof (op))
201+ _u0_eltype = something (u0_eltype, floatT)
202+
203+ guess = copy (guesses (sys))
204+ defs = copy (defaults (sys))
205+ if A != = nothing
206+ guess[linear_matrix_param] = fill (NaN , size (A))
207+ defs[linear_matrix_param] = A
208+ end
209+ if B != = nothing
210+ for (par, mat) in zip (quadratic_forms, B)
211+ guess[par] = fill (NaN , size (mat))
212+ defs[par] = mat
213+ end
214+ cachelen = jac ? length (dvs) * length (eqs) : length (dvs)
215+ defs[diffcache_par] = DiffCache (zeros (DiffEqBase. value (_u0_eltype), cachelen))
216+ end
217+ @set! sys. guesses = guess
218+ @set! sys. defaults = defs
219+
220+ f, u0,
221+ p = process_SciMLProblem (SemilinearODEFunction{iip, spec}, sys, op;
222+ t = tspan != = nothing ? tspan[1 ] : tspan, expression, check_compatibility,
223+ semiquadratic_form, sparse, u0_eltype, stiff_linear, stiff_quadratic, stiff_nonlinear, jac, kwargs... )
224+
225+ kwargs = process_kwargs (sys; expression, callback, kwargs... )
226+
227+ args = (; f, u0, tspan, p)
228+ maybe_codegen_scimlproblem (expression, SplitODEProblem{iip}, args; kwargs... )
229+ end
230+
231+ """
232+ $(TYPEDSIGNATURES)
233+
234+ Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices
235+ `A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref).
236+ """
237+ function add_semiquadratic_parameters (sys:: System , A, B, C)
238+ eqs = equations (sys)
239+ n = length (eqs)
240+ var_to_name = copy (get_var_to_name (sys))
241+ if B != = nothing
242+ for i in eachindex (B)
243+ B[i] === nothing && continue
244+ par = get_quadratic_form_param ((n, n), i)
245+ var_to_name[get_quadratic_form_name (i)] = par
246+ sys = with_additional_constant_parameter (sys, par)
247+ end
248+ par = get_diffcache_param (Float64)
249+ var_to_name[DIFFCACHE_PARAM_NAME] = par
250+ sys = with_additional_nonnumeric_parameter (sys, par)
251+ end
252+ if A != = nothing
253+ par = get_linear_matrix_param ((n, n))
254+ var_to_name[LINEAR_MATRIX_PARAM_NAME] = par
255+ sys = with_additional_constant_parameter (sys, par)
256+ end
257+ @set! sys. var_to_name = var_to_name
258+ if get_parent (sys) != = nothing
259+ @set! sys. parent = add_semiquadratic_parameters (get_parent (sys), A, B, C)
260+ end
261+ return sys
262+ end
263+
111264function check_compatible_system (
112265 T:: Union {Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
113- Type{DAEProblem}, Type{SteadyStateProblem}},
266+ Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
267+ Type{SemilinearODEProblem}},
114268 sys:: System )
115269 check_time_dependent (sys, T)
116270 check_not_dde (sys)
0 commit comments