Skip to content

Commit 4638851

Browse files
authored
Merge pull request #53 from SymbolicML/abstract-node
Create `AbstractNode` super type
2 parents 65184f9 + 4f621ec commit 4638851

File tree

7 files changed

+130
-48
lines changed

7 files changed

+130
-48
lines changed

docs/src/types.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,9 @@ You can create a copy of a node with `copy_node`:
8383
```@docs
8484
copy_node(tree::Node)
8585
```
86+
87+
There is also an abstract type `AbstractNode` which is a supertype of `Node`:
88+
89+
```@docs
90+
AbstractNode
91+
```

src/DynamicExpressions.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@ include("ExtensionInterface.jl")
1414
import PackageExtensionCompat: @require_extensions
1515
import Reexport: @reexport
1616
@reexport import .EquationModule:
17-
Node, string_tree, print_tree, copy_node, set_node!, tree_mapreduce, filter_map
17+
AbstractNode,
18+
Node,
19+
string_tree,
20+
print_tree,
21+
copy_node,
22+
set_node!,
23+
tree_mapreduce,
24+
filter_map
1825
@reexport import .EquationUtilsModule:
1926
count_nodes,
2027
count_constants,

src/Equation.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,24 @@ import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap
55

66
const DEFAULT_NODE_TYPE = Float32
77

8+
"""
9+
AbstractNode
10+
11+
Abstract type for binary trees. Must have the following fields:
12+
13+
- `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1,
14+
then `l` needs to be defined as the left child. If 2,
15+
then `r` also needs to be defined as the right child.
16+
- `l::AbstractNode`: Left child of the current node. Should only be
17+
defined if `degree >= 1`; otherwise, leave it undefined (see the
18+
the constructors of `Node{T}` for an example).
19+
Don't use `nothing` to represent an undefined value
20+
as it will incur a large performance penalty.
21+
- `r::AbstractNode`: Right child of the current node. Should only
22+
be defined if `degree == 2`.
23+
"""
24+
abstract type AbstractNode end
25+
826
#! format: off
927
"""
1028
Node{T}
@@ -36,7 +54,7 @@ nodes, you can evaluate or print a given expression.
3654
Same type as the parent node. This is to be passed as the right
3755
argument to the binary operator.
3856
"""
39-
mutable struct Node{T}
57+
mutable struct Node{T} <: AbstractNode
4058
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
4159
constant::Bool # false if variable
4260
val::Union{T,Nothing} # If is a constant, this stores the actual value

src/EquationUtils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
module EquationUtilsModule
22

33
import Compat: Returns
4-
import ..EquationModule: Node, copy_node, tree_mapreduce, any, filter_map
4+
import ..EquationModule: AbstractNode, Node, copy_node, tree_mapreduce, any, filter_map
55

66
"""
7-
count_nodes(tree::Node{T})::Int where {T}
7+
count_nodes(tree::AbstractNode)::Int
88
99
Count the number of nodes in the tree.
1010
"""
11-
count_nodes(tree::Node) = tree_mapreduce(_ -> 1, +, tree)
11+
count_nodes(tree::AbstractNode) = tree_mapreduce(_ -> 1, +, tree)
1212
# This code is given as an example. Normally we could just use sum(Returns(1), tree).
1313

1414
"""
15-
count_depth(tree::Node{T})::Int where {T}
15+
count_depth(tree::AbstractNode)::Int
1616
1717
Compute the max depth of the tree.
1818
"""
19-
function count_depth(tree::Node)
19+
function count_depth(tree::AbstractNode)
2020
return tree_mapreduce(Returns(1), (p, child...) -> p + max(child...), tree)
2121
end
2222

src/base.jl

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ import Compat: @inline, Returns
2626
import ..UtilsModule: @memoize_on, @with_memoize
2727

