Skip to content

Commit e535746

Browse files
committed
add tune_parameters for JuMPDynamicOptProblem
1 parent 7b4a284 commit e535746

File tree

2 files changed

+51
-22
lines changed

2 files changed

+51
-22
lines changed

ext/MTKInfiniteOptExt.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module MTKInfiniteOptExt
22
using ModelingToolkit
33
using InfiniteOpt
44
using DiffEqBase
5+
using SciMLStructures
56
using LinearAlgebra
67
using StaticArrays
78
using UnPack
@@ -13,12 +14,13 @@ struct InfiniteOptModel
1314
model::InfiniteModel
1415
U::Vector{<:AbstractVariableRef}
1516
V::Vector{<:AbstractVariableRef}
17+
P::Vector{<:AbstractVariableRef}
1618
tₛ::AbstractVariableRef
1719
is_free_final::Bool
1820
end
1921

2022
struct JuMPDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
21-
AbstractDynamicOptProblem{uType, tType, isinplace}
23+
SciMLBase.AbstractDynamicOptProblem{uType, tType, isinplace}
2224
f::F
2325
u0::uType
2426
tspan::tType
@@ -33,7 +35,7 @@ struct JuMPDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
3335
end
3436

3537
struct InfiniteOptDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
36-
AbstractDynamicOptProblem{uType, tType, isinplace}
38+
SciMLBase.AbstractDynamicOptProblem{uType, tType, isinplace}
3739
f::F
3840
u0::uType
3941
tspan::tType
@@ -57,6 +59,9 @@ end
5759
function MTK.generate_input_variable!(m::InfiniteModel, c0, nc, ts)
5860
@variable(m, V[i = 1:nc], Infinite(m[:t]), start=c0[i])
5961
end
62+
function MTK.generate_tunable_params!(m::InfiniteModel, p0, np)
63+
@variable(m, P[i=1:np], start=p0[i])
64+
end
6065

6166
function MTK.generate_timescale!(m::InfiniteModel, guess, is_free_t)
6267
@variable(m, tₛ 0, start = guess)
@@ -81,10 +86,11 @@ MTK.set_objective!(m::InfiniteOptModel, expr) = @objective(m.model, Min, expr)
8186
function MTK.JuMPDynamicOptProblem(sys::System, op, tspan;
8287
dt = nothing,
8388
steps = nothing,
89+
tune_parameters = false,
8490
guesses = Dict(), kwargs...)
8591
prob,
8692
_ = MTK.process_DynamicOptProblem(JuMPDynamicOptProblem, InfiniteOptModel, sys,
87-
op, tspan; dt, steps, guesses, kwargs...)
93+
op, tspan; dt, steps, tune_parameters, guesses, kwargs...)
8894
prob
8995
end
9096

@@ -125,13 +131,24 @@ function MTK.lowered_var(m::InfiniteOptModel, uv, i, t)
125131
t isa Union{Num, Symbolics.Symbolic} ? X[i] : X[i](t)
126132
end
127133

134+
function f_wrapper(f, Uₙ, Vₙ, p, P, t)
135+
if SciMLStructures.isscimlstructure(p)
136+
_, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
137+
p′ = repack(P)
138+
f(Uₙ, Vₙ, p′, t)
139+
else
140+
f(Uₙ, Vₙ, P, t)
141+
end
142+
end
143+
128144
function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau)
129145
@unpack A, α, c = tableau
130146
@unpack wrapped_model, f, p = prob
131-
@unpack tₛ, U, V, model = wrapped_model
147+
@unpack tₛ, U, V, P, model = wrapped_model
132148
t = model[:t]
133149
tsteps = supports(t)
134-
dt = tsteps[2] - tsteps[1]
150+
151+
dt = (tsteps[end] - tsteps[1]) / (length(tsteps) - 1)
135152

