1- import DynamicExpressions:
2- Node, copy_node, set_node!, count_nodes, has_constants, has_operators
1+ using DynamicExpressions:
2+ AbstractExpressionNode,
3+ AbstractNode,
4+ Node,
5+ NodeSampler,
6+ constructorof,
7+ set_node!,
8+ count_nodes
9+ using Random: AbstractRNG, default_rng
310
4- # This code is copied from SymbolicRegression.jl and modified
11+ """
12+ random_node(tree::AbstractNode; filter::F=Returns(true))
513
6- # Return a random node from the tree
7- function random_node (tree:: Node{T} ):: Node{T} where {T}
8- if tree. degree == 0
9- return tree
10- end
11- b = count_nodes (tree. l)
12- c = if tree. degree == 2
13- count_nodes (tree. r)
14- else
15- 0
16- end
17-
18- i = rand (1 : (1 + b + c))
19- if i <= b
20- return random_node (tree. l)
21- elseif i == b + 1
22- return tree
23- end
24-
25- return random_node (tree. r)
14+ Return a random node from the tree. You may optionally
15+ filter the nodes matching some condition before sampling.
16+ """
17+ function random_node (
18+ tree:: AbstractNode , rng:: AbstractRNG = default_rng (); filter:: F = Returns (true )
19+ ) where {F<: Function }
20+ Base. depwarn (
21+ " Instead of `random_node(tree, filter)`, use `rand(NodeSampler(; tree, filter))`" ,
22+ :random_node ,
23+ )
24+ return rand (rng, NodeSampler (; tree, filter))
2625end
2726
28- function make_random_leaf (nfeatures:: Integer , :: Type{T} ):: Node{T} where {T}
29- if rand () > 0.5
30- return Node (; val= randn (T))
27+ function make_random_leaf (
28+ nfeatures:: Int , :: Type{T} , :: Type{N} , rng:: AbstractRNG = default_rng ()
29+ ) where {T,N<: AbstractExpressionNode }
30+ if rand (rng, Bool)
31+ return constructorof (N)(; val= randn (rng, T))
3132 else
32- return Node ( T; feature= rand (1 : nfeatures))
33+ return constructorof (N)( T; feature= rand (rng, 1 : nfeatures))
3334 end
3435end
3536
36- # Add a random unary/binary operation to the end of a tree
37+ """ Add a random unary/binary operation to the end of a tree"""
3738function append_random_op (
38- tree:: Node{T} , operators, nfeatures:: Integer ; makeNewBinOp:: Union{Bool,Nothing} = nothing
39- ):: Node{T} where {T}
39+ tree:: AbstractExpressionNode{T} ,
40+ operators,
41+ nfeatures:: Int ,
42+ rng:: AbstractRNG = default_rng ();
43+ makeNewBinOp:: Union{Bool,Nothing} = nothing ,
44+ ) where {T}
45+ node = rand (rng, NodeSampler (; tree, filter= t -> t. degree == 0 ))
4046 nuna = length (operators. unaops)
4147 nbin = length (operators. binops)
4248
43- node = random_node (tree)
44- while node. degree != 0
45- node = random_node (tree)
46- end
47-
4849 if makeNewBinOp === nothing
49- choice = rand ()
50+ choice = rand (rng )
5051 makeNewBinOp = choice < nbin / (nuna + nbin)
5152 end
5253
5354 if makeNewBinOp
54- newnode = Node (
55- rand (1 : nbin), make_random_leaf (nfeatures, T), make_random_leaf (nfeatures, T)
55+ newnode = constructorof (typeof (tree))(
56+ rand (rng, 1 : nbin),
57+ make_random_leaf (nfeatures, T, typeof (tree), rng),
58+ make_random_leaf (nfeatures, T, typeof (tree), rng),
5659 )
5760 else
58- newnode = Node (rand (1 : nuna), make_random_leaf (nfeatures, T))
61+ newnode = constructorof (typeof (tree))(
62+ rand (rng, 1 : nuna), make_random_leaf (nfeatures, T, typeof (tree), rng)
63+ )
5964 end
6065
6166 set_node! (node, newnode)
@@ -64,16 +69,21 @@ function append_random_op(
6469end
6570
6671function gen_random_tree_fixed_size (
67- node_count:: Integer , operators, nfeatures:: Integer , :: Type{T}
68- ):: Node{T} where {T}
69- tree = make_random_leaf (nfeatures, T)
72+ node_count:: Int ,
73+ operators,
74+ nfeatures:: Int ,
75+ :: Type{T} ,
76+ node_type= Node,
77+ rng:: AbstractRNG = default_rng (),
78+ ) where {T}
79+ tree = make_random_leaf (nfeatures, T, node_type, rng)
7080 cur_size = count_nodes (tree)
7181 while cur_size < node_count
7282 if cur_size == node_count - 1 # only unary operator allowed.
7383 length (operators. unaops) == 0 && break # We will go over the requested amount, so we must break.
74- tree = append_random_op (tree, operators, nfeatures; makeNewBinOp= false )
84+ tree = append_random_op (tree, operators, nfeatures, rng ; makeNewBinOp= false )
7585 else
76- tree = append_random_op (tree, operators, nfeatures)
86+ tree = append_random_op (tree, operators, nfeatures, rng )
7787 end
7888 cur_size = count_nodes (tree)
7989 end
0 commit comments