Skip to content

Commit 3013549

Browse files
committed
Proper Optim interface of results struct
1 parent a968ac1 commit 3013549

File tree

3 files changed

+90
-18
lines changed

3 files changed

+90
-18
lines changed

ext/DynamicExpressionsOptimExt.jl

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,51 @@ module DynamicExpressionsOptimExt
33
using DynamicExpressions: AbstractExpressionNode, eval_tree_array
44
using Compat: @inline
55

6-
import Optim: optimize
6+
import Optim: Optim, OptimizationResults
7+
8+
#! format: off
9+
"""
10+
ExpressionOptimizationResults{R,N<:AbstractExpressionNode}
11+
12+
Optimization results for an expression, which wraps the base optimization results
13+
on a vector of constants.
14+
"""
15+
struct ExpressionOptimizationResults{R<:OptimizationResults,N<:AbstractExpressionNode} <: OptimizationResults
16+
_results::R # The raw results from Optim.
17+
tree::N # The final expression tree
18+
end
19+
#! format: on
20+
function Base.getproperty(r::ExpressionOptimizationResults, s::Symbol)
21+
if s == :tree || s == :minimizer
22+
return getfield(r, :tree)
23+
else
24+
return getproperty(getfield(r, :_results), s)
25+
end
26+
end
27+
function Base.propertynames(r::ExpressionOptimizationResults)
28+
return (:tree, propertynames(getfield(r, :_results))...)
29+
end
30+
function base_results(r::ExpressionOptimizationResults)
31+
return getfield(r, :_results)
32+
end
33+
function Optim.minimizer(r::ExpressionOptimizationResults)
34+
return r.tree
35+
end
36+
37+
function set_constant_nodes!(
38+
constant_nodes::AbstractArray{N}, x
39+
) where {T,N<:AbstractExpressionNode{T}}
40+
for (ci, xi) in zip(constant_nodes, x)
41+
ci.val::T = xi::T
42+
end
43+
end
744

845
"""Wrap f with insertion of values of the constant nodes."""
946
function get_wrapped_f(
1047
f::F, tree::N, constant_nodes::AbstractArray{N}
1148
) where {F,T,N<:AbstractExpressionNode{T}}
1249
function wrapped_f(x)
13-
for (ci, xi) in zip(constant_nodes, x)
14-
ci.val::T = xi::T
15-
end
50+
set_constant_nodes!(constant_nodes, x)
1651
return @inline(f(tree))
1752
end
1853
return wrapped_f
@@ -23,31 +58,52 @@ function get_wrapped_gh!(
2358
gh!::GH, tree::N, constant_nodes::AbstractArray{N}
2459
) where {GH,T,N<:AbstractExpressionNode{T}}
2560
function wrapped_gh!(G, x)
26-
for (ci, xi) in zip(constant_nodes, x)
27-
ci.val::T = xi::T
28-
end
61+
set_constant_nodes!(constant_nodes, x)
2962
@inline(gh!(G, tree))
3063
return nothing
3164
end
3265
return wrapped_gh!
3366
end
3467

35-
function optimize(f::F, g!::G, h!::H, tree::AbstractExpressionNode{T}, args...; kwargs...) where {F,G,H,T}
68+
"""
69+
optimize(f, [g!, [h!,]] tree, args...; kwargs...)
70+
71+
Optimize an expression tree with respect to the constants in the tree.
72+
Returns an `ExpressionOptimizationResults` object, which wraps the base
73+
optimization results on a vector of constants. You may use `res.minimizer`
74+
to view the optimized expression tree.
75+
"""
76+
function Optim.optimize(
77+
f::F, tree::AbstractExpressionNode, args...; kwargs...
78+
) where {F<:Function}
79+
return Optim.optimize(f, nothing, tree, args...; kwargs...)
80+
end
81+
function Optim.optimize(
82+
f::F, g!::G, tree::AbstractExpressionNode, args...; kwargs...
83+
) where {F,G<:Union{Function,Nothing}}
84+
return Optim.optimize(f, g!, nothing, tree, args...; kwargs...)
85+
end
86+
function Optim.optimize(
87+
f::F, g!::G, h!::H, tree::AbstractExpressionNode{T}, args...; make_copy=true, kwargs...
88+
) where {F,G<:Union{Function,Nothing},H<:Union{Function,Nothing},T}
89+
if make_copy
90+
tree = copy(tree)
91+
end
3692
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)
3793
x0 = T[t.val::T for t in constant_nodes]
38-
if g! === nothing
94+
base_res = if g! === nothing
3995
@assert h! === nothing
40-
return optimize(get_wrapped_f(f, tree, constant_nodes), x0, args...; kwargs...)
96+
Optim.optimize(get_wrapped_f(f, tree, constant_nodes), x0, args...; kwargs...)
4197
elseif h! === nothing
42-
return optimize(
98+
Optim.optimize(
4399
get_wrapped_f(f, tree, constant_nodes),
44100
get_wrapped_gh!(g!, tree, constant_nodes),
45101
x0,
46102
args...;
47103
kwargs...,
48104
)
49105
else
50-
return optimize(
106+
Optim.optimize(
51107
get_wrapped_f(f, tree, constant_nodes),
52108
get_wrapped_gh!(g!, tree, constant_nodes),
53109
get_wrapped_gh!(h!, tree, constant_nodes),
@@ -56,12 +112,8 @@ function optimize(f::F, g!::G, h!::H, tree::AbstractExpressionNode{T}, args...;
56112
kwargs...,
57113
)
58114
end
59-
end
60-
function optimize(f::F, g!::G, tree::AbstractExpressionNode, args...; kwargs...) where {F,G}
61-
return optimize(f, g!, nothing, tree, args...; kwargs...)
62-
end
63-
function optimize(f::F, tree::AbstractExpressionNode, args...; kwargs...) where {F}
64-
return optimize(f, nothing, tree, args...; kwargs...)
115+
set_constant_nodes!(constant_nodes, Optim.minimizer(base_res))
116+
return ExpressionOptimizationResults(base_res, tree)
65117
end
66118

67119
end

test/test_optim.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using DynamicExpressions, Optim, Zygote
2+
using Random: Xoshiro
3+
4+
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(sin, cos))
5+
x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3);
6+
7+
X = randn(Xoshiro(0), Float64, 3, 100);
8+
y = @. cos(X[1, :] * 2.1 - 0.9) + X[3, :] * -0.9
9+
10+
original_tree = cos(x1 * 0.8 - 0.0) + 5.2 * x3
11+
tree = copy(original_tree)
12+
13+
res = optimize(t -> sum(abs2, t(X, operators) .- y), tree)
14+
15+
# Should be unchanged by default
16+
@test tree == original_tree

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ VERSION >= v"1.9" && @safetestset "Test Aqua.jl" begin
44
include("test_aqua.jl")
55
end
66

7+
@safetestset "Test Optim.jl" begin
8+
include("test_optim.jl")
9+
end
10+
711
@safetestset "Initial error handling test" begin
812
include("test_initial_errors.jl")
913
end

0 commit comments

Comments
 (0)