Skip to content

Commit bee3abf

Browse files
committed
Overload Optim.optimize for Node as x0
1 parent 9737763 commit bee3abf

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1212
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
15+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1516
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1617
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1718
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

src/ConstantOptimization.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
module ConstantOptimizationModule
2+
3+
import Compat: @inline
4+
import Optim: optimize
5+
import ..EquationModule: Node
6+
import ..EvaluateEquationModule: eval_tree_array
7+
8+
"""Wrap f with insertion of values of the constant nodes."""
9+
function get_wrapped_f(
10+
f::F, tree::N, constant_nodes::AbstractArray{N}
11+
) where {F,T,N<:Node{T}}
12+
function wrapped_f(x)
13+
for (ci, xi) in zip(constant_nodes, x)
14+
ci.val::T = xi::T
15+
end
16+
return @inline(f(tree))
17+
end
18+
return wrapped_f
19+
end
20+
21+
"""Wrap g! or h! with insertion of values of the constant nodes."""
22+
function get_wrapped_gh!(
23+
gh!::GH, tree::N, constant_nodes::AbstractArray{N}
24+
) where {GH,T,N<:Node{T}}
25+
function wrapped_gh!(G, x)
26+
for (ci, xi) in zip(constant_nodes, x)
27+
ci.val::T = xi::T
28+
end
29+
@inline(gh!(G, tree))
30+
return nothing
31+
end
32+
return wrapped_gh!
33+
end
34+
35+
function optimize(f::F, g!::G, h!::H, tree::Node{T}, args...; kwargs...) where {F,G,H,T}
36+
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)
37+
x0 = [t.val::T for t in constant_nodes]
38+
if g! === nothing
39+
@assert h! === nothing
40+
return optimize(get_wrapped_f(f, tree, constant_nodes), x0, args...; kwargs...)
41+
elseif h! === nothing
42+
return optimize(
43+
get_wrapped_f(f, tree, constant_nodes),
44+
get_wrapped_gh!(g!, tree, constant_nodes),
45+
x0,
46+
args...;
47+
kwargs...,
48+
)
49+
else
50+
return optimize(
51+
get_wrapped_f(f, tree, constant_nodes),
52+
get_wrapped_gh!(g!, tree, constant_nodes),
53+
get_wrapped_gh!(h!, tree, constant_nodes),
54+
x0,
55+
args...;
56+
kwargs...,
57+
)
58+
end
59+
end
60+
function optimize(f::F, g!::G, tree::Node, args...; kwargs...) where {F,G}
61+
return optimize(f, g!, nothing, tree, args...; kwargs...)
62+
end
63+
function optimize(f::F, tree::Node, args...; kwargs...) where {F}
64+
return optimize(f, nothing, tree, args...; kwargs...)
65+
end
66+
67+
end

src/DynamicExpressions.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module DynamicExpressions
22

3+
using Requires: @require
4+
35
include("Utils.jl")
46
include("OperatorEnum.jl")
57
include("Equation.jl")
@@ -10,8 +12,9 @@ include("EvaluationHelpers.jl")
1012
include("InterfaceSymbolicUtils.jl")
1113
include("SimplifyEquation.jl")
1214
include("OperatorEnumConstruction.jl")
15+
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include("ConstantOptimization.jl")
1316

14-
using Reexport
17+
using Reexport: @reexport
1518
@reexport import .EquationModule:
1619
Node,
1720
string_tree,
@@ -41,6 +44,8 @@ using Reexport
4144
@reexport import .InterfaceSymbolicUtilsModule: node_to_symbolic, symbolic_to_node
4245
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
4346
@reexport import .EvaluationHelpersModule
47+
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" @reexport import .ConstantOptimizationModule:
48+
optimize
4449

4550
include("deprecated.jl")
4651

0 commit comments

Comments
 (0)