Skip to content

Commit a900883

Browse files
committed
Make overload-able interface for specifying constants in tree
1 parent c48c2c2 commit a900883

File tree

4 files changed

+75
-26
lines changed

4 files changed

+75
-26
lines changed

ext/DynamicExpressionsOptimExt.jl

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DynamicExpressionsOptimExt
22

3-
using DynamicExpressions: AbstractExpressionNode, eval_tree_array
3+
using DynamicExpressions:
4+
AbstractExpressionNode, eval_tree_array, get_constant_refs, set_constant_refs!
45
using Compat: @inline
56

67
import Optim: Optim, OptimizationResults, NLSolversBase
@@ -31,45 +32,36 @@ function Optim.minimizer(r::ExpressionOptimizationResults)
3132
return r.tree
3233
end
3334

34-
function set_constant_nodes!(
35-
constant_nodes::AbstractArray{N}, x
36-
) where {T,N<:AbstractExpressionNode{T}}
37-
for (ci, xi) in zip(constant_nodes, x)
38-
ci.val::T = xi::T
39-
end
40-
return nothing
41-
end
42-
4335
"""Wrap function or objective with insertion of values of the constant nodes."""
4436
function wrap_func(
45-
f::F, tree::N, constant_nodes::AbstractArray{N}
37+
f::F, tree::N, constant_refs::AbstractArray
4638
) where {F<:Function,T,N<:AbstractExpressionNode{T}}
4739
function wrapped_f(args::Vararg{Any,M}) where {M}
4840
first_args = args[1:(end - 1)]
4941
x = last(args)
50-
set_constant_nodes!(constant_nodes, x)
42+
set_constant_refs!(constant_refs, x)
5143
return @inline(f(first_args..., tree))
5244
end
5345
return wrapped_f
5446
end
5547
function wrap_func(
56-
::Nothing, tree::N, constant_nodes::AbstractArray{N}
48+
::Nothing, tree::N, constant_refs::AbstractArray
5749
) where {N<:AbstractExpressionNode}
5850
return nothing
5951
end
6052
function wrap_func(
61-
f::NLSolversBase.InplaceObjective, tree::N, constant_nodes::AbstractArray{N}
53+
f::NLSolversBase.InplaceObjective, tree::N, constant_refs::AbstractArray
6254
) where {N<:AbstractExpressionNode}
6355
# Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead
6456
# `InplaceObjective`. These contain multiple functions, each of which needs to be
6557
# wrapped. Some functions are `nothing`; those can be left as-is.
6658
@assert fieldnames(NLSolversBase.InplaceObjective) == (:df, :fdf, :fgh, :hv, :fghv)
6759
return NLSolversBase.InplaceObjective(
68-
wrap_func(f.df, tree, constant_nodes),
69-
wrap_func(f.fdf, tree, constant_nodes),
70-
wrap_func(f.fgh, tree, constant_nodes),
71-
wrap_func(f.hv, tree, constant_nodes),
72-
wrap_func(f.fghv, tree, constant_nodes),
60+
wrap_func(f.df, tree, constant_refs),
61+
wrap_func(f.fdf, tree, constant_refs),
62+
wrap_func(f.fgh, tree, constant_refs),
63+
wrap_func(f.hv, tree, constant_refs),
64+
wrap_func(f.fghv, tree, constant_refs),
7365
)
7466
end
7567