2828
"""
29-
tree_mapreduce(f::Function, op::Function, tree::Node, result_type::Type=Nothing)
30-
tree_mapreduce(f_leaf::Function, f_branch::Function, op::Function, tree::Node, result_type::Type=Nothing)
29+
tree_mapreduce(f::Function, op::Function, tree::AbstractNode, result_type::Type=Nothing)
30+
tree_mapreduce(f_leaf::Function, f_branch::Function, op::Function, tree::AbstractNode, result_type::Type=Nothing)
3131
3232
Map a function over a tree and aggregate the result using an operator `op`.
3333
`op` should be defined with inputs `(parent, child...) ->` so that it can aggregate
@@ -66,23 +66,27 @@ end # Get list of constants. (regular mapreduce also works)
6666
```
6767
"""
6868
function tree_mapreduce(
69-
f::F, op::G, tree::N, result_type::Type{RT}=Nothing; preserve_sharing::Bool=false
70-
) where {T,N<:Node{T},F<:Function,G<:Function,RT}
69+
f::F,
70+
op::G,
71+
tree::AbstractNode,
72+
result_type::Type{RT}=Nothing;
73+
preserve_sharing::Bool=false,
74+
) where {F<:Function,G<:Function,RT}
7175
return tree_mapreduce(f, f, op, tree, result_type; preserve_sharing)
7276
end
7377
function tree_mapreduce(
7478
f_leaf::F1,
7579
f_branch::F2,
7680
op::G,
77-
tree::N,
81+
tree::AbstractNode,
7882
result_type::Type{RT}=Nothing;
7983
preserve_sharing::Bool=false,
80-
) where {T,N<:Node{T},F1<:Function,F2<:Function,G<:Function,RT}
84+
) where {F1<:Function,F2<:Function,G<:Function,RT}
8185