136153
nᵤ = length(U)
137154
nᵥ = length(V)
@@ -142,7 +159,7 @@ function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau)
142159
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = zeros(nᵤ))
143160
Uₙ = [U[i](τ) + ΔU[i] * dt for i in 1:nᵤ]
144161
Vₙ = [V[i](τ) for i in 1:nᵥ]
145-
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt)
162+
Kₙ = tₛ * f_wrapper(f, Uₙ, Vₙ, p, P, τ + h * dt)
146163
push!(K, Kₙ)
147164
end
148165
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
@@ -158,7 +175,7 @@ function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau)
158175
for (i, h) in enumerate(c)
159176
ΔU = @view ΔUs[i, :]
160177
Uₙ = U + ΔU * dt
161-
@constraint(model, [j = 1:nᵤ], K[i, j]==(tₛ * f(Uₙ, V, p, τ + h * dt)[j]),
178+
@constraint(model, [j = 1:nᵤ], K[i, j]==(tₛ * f_wrapper(f, Uₙ, V, p, P, τ + h * dt)[j]),
162179
DomainRestrictions(t => τ), base_name="solve_K$i()")
163180
end
164181
@constraint(model,
@@ -233,6 +250,7 @@ function MTK.get_U_values(m::InfiniteModel)
233250
U_vals = value.(m[:U])
234251
U_vals = [[U_vals[i][j] for i in 1:length(U_vals)] for j in 1:nt]
235252
end
253+
MTK.get_P_values(m::InfiniteModel) = value(m[:P])
236254
MTK.get_t_values(m::InfiniteModel) = value(m[:tₛ]) * supports(m[:t])
237255
MTK.objective_value(m::InfiniteModel) = InfiniteOpt.objective_value(m)
238256

src/systems/optimal_control_interface.jl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
abstract type AbstractDynamicOptProblem{uType, tType, isinplace} <:
2-
SciMLBase.AbstractODEProblem{uType, tType, isinplace} end
3-
41
abstract type AbstractCollocation end
52

63
struct DynamicOptSolution
@@ -22,8 +19,8 @@ end
2219
JuMPDynamicOptProblem(sys::System, op, tspan; dt, steps, guesses, kwargs...)
2320
2421
Convert an System representing an optimal control system into a JuMP model
25-
for solving using optimization. Must provide either `dt`, the timestep between collocation
26-
points (which, along with the timespan, determines the number of points), or directly
22+
for solving using optimization. Must provide either `dt`, the timestep between collocation
23+
points (which, along with the timespan, determines the number of points), or directly
2724
provide the number of points as `steps`.
2825
2926
To construct the problem, please load InfiniteOpt along with ModelingToolkit.
@@ -33,7 +30,7 @@ function JuMPDynamicOptProblem end
3330
InfiniteOptDynamicOptProblem(sys::System, op, tspan; dt)
3431
3532
Convert an System representing an optimal control system into a InfiniteOpt model
36-
for solving using optimization. Must provide `dt` for determining the length
33+
for solving using optimization. Must provide `dt` for determining the length
3734
of the interpolation arrays.
3835
3936
Related to `JuMPDynamicOptProblem`, but directly adds the differential equations
@@ -46,8 +43,8 @@ function InfiniteOptDynamicOptProblem end
4643
CasADiDynamicOptProblem(sys::System, op, tspan; dt, steps, guesses, kwargs...)
4744
4845
Convert an System representing an optimal control system into a CasADi model
49-
for solving using optimization. Must provide either `dt`, the timestep between collocation
50-
points (which, along with the timespan, determines the number of points), or directly
46+
for solving using optimization. Must provide either `dt`, the timestep between collocation
47+
points (which, along with the timespan, determines the number of points), or directly
5148
provide the number of points as `steps`.
5249
5350
To construct the problem, please load CasADi along with ModelingToolkit.
@@ -57,8 +54,8 @@ function CasADiDynamicOptProblem end
5754
PyomoDynamicOptProblem(sys::System, op, tspan; dt, steps)
5855
5956
Convert an System representing an optimal control system into a Pyomo model
60-
for solving using optimization. Must provide either `dt`, the timestep between collocation
61-
points (which, along with the timespan, determines the number of points), or directly
57+
for solving using optimization. Must provide either `dt`, the timestep between collocation
58+
points (which, along with the timespan, determines the number of points), or directly
6259
provide the number of points as `steps`.
6360
6461
To construct the problem, please load Pyomo along with ModelingToolkit.
@@ -229,13 +226,15 @@ end
229226
### MODEL CONSTRUCTION ###
230227
##########################
231228
function process_DynamicOptProblem(
232-
prob_type::Type{<:AbstractDynamicOptProblem}, model_type, sys::System, op, tspan;
229+
prob_type::Type{<:SciMLBase.AbstractDynamicOptProblem}, model_type, sys::System, op, tspan;
233230
dt = nothing,
234231
steps = nothing,
232+
tune_parameters = false,
235233
guesses = Dict(), kwargs...)
236234
warn_overdetermined(sys, op)
237235
ctrls = unbound_inputs(sys)
238236
states = unknowns(sys)
237+
params = tune_parameters ? tunable_parameters(sys) : []
239238

240239
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
241240
op = Dict([default_toterm(value(k)) => v for (k, v) in op])
@@ -253,14 +252,16 @@ function process_DynamicOptProblem(
253252
pmap = recursive_unwrap(AnyDict(pmap))
254253
evaluate_varmap!(pmap, keys(pmap))
255254
c0 = value.([pmap[c] for c in ctrls])
255+
p0, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
256256

257257
tsteps = LinRange(model_tspan[1], model_tspan[2], steps)
258258
model = generate_internal_model(model_type)
259259
generate_time_variable!(model, model_tspan, tsteps)
260260
U = generate_state_variable!(model, u0, length(states), tsteps)
261261
V = generate_input_variable!(model, c0, length(ctrls), tsteps)
262+
P = generate_tunable_params!(model, p0, length(params))
262263
tₛ = generate_timescale!(model, get(pmap, tspan[2], tspan[2]), is_free_t)
263-
fullmodel = model_type(model, U, V, tₛ, is_free_t)
264+
fullmodel = model_type(model, U, V, P, tₛ, is_free_t)
264265

265266
set_variable_bounds!(fullmodel, sys, pmap, tspan[2])
266267
add_cost_function!(fullmodel, sys, tspan, pmap)
@@ -274,6 +275,7 @@ function generate_time_variable! end
274275
function generate_internal_model end
275276
function generate_state_variable! end
276277
function generate_input_variable! end
278+
function generate_tunable_params! end
277279
function generate_timescale! end
278280
function add_initial_constraints! end
279281
function add_constraint! end
@@ -467,24 +469,33 @@ function prepare_and_optimize! end
467469
function get_t_values end
468470
function get_U_values end
469471
function get_V_values end
472+
function get_P_values end
470473
function successful_solve end
471474

472475
"""
473476
solve(prob::AbstractDynamicOptProblem, solver::AbstractCollocation; verbose = false, kwargs...)
474477
475478
- kwargs are used for other options. For example, the `plugin_options` and `solver_options` will propagated to the Opti object in CasADi.
476479
"""
477-
function DiffEqBase.solve(prob::AbstractDynamicOptProblem,
480+
function DiffEqBase.solve(prob::SciMLBase.AbstractDynamicOptProblem,
478481
solver::AbstractCollocation; verbose = false, kwargs...)
479482
solved_model = prepare_and_optimize!(prob, solver; verbose, kwargs...)
480483

481484
ts = get_t_values(solved_model)
482485
Us = get_U_values(solved_model)
483486
Vs = get_V_values(solved_model)
487+
Ps = get_P_values(solved_model)
484488
is_free_final(prob.wrapped_model) && (ts .+ prob.tspan[1])
485489

486-
ode_sol = DiffEqBase.build_solution(prob, solver, ts, Us)
487-
input_sol = isnothing(Vs) ? nothing : DiffEqBase.build_solution(prob, solver, ts, Vs)
490+
# update the parameters with the ones in the solved_model
491+
if !isempty(Ps)
492+
new_p = SciMLStructures.replace(SciMLStructures.Tunable(), prob.p, Ps)
493+
new_prob = remake(prob, p=new_p)
494+
else
495+
new_prob = prob
496+
end
497+
ode_sol = SciMLBase.build_solution(new_prob, solver, ts, Us)
498+
input_sol = isnothing(Vs) ? nothing : SciMLBase.build_solution(new_prob, solver, ts, Vs)
488499

489500
if !successful_solve(solved_model)
490501
ode_sol = SciMLBase.solution_new_retcode(

0 commit comments

Comments
 (0)