Skip to content

Commit 27b6199

Browse files
authored
Merge pull request #70 from SymbolicML/type-stability
feat!: simplify expression optimization routine
2 parents 4bd83e9 + 8d7f000 commit 27b6199

File tree

5 files changed

+70
-109
lines changed

5 files changed

+70
-109
lines changed

ext/DynamicExpressionsOptimExt.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module DynamicExpressionsOptimExt
22

3-
using DynamicExpressions:
4-
AbstractExpressionNode, eval_tree_array, get_constant_refs, set_constant_refs!
3+
using DynamicExpressions: AbstractExpressionNode, filter_map, eval_tree_array
54
using Compat: @inline
65

76
import Optim: Optim, OptimizationResults, NLSolversBase
@@ -39,7 +38,9 @@ function wrap_func(
3938
function wrapped_f(args::Vararg{Any,M}) where {M}
4039
first_args = args[1:(end - 1)]
4140
x = last(args)
42-
set_constant_refs!(constant_refs, x)
41+
@inbounds for i in eachindex(constant_refs, x)
42+
constant_refs[i][].val = x[i]
43+
end
4344
return @inline(f(first_args..., tree))
4445
end
4546
return wrapped_f
@@ -87,8 +88,10 @@ function Optim.optimize(
8788
if make_copy
8889
tree = copy(tree)
8990
end
90-
constant_refs = get_constant_refs(tree)
91-
x0 = map(t -> t.x, constant_refs)
91+
constant_refs = filter_map(
92+
t -> t.degree == 0 && t.constant, t -> Ref(t), tree, Ref{typeof(tree)}
93+
)
94+
x0 = T[copy(t[].val) for t in constant_refs]
9295
if !isnothing(h!)
9396
throw(
9497
ArgumentError(
@@ -108,7 +111,10 @@ function Optim.optimize(
108111
kwargs...,
109112
)
110113
end
111-
set_constant_refs!(constant_refs, Optim.minimizer(base_res))
114+
minimizer = Optim.minimizer(base_res)
115+
@inbounds for i in eachindex(constant_refs, minimizer)
116+
constant_refs[i][].val = minimizer[i]
117+
end
112118
return ExpressionOptimizationResults(base_res, tree)
113119
end
114120

src/DynamicExpressions.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ import .EquationModule: constructorof, preserve_sharing
3535
has_operators,
3636
has_constants,
3737
get_constants,
38-
set_constants!,
39-
get_constant_refs,
40-
set_constant_refs!
38+
set_constants!
4139
@reexport import .StringsModule: string_tree, print_tree
4240
@reexport import .OperatorEnumModule: AbstractOperatorEnum
4341
@reexport import .OperatorEnumConstructionModule:

src/EquationUtils.jl

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -98,60 +98,6 @@ function set_constants!(
9898
return nothing
9999
end
100100

101-
"""
102-
NodeConstantRef{T,N<:AbstractExpressionNode{T}}
103-
104-
A reference to a constant in an expression tree. Use `.x` to access
105-
the value of the constant for setting or getting.
106-
"""
107-
struct NodeConstantRef{T,N<:AbstractExpressionNode{T}}
108-
_node::Ref{N}
109-
110-
function NodeConstantRef(node::_N) where {_T,_N<:AbstractExpressionNode{_T}}
111-
return new{_T,_N}(Ref(node))
112-
end
113-
end
114-
function Base.getproperty(cr::NodeConstantRef{T}, s::Symbol) where {T}
115-
s != :x && error("Only :x is a valid property for NodeConstantRef")
116-
117-
return getfield(cr, :_node).x.val
118-
end
119-
function Base.setproperty!(cr::NodeConstantRef{T}, s::Symbol, v) where {T}
120-
s != :x && error("Only :x is a valid property for NodeConstantRef")
121-
122-
return getfield(cr, :_node).x.val = v
123-
end
124-
Base.propertynames(::NodeConstantRef) = (:x,)
125-
126-
"""
127-
get_constant_refs(tree::AbstractExpressionNode)
128-
129-
Get references to all constants in a tree, in depth-first order. Using the output of this lets
130-
you quickly modify the constants in the tree in-place.
131-
"""
132-
function get_constant_refs(tree::AbstractExpressionNode)
133-
return filter_map(
134-
is_node_constant,
135-
t -> NodeConstantRef(t),
136-
tree,
137-
NodeConstantRef{eltype(tree),typeof(tree)},
138-
)
139-
end
140-
141-
"""
142-
set_constant_refs!(crs::AbstractArray{C}, xs::AbstractArray{T}) where {T,C<:NodeConstantRef{T}}
143-
144-
Set the constants in a tree to the values in a vector.
145-
"""
146-
@inline function set_constant_refs!(
147-
constant_refs::AbstractArray{C}, xs::AbstractArray{T}
148-
) where {T,C<:NodeConstantRef{T}}
149-
for (cr, x) in zip(constant_refs, xs)
150-
cr.x = x
151-
end
152-
return nothing
153-
end
154-
155101
## Assign index to nodes of a tree
156102
# This will mirror a Node struct, rather
157103
# than adding a new attribute to Node.

src/base.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import Base:
2222
mapreduce,
2323
reduce,
2424
sum
25-
import Compat: @inline, Returns
26-
import ..UtilsModule: @memoize_on, @with_memoize, Undefined
25+
26+
using Compat: @inline, Returns
27+
using ..UtilsModule: @memoize_on, @with_memoize, Undefined
2728

2829
"""
2930
tree_mapreduce(

test/tree_gen_utils.jl

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,66 @@
1-
import DynamicExpressions:
2-
Node, copy_node, set_node!, count_nodes, has_constants, has_operators
1+
using DynamicExpressions:
2+
AbstractExpressionNode,
3+
AbstractNode,
4+
Node,
5+
NodeSampler,
6+
constructorof,
7+
set_node!,
8+
count_nodes
9+
using Random: AbstractRNG, default_rng
310

4-
# This code is copied from SymbolicRegression.jl and modified
11+
"""
12+
random_node(tree::AbstractNode; filter::F=Returns(true))
513
6-
# Return a random node from the tree
7-
function random_node(tree::Node{T})::Node{T} where {T}
8-
if tree.degree == 0
9-
return tree
10-
end
11-
b = count_nodes(tree.l)
12-
c = if tree.degree == 2
13-
count_nodes(tree.r)
14-
else
15-
0
16-
end
17-
18-
i = rand(1:(1 + b + c))
19-
if i <= b
20-
return random_node(tree.l)
21-
elseif i == b + 1
22-
return tree
23-
end
24-
25-
return random_node(tree.r)
14+
Return a random node from the tree. You may optionally
15+
filter the nodes matching some condition before sampling.
16+
"""
17+
function random_node(
18+
tree::AbstractNode, rng::AbstractRNG=default_rng(); filter::F=Returns(true)
19+
) where {F<:Function}
20+
Base.depwarn(
21+
"Instead of `random_node(tree, filter)`, use `rand(NodeSampler(; tree, filter))`",
22+
:random_node,
23+
)
24+
return rand(rng, NodeSampler(; tree, filter))
2625
end
2726

28-
function make_random_leaf(nfeatures::Integer, ::Type{T})::Node{T} where {T}
29-
if rand() > 0.5
30-
return Node(; val=randn(T))
27+
function make_random_leaf(
28+
nfeatures::Int, ::Type{T}, ::Type{N}, rng::AbstractRNG=default_rng()
29+
) where {T,N<:AbstractExpressionNode}
30+
if rand(rng, Bool)
31+
return constructorof(N)(; val=randn(rng, T))
3132
else
32-
return Node(T; feature=rand(1:nfeatures))
33+
return constructorof(N)(T; feature=rand(rng, 1:nfeatures))
3334
end
3435
end
3536

36-
# Add a random unary/binary operation to the end of a tree
37+
"""Add a random unary/binary operation to the end of a tree"""
3738
function append_random_op(
38-
tree::Node{T}, operators, nfeatures::Integer; makeNewBinOp::Union{Bool,Nothing}=nothing
39-
)::Node{T} where {T}
39+
tree::AbstractExpressionNode{T},
40+
operators,
41+
nfeatures::Int,
42+
rng::AbstractRNG=default_rng();
43+
makeNewBinOp::Union{Bool,Nothing}=nothing,
44+
) where {T}
45+
node = rand(rng, NodeSampler(; tree, filter=t -> t.degree == 0))
4046
nuna = length(operators.unaops)
4147
nbin = length(operators.binops)
4248

43-
node = random_node(tree)
44-
while node.degree != 0
45-
node = random_node(tree)
46-
end
47-
4849
if makeNewBinOp === nothing
49-
choice = rand()
50+
choice = rand(rng)
5051
makeNewBinOp = choice < nbin / (nuna + nbin)
5152
end
5253

5354
if makeNewBinOp
54-
newnode = Node(
55-
rand(1:nbin), make_random_leaf(nfeatures, T), make_random_leaf(nfeatures, T)
55+
newnode = constructorof(typeof(tree))(
56+
rand(rng, 1:nbin),
57+
make_random_leaf(nfeatures, T, typeof(tree), rng),
58+
make_random_leaf(nfeatures, T, typeof(tree), rng),
5659
)
5760
else
58-
newnode = Node(rand(1:nuna), make_random_leaf(nfeatures, T))
61+
newnode = constructorof(typeof(tree))(
62+
rand(rng, 1:nuna), make_random_leaf(nfeatures, T, typeof(tree), rng)
63+
)
5964
end
6065

6166
set_node!(node, newnode)
@@ -64,16 +69,21 @@ function append_random_op(
6469
end
6570

6671
function gen_random_tree_fixed_size(
67-
node_count::Integer, operators, nfeatures::Integer, ::Type{T}
68-
)::Node{T} where {T}
69-
tree = make_random_leaf(nfeatures, T)
72+
node_count::Int,
73+
operators,
74+
nfeatures::Int,
75+
::Type{T},
76+
node_type=Node,
77+
rng::AbstractRNG=default_rng(),
78+
) where {T}
79+
tree = make_random_leaf(nfeatures, T, node_type, rng)
7080
cur_size = count_nodes(tree)
7181
while cur_size < node_count
7282
if cur_size == node_count - 1 # only unary operator allowed.
7383
length(operators.unaops) == 0 && break # We will go over the requested amount, so we must break.
74-
tree = append_random_op(tree, operators, nfeatures; makeNewBinOp=false)
84+
tree = append_random_op(tree, operators, nfeatures, rng; makeNewBinOp=false)
7585
else
76-
tree = append_random_op(tree, operators, nfeatures)
86+
tree = append_random_op(tree, operators, nfeatures, rng)
7787
end
7888
cur_size = count_nodes(tree)
7989
end

0 commit comments

Comments
 (0)