8286
# Trick taken from here:
8387
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
8488
# to speed up recursive closure
85-
@memoize_on t function inner(inner, t::Node)
89+
@memoize_on t function inner(inner, t)
8690
if t.degree == 0
8791
return @inline(f_leaf(t))
8892
elseif t.degree == 1
@@ -97,19 +101,19 @@ function tree_mapreduce(
97101
throw(ArgumentError("Need to specify `result_type` if you use `preserve_sharing`."))
98102

99103
if preserve_sharing && RT != Nothing
100-
return @with_memoize inner(inner, tree) IdDict{N,RT}()
104+
return @with_memoize inner(inner, tree) IdDict{typeof(tree),RT}()
101105
else
102106
return inner(inner, tree)
103107
end
104108
end
105109

106110
"""
107-
any(f::Function, tree::Node)
111+
any(f::Function, tree::AbstractNode)
108112
109113
Reduce a flag function over a tree, returning `true` if the function returns `true` for any node.
110114
By using this instead of tree_mapreduce, we can take advantage of early exits.
111115
"""
112-
function any(f::F, tree::Node) where {F<:Function}
116+
function any(f::F, tree::AbstractNode) where {F<:Function}
113117
if tree.degree == 0
114118
return @inline(f(tree))::Bool
115119
elseif tree.degree == 1
@@ -119,19 +123,25 @@ function any(f::F, tree::Node) where {F<:Function}
119123
end
120124
end
121125

122-
function Base.:(==)(a::Node{T1}, b::Node{T2})::Bool where {T1,T2}
126+
function Base.:(==)(a::AbstractNode, b::AbstractNode)::Bool
123127
(degree = a.degree) != b.degree && return false
124128
if degree == 0
125-
(constant = a.constant) != b.constant && return false
126-
if constant
127-
return a.val::T1 == b.val::T2
128-
else
129-
return a.feature == b.feature
130-
end
129+
return isequal_deg0(a, b)
131130
elseif degree == 1
132-
return a.op == b.op && a.l == b.l
131+
return isequal_deg1(a, b) && a.l == b.l
132+
else
133+
return isequal_deg2(a, b) && a.l == b.l && a.r == b.r
134+
end
135+
end
136+
137+
@inline isequal_deg1(a::Node, b::Node) = a.op == b.op
138+
@inline isequal_deg2(a::Node, b::Node) = a.op == b.op
139+
@inline function isequal_deg0(a::Node{T1}, b::Node{T2}) where {T1,T2}
140+
(constant = a.constant) != b.constant && return false
141+
if constant
142+
return a.val::T1 == b.val::T2
133143
else
134-
return a.op == b.op && a.l == b.l && a.r == b.r
144+
return a.feature == b.feature
135145
end
136146
end
137147

@@ -144,33 +154,33 @@ end
144154
145155
Apply a function to each node in a tree.
146156
"""
147-
function foreach(f::Function, tree::Node)
157+
function foreach(f::Function, tree::AbstractNode)
148158
return tree_mapreduce(t -> (@inline(f(t)); nothing), Returns(nothing), tree)
149159
end
150160

151161
"""
152-
filter_map(filter_fnc::Function, map_fnc::Function, tree::Node, result_type::Type)
162+
filter_map(filter_fnc::Function, map_fnc::Function, tree::AbstractNode, result_type::Type)
153163
154164
A faster equivalent to `map(map_fnc, filter(filter_fnc, tree))`
155165
that avoids the intermediate allocation. However, using this requires
156166
specifying the `result_type` of `map_fnc` so the resultant array can
157167
be preallocated.
158168
"""
159169
function filter_map(
160-
filter_fnc::F, map_fnc::G, tree::Node, result_type::Type{GT}
170+
filter_fnc::F, map_fnc::G, tree::AbstractNode, result_type::Type{GT}
161171
) where {F<:Function,G<:Function,GT}
162172
stack = Array{GT}(undef, count(filter_fnc, tree))
163173
filter_map!(filter_fnc, map_fnc, stack, tree)
164174
return stack::Vector{GT}
165175
end
166176

167177
"""
168-
filter_map!(filter_fnc::Function, map_fnc::Function, stack::Vector{GT}, tree::Node)
178+
filter_map!(filter_fnc::Function, map_fnc::Function, stack::Vector{GT}, tree::AbstractNode)
169179
170180
Equivalent to `filter_map`, but stores the results in a preallocated array.
171181
"""
172182
function filter_map!(
173-
filter_fnc::Function, map_fnc::Function, destination::Vector{GT}, tree::Node
183+
filter_fnc::Function, map_fnc::Function, destination::Vector{GT}, tree::AbstractNode
174184
) where {GT}
175185
pointer = Ref(0)
176186
foreach(tree) do t
@@ -183,49 +193,49 @@ function filter_map!(
183193
end
184194

185195
"""
186-
filter(f::Function, tree::Node)
196+
filter(f::Function, tree::AbstractNode)
187197
188198
Filter nodes of a tree, returning a flat array of the nodes for which the function returns `true`.
189199
"""
190-
function filter(f::F, tree::Node{T}) where {F<:Function,T}
191-
return filter_map(f, identity, tree, Node{T})
200+
function filter(f::F, tree::AbstractNode) where {F<:Function}
201+
return filter_map(f, identity, tree, typeof(tree))
192202
end
193203

194-
collect(tree::Node) = filter(Returns(true), tree)
204+
collect(tree::AbstractNode) = filter(Returns(true), tree)
195205

196206
"""
197-
map(f::Function, tree::Node, result_type::Type{RT}=Nothing)
207+
map(f::Function, tree::AbstractNode, result_type::Type{RT}=Nothing)
198208
199209
Map a function over a tree and return a flat array of the results in depth-first order.
200210
Pre-specifying the `result_type` of the function can be used to avoid extra allocations,
201211
"""
202-
function map(f::F, tree::Node, result_type::Type{RT}=Nothing) where {F<:Function,RT}
212+
function map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing) where {F<:Function,RT}
203213
if RT == Nothing
204214
return f.(collect(tree))
205215
else
206216
return filter_map(Returns(true), f, tree, result_type)
207217
end
208218
end
209219

210-
function count(f::F, tree::Node; init=0) where {F<:Function}
220+
function count(f::F, tree::AbstractNode; init=0) where {F<:Function}
211221
return tree_mapreduce(t -> @inline(f(t)) ? 1 : 0, +, tree) + init
212222
end
213223

214-
function sum(f::F, tree::Node; init=0) where {F<:Function}
224+
function sum(f::F, tree::AbstractNode; init=0) where {F<:Function}
215225
return tree_mapreduce(f, +, tree) + init
216226
end
217227

218-
all(f::F, tree::Node) where {F<:Function} = !any(t -> !@inline(f(t)), tree)
228+
all(f::F, tree::AbstractNode) where {F<:Function} = !any(t -> !@inline(f(t)), tree)
219229

220-
function mapreduce(f::F, op::G, tree::Node) where {F<:Function,G<:Function}
230+
function mapreduce(f::F, op::G, tree::AbstractNode) where {F<:Function,G<:Function}
221231
return tree_mapreduce(f, (n...) -> reduce(op, n), tree)
222232
end
223233

224-
isempty(::Node) = false
225-
iterate(root::Node) = (root, collect(root)[(begin + 1):end])
226-
iterate(::Node, stack) = isempty(stack) ? nothing : (popfirst!(stack), stack)
227-
in(item, tree::Node) = any(t -> t == item, tree)
228-
length(tree::Node) = sum(Returns(1), tree)
234+
isempty(::AbstractNode) = false
235+
iterate(root::AbstractNode) = (root, collect(root)[(begin + 1):end])
236+
iterate(::AbstractNode, stack) = isempty(stack) ? nothing : (popfirst!(stack), stack)
237+
in(item, tree::AbstractNode) = any(t -> t == item, tree)
238+
length(tree::AbstractNode) = sum(Returns(1), tree)
229239
function hash(tree::Node{T}) where {T}
230240
return tree_mapreduce(
231241
t -> t.constant ? hash((0, t.val::T)) : hash((1, t.feature)),
@@ -299,11 +309,11 @@ end
299309

300310
for func in (:reduce, :foldl, :foldr, :mapfoldl, :mapfoldr)
301311
@eval begin
302-
function $func(f, tree::Node; kws...)
312+
function $func(f, tree::AbstractNode; kws...)
303313
throw(
304314
error(
305315
string($func) *
306-
" not implemented for Node. Use `tree_mapreduce` instead.",
316+
" not implemented for AbstractNode. Use `tree_mapreduce` instead.",
307317
),
308318
)
309319
end

test/test_custom_node_type.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using DynamicExpressions
2+
using Test
3+
4+
mutable struct MyCustomNode{A,B} <: AbstractNode
5+
degree::Int
6+
val1::A
7+
val2::B
8+
l::MyCustomNode{A,B}
9+
r::MyCustomNode{A,B}
10+
11+
MyCustomNode(val1, val2) = new{typeof(val1),typeof(val2)}(0, val1, val2)
12+
MyCustomNode(val1, val2, l) = new{typeof(val1),typeof(val2)}(1, val1, val2, l)
13+
MyCustomNode(val1, val2, l, r) = new{typeof(val1),typeof(val2)}(2, val1, val2, l, r)
14+
end
15+
16+
node1 = MyCustomNode(1.0, 2)
17+
18+
@test typeof(node1) == MyCustomNode{Float64,Int}
19+
@test node1.degree == 0
20+
@test count_depth(node1) == 1
21+
@test count_nodes(node1) == 1
22+
23+
node2 = MyCustomNode(1.5, 3, node1)
24+
25+
@test typeof(node2) == MyCustomNode{Float64,Int}
26+
@test node2.degree == 1
27+
@test node2.l.degree == 0
28+
@test count_depth(node2) == 2
29+
@test count_nodes(node2) == 2
30+
31+
node2 = MyCustomNode(1.5, 3, node1, node1)
32+
33+
@test count_depth(node2) == 2
34+
@test count_nodes(node2) == 3
35+
@test sum(t -> t.val1, node2) == 1.5 + 1.0 + 1.0
36+
@test sum(t -> t.val2, node2) == 3 + 2 + 2
37+
@test count(t -> t.degree == 0, node2) == 2

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,7 @@ end
9191
@safetestset "Test helpers break upon redefining" begin
9292
include("test_safe_helpers.jl")
9393
end
94+
95+
@safetestset "Test custom node type" begin
96+
include("test_custom_node_type.jl")
97+
end

0 commit comments

Comments
 (0)