Skip to content

Commit 8d7f000

Browse files
committed
refactor: update tree_gen_utils.jl
1 parent 21acfa1 commit 8d7f000

File tree

2 files changed

+57
-46
lines changed

2 files changed

+57
-46
lines changed

src/base.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import Base:
2222
mapreduce,
2323
reduce,
2424
sum
25-
import Compat: @inline, Returns
26-
import ..UtilsModule: @memoize_on, @with_memoize, Undefined
25+
26+
using Compat: @inline, Returns
27+
using ..UtilsModule: @memoize_on, @with_memoize, Undefined
2728

2829
"""
2930
tree_mapreduce(

test/tree_gen_utils.jl

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,66 @@
1-
import DynamicExpressions:
2-
Node, copy_node, set_node!, count_nodes, has_constants, has_operators
1+
using DynamicExpressions:
2+
AbstractExpressionNode,
3+
AbstractNode,
4+
Node,
5+
NodeSampler,
6+
constructorof,
7+
set_node!,
8+
count_nodes
9+
using Random: AbstractRNG, default_rng
310

4-
# This code is copied from SymbolicRegression.jl and modified
11+
"""
12+
random_node(tree::AbstractNode; filter::F=Returns(true))
513
6-
# Return a random node from the tree
7-
function random_node(tree::Node{T})::Node{T} where {T}
8-
if tree.degree == 0
9-
return tree
10-
end
11-
b = count_nodes(tree.l)
12-
c = if tree.degree == 2
13-
count_nodes(tree.r)
14-
else
15-
0
16-
end
17-
18-
i = rand(1:(1 + b + c))
19-
if i <= b
20-
return random_node(tree.l)
21-
elseif i == b + 1
22-
return tree
23-
end
24-
25-
return random_node(tree.r)
14+
Return a random node from the tree. You may optionally
15+
filter the nodes matching some condition before sampling.
16+
"""
17+
function random_node(
18+
tree::AbstractNode, rng::AbstractRNG=default_rng(); filter::F=Returns(true)
19+
) where {F<:Function}
20+
Base.depwarn(
21+
"Instead of `random_node(tree, filter)`, use `rand(NodeSampler(; tree, filter))`",
22+
:random_node,
23+
)
24+
return rand(rng, NodeSampler(; tree, filter))
2625
end
2726

28-
function make_random_leaf(nfeatures::Integer, ::Type{T})::Node{T} where {T}
29-
if rand() > 0.5
30-
return Node(; val=randn(T))
27+
function make_random_leaf(
28+
nfeatures::Int, ::Type{T}, ::Type{N}, rng::AbstractRNG=default_rng()
29+
) where {T,N<:AbstractExpressionNode}
30+
if rand(rng, Bool)
31+
return constructorof(N)(; val=randn(rng, T))
3132
else
32-
return Node(T; feature=rand(1:nfeatures))
33+
return constructorof(N)(T; feature=rand(rng, 1:nfeatures))
3334
end
3435
end
3536

36-
# Add a random unary/binary operation to the end of a tree
37+
"""Add a random unary/binary operation to the end of a tree"""
3738
function append_random_op(
38-
tree::Node{T}, operators, nfeatures::Integer; makeNewBinOp::Union{Bool,Nothing}=nothing
39-
)::Node{T} where {T}
39+
tree::AbstractExpressionNode{T},
40+
operators,
41+
nfeatures::Int,
42+
rng::AbstractRNG=default_rng();
43+
makeNewBinOp::Union{Bool,Nothing}=nothing,
44+
) where {T}
45+
node = rand(rng, NodeSampler(; tree, filter=t -> t.degree == 0))
4046
nuna = length(operators.unaops)
4147
nbin = length(operators.binops)
4248

43-
node = random_node(tree)
44-
while node.degree != 0
45-
node = random_node(tree)
46-
end
47-
4849
if makeNewBinOp === nothing
49-
choice = rand()
50+
choice = rand(rng)
5051
makeNewBinOp = choice < nbin / (nuna + nbin)
5152
end
5253

5354
if makeNewBinOp
54-
newnode = Node(
55-
rand(1:nbin), make_random_leaf(nfeatures, T), make_random_leaf(nfeatures, T)
55+
newnode = constructorof(typeof(tree))(
56+
rand(rng, 1:nbin),
57+
make_random_leaf(nfeatures, T, typeof(tree), rng),
58+
make_random_leaf(nfeatures, T, typeof(tree), rng),
5659
)
5760
else
58-
newnode = Node(rand(1:nuna), make_random_leaf(nfeatures, T))
61+
newnode = constructorof(typeof(tree))(
62+
rand(rng, 1:nuna), make_random_leaf(nfeatures, T, typeof(tree), rng)
63+
)
5964
end
6065

6166
set_node!(node, newnode)
@@ -64,16 +69,21 @@ function append_random_op(
6469
end
6570

6671
function gen_random_tree_fixed_size(
67-
node_count::Integer, operators, nfeatures::Integer, ::Type{T}
68-
)::Node{T} where {T}
69-
tree = make_random_leaf(nfeatures, T)
72+
node_count::Int,
73+
operators,
74+
nfeatures::Int,
75+
::Type{T},
76+
node_type=Node,
77+
rng::AbstractRNG=default_rng(),
78+
) where {T}
79+
tree = make_random_leaf(nfeatures, T, node_type, rng)
7080
cur_size = count_nodes(tree)
7181
while cur_size < node_count
7282
if cur_size == node_count - 1 # only unary operator allowed.
7383
length(operators.unaops) == 0 && break # We will go over the requested amount, so we must break.
74-
tree = append_random_op(tree, operators, nfeatures; makeNewBinOp=false)
84+
tree = append_random_op(tree, operators, nfeatures, rng; makeNewBinOp=false)
7585
else
76-
tree = append_random_op(tree, operators, nfeatures)
86+
tree = append_random_op(tree, operators, nfeatures, rng)
7787
end
7888
cur_size = count_nodes(tree)
7989
end

0 commit comments

Comments
 (0)