@@ -95,8 +87,8 @@ function Optim.optimize(
9587
if make_copy
9688
tree = copy(tree)
9789
end
98-
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)
99-
x0 = T[t.val::T for t in constant_nodes]
90+
constant_refs = get_constant_refs(tree)
91+
x0 = map(t -> t.x, constant_refs)
10092
if !isnothing(h!)
10193
throw(
10294
ArgumentError(
@@ -106,17 +98,17 @@ function Optim.optimize(
10698
)
10799
end
108100
base_res = if isnothing(g!)
109-
Optim.optimize(wrap_func(f, tree, constant_nodes), x0, args...; kwargs...)
101+
Optim.optimize(wrap_func(f, tree, constant_refs), x0, args...; kwargs...)
110102
else
111103
Optim.optimize(
112-
wrap_func(f, tree, constant_nodes),
113-
wrap_func(g!, tree, constant_nodes),
104+
wrap_func(f, tree, constant_refs),
105+
wrap_func(g!, tree, constant_refs),
114106
x0,
115107
args...;
116108
kwargs...,
117109
)
118110
end
119-
set_constant_nodes!(constant_nodes, Optim.minimizer(base_res))
111+
set_constant_refs!(constant_refs, Optim.minimizer(base_res))
120112
return ExpressionOptimizationResults(base_res, tree)
121113
end
122114

src/DynamicExpressions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ import .EquationModule: constructorof, preserve_sharing
3636
has_operators,
3737
has_constants,
3838
get_constants,
39-
set_constants!
39+
set_constants!,
40+
get_constant_refs,
41+
set_constant_refs!
4042
@reexport import .OperatorEnumModule: AbstractOperatorEnum
4143
@reexport import .OperatorEnumConstructionModule:
4244
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!

src/EquationUtils.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,60 @@ function set_constants!(
9898
return nothing
9999
end
100100

101+
"""
102+
NodeConstantRef{T,N<:AbstractExpressionNode{T}}
103+
104+
A reference to a constant in an expression tree. Use `.x` to access
105+
the value of the constant for setting or getting.
106+
"""
107+
struct NodeConstantRef{T,N<:AbstractExpressionNode{T}}
108+
_node::Ref{N}
109+
110+
function NodeConstantRef(node::_N) where {_T,_N<:AbstractExpressionNode{_T}}
111+
return new{_T,_N}(Ref(node))
112+
end
113+
end
114+
function Base.getproperty(cr::NodeConstantRef{T}, s::Symbol) where {T}
115+
s != :x && error("Only :x is a valid property for NodeConstantRef")
116+
117+
return getfield(cr, :_node).x.val::T
118+
end
119+
function Base.setproperty!(cr::NodeConstantRef{T}, s::Symbol, v) where {T}
120+
s != :x && error("Only :x is a valid property for NodeConstantRef")
121+
122+
return getfield(cr, :_node).x.val::T = v::T
123+
end
124+
Base.propertynames(::NodeConstantRef) = (:x,)
125+
126+
"""
127+
get_constant_refs(tree::AbstractExpressionNode)
128+
129+
Get references to all constants in a tree, in depth-first order. Using the output of this lets
130+
you quickly modify the constants in the tree in-place.
131+
"""
132+
function get_constant_refs(tree::AbstractExpressionNode)
133+
return filter_map(
134+
is_node_constant,
135+
t -> NodeConstantRef(t),
136+
tree,
137+
NodeConstantRef{eltype(tree),typeof(tree)},
138+
)
139+
end
140+
141+
"""
142+
set_constant_refs!(crs::AbstractArray{C}, xs::AbstractArray{T}) where {T,C<:NodeConstantRef{T}}
143+
144+
Set the constants in a tree to the values in a vector.
145+
"""
146+
@inline function set_constant_refs!(
147+
constant_refs::AbstractArray{C}, xs::AbstractArray{T}
148+
) where {T,C<:NodeConstantRef{T}}
149+
for (cr, x) in zip(constant_refs, xs)
150+
cr.x = x
151+
end
152+
return nothing
153+
end
154+
101155
## Assign index to nodes of a tree
102156
# This will mirror a Node struct, rather
103157
# than adding a new attribute to Node.

test/test_optim.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using DynamicExpressions, Optim, Zygote
22
using Random: MersenneTwister as RNG
3+
using Test
34

45
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
56
x1, x2 = (i -> Node(Float64; feature=i)).(1:2)

0 commit comments

Comments
 (0)