Skip to content

Commit 21acfa1

Browse files
committed
feat!: simplify expression optimization routine
1 parent 4bd83e9 commit 21acfa1

File tree

3 files changed

+13
-63
lines changed

3 files changed

+13
-63
lines changed

ext/DynamicExpressionsOptimExt.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module DynamicExpressionsOptimExt
22

3-
using DynamicExpressions:
4-
AbstractExpressionNode, eval_tree_array, get_constant_refs, set_constant_refs!
3+
using DynamicExpressions: AbstractExpressionNode, filter_map, eval_tree_array
54
using Compat: @inline
65

76
import Optim: Optim, OptimizationResults, NLSolversBase
@@ -39,7 +38,9 @@ function wrap_func(
3938
function wrapped_f(args::Vararg{Any,M}) where {M}
4039
first_args = args[1:(end - 1)]
4140
x = last(args)
42-
set_constant_refs!(constant_refs, x)
41+
@inbounds for i in eachindex(constant_refs, x)
42+
constant_refs[i][].val = x[i]
43+
end
4344
return @inline(f(first_args..., tree))
4445
end
4546
return wrapped_f
@@ -87,8 +88,10 @@ function Optim.optimize(
8788
if make_copy
8889
tree = copy(tree)
8990
end
90-
constant_refs = get_constant_refs(tree)
91-
x0 = map(t -> t.x, constant_refs)
91+
constant_refs = filter_map(
92+
t -> t.degree == 0 && t.constant, t -> Ref(t), tree, Ref{typeof(tree)}
93+
)
94+
x0 = T[copy(t[].val) for t in constant_refs]
9295
if !isnothing(h!)
9396
throw(
9497
ArgumentError(
@@ -108,7 +111,10 @@ function Optim.optimize(
108111
kwargs...,
109112
)
110113
end
111-
set_constant_refs!(constant_refs, Optim.minimizer(base_res))
114+
minimizer = Optim.minimizer(base_res)
115+
@inbounds for i in eachindex(constant_refs, minimizer)
116+
constant_refs[i][].val = minimizer[i]
117+
end
112118
return ExpressionOptimizationResults(base_res, tree)
113119
end
114120

src/DynamicExpressions.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ import .EquationModule: constructorof, preserve_sharing
3535
has_operators,
3636
has_constants,
3737
get_constants,
38-
set_constants!,
39-
get_constant_refs,
40-
set_constant_refs!
38+
set_constants!
4139
@reexport import .StringsModule: string_tree, print_tree
4240
@reexport import .OperatorEnumModule: AbstractOperatorEnum
4341
@reexport import .OperatorEnumConstructionModule:

src/EquationUtils.jl

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -98,60 +98,6 @@ function set_constants!(
9898
return nothing
9999
end
100100

101-
"""
102-
NodeConstantRef{T,N<:AbstractExpressionNode{T}}
103-
104-
A reference to a constant in an expression tree. Use `.x` to access
105-
the value of the constant for setting or getting.
106-
"""
107-
struct NodeConstantRef{T,N<:AbstractExpressionNode{T}}
108-
_node::Ref{N}
109-
110-
function NodeConstantRef(node::_N) where {_T,_N<:AbstractExpressionNode{_T}}
111-
return new{_T,_N}(Ref(node))
112-
end
113-
end
114-
function Base.getproperty(cr::NodeConstantRef{T}, s::Symbol) where {T}
115-
s != :x && error("Only :x is a valid property for NodeConstantRef")
116-
117-
return getfield(cr, :_node).x.val
118-
end
119-
function Base.setproperty!(cr::NodeConstantRef{T}, s::Symbol, v) where {T}
120-
s != :x && error("Only :x is a valid property for NodeConstantRef")
121-
122-
return getfield(cr, :_node).x.val = v
123-
end
124-
Base.propertynames(::NodeConstantRef) = (:x,)
125-
126-
"""
127-
get_constant_refs(tree::AbstractExpressionNode)
128-
129-
Get references to all constants in a tree, in depth-first order. Using the output of this lets
130-
you quickly modify the constants in the tree in-place.
131-
"""
132-
function get_constant_refs(tree::AbstractExpressionNode)
133-
return filter_map(
134-
is_node_constant,
135-
t -> NodeConstantRef(t),
136-
tree,
137-
NodeConstantRef{eltype(tree),typeof(tree)},
138-
)
139-
end
140-
141-
"""
142-
set_constant_refs!(crs::AbstractArray{C}, xs::AbstractArray{T}) where {T,C<:NodeConstantRef{T}}
143-
144-
Set the constants in a tree to the values in a vector.
145-
"""
146-
@inline function set_constant_refs!(
147-
constant_refs::AbstractArray{C}, xs::AbstractArray{T}
148-
) where {T,C<:NodeConstantRef{T}}
149-
for (cr, x) in zip(constant_refs, xs)
150-
cr.x = x
151-
end
152-
return nothing
153-
end
154-
155101
## Assign index to nodes of a tree
156102
# This will mirror a Node struct, rather
157103
# than adding a new attribute to Node.

0 commit comments

Comments
 (0)