Skip to content

Commit 3d57bd4

Browse files
committed
Implement AbstractNode super type
1 parent 65184f9 commit 3d57bd4

File tree

4 files changed

+78
-48
lines changed

4 files changed

+78
-48
lines changed

src/DynamicExpressions.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@ include("ExtensionInterface.jl")
1414
import PackageExtensionCompat: @require_extensions
1515
import Reexport: @reexport
1616
@reexport import .EquationModule:
17-
Node, string_tree, print_tree, copy_node, set_node!, tree_mapreduce, filter_map
17+
AbstractNode,
18+
Node,
19+
string_tree,
20+
print_tree,
21+
copy_node,
22+
set_node!,
23+
tree_mapreduce,
24+
filter_map
1825
@reexport import .EquationUtilsModule:
1926
count_nodes,
2027
count_constants,

src/Equation.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,19 @@ import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap
55

66
const DEFAULT_NODE_TYPE = Float32
77

8+
"""
9+
AbstractNode
10+
11+
Abstract type for binary trees. Must have the following fields:
12+
13+
- `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1,
14+
then `l` needs to be defined as the left child. If 2,
15+
then `r` also needs to be defined as the right child.
16+
- `l::AbstractNode`: Left child of the current node.
17+
- `r::AbstractNode`: Right child of the current node.
18+
"""
19+
abstract type AbstractNode end
20+
821
#! format: off
922
"""
1023
Node{T}
@@ -36,7 +49,7 @@ nodes, you can evaluate or print a given expression.
3649
Same type as the parent node. This is to be passed as the right
3750
argument to the binary operator.
3851
"""
39-
mutable struct Node{T}
52+
mutable struct Node{T} <: AbstractNode
4053
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
4154
constant::Bool # false if variable
4255
val::Union{T,Nothing} # If is a constant, this stores the actual value

src/EquationUtils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
module EquationUtilsModule
22

33
import Compat: Returns
4-
import ..EquationModule: Node, copy_node, tree_mapreduce, any, filter_map
4+
import ..EquationModule: AbstractNode, Node, copy_node, tree_mapreduce, any, filter_map
55

66
"""
7-
count_nodes(tree::Node{T})::Int where {T}
7+
count_nodes(tree::AbstractNode)::Int
88
99
Count the number of nodes in the tree.
1010
"""
11-
count_nodes(tree::Node) = tree_mapreduce(_ -> 1, +, tree)
11+
count_nodes(tree::AbstractNode) = tree_mapreduce(_ -> 1, +, tree)
1212
# This code is given as an example. Normally we could just use sum(Returns(1), tree).
1313

1414
"""
15-
count_depth(tree::Node{T})::Int where {T}
15+
count_depth(tree::AbstractNode)::Int
1616
1717
Compute the max depth of the tree.
1818
"""
19-
function count_depth(tree::Node)
19+
function count_depth(tree::AbstractNode)
2020
return tree_mapreduce(Returns(1), (p, child...) -> p + max(child...), tree)
2121
end
2222

src/base.jl

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ import Compat: @inline, Returns
2626
import ..UtilsModule: @memoize_on, @with_memoize
2727

2828
"""
29-
tree_mapreduce(f::Function, op::Function, tree::Node, result_type::Type=Nothing)
30-
tree_mapreduce(f_leaf::Function, f_branch::Function, op::Function, tree::Node, result_type::Type=Nothing)
29+
tree_mapreduce(f::Function, op::Function, tree::AbstractNode, result_type::Type=Nothing)
30+
tree_mapreduce(f_leaf::Function, f_branch::Function, op::Function, tree::AbstractNode, result_type::Type=Nothing)
3131
3232
Map a function over a tree and aggregate the result using an operator `op`.
3333
`op` should be defined with inputs `(parent, child...) ->` so that it can aggregate
@@ -66,23 +66,27 @@ end # Get list of constants. (regular mapreduce also works)
6666
```
6767
"""
6868
function tree_mapreduce(
69-
f::F, op::G, tree::N, result_type::Type{RT}=Nothing; preserve_sharing::Bool=false
70-
) where {T,N<:Node{T},F<:Function,G<:Function,RT}
69+
f::F,
70+
op::G,
71+
tree::AbstractNode,
72+
result_type::Type{RT}=Nothing;
73+
preserve_sharing::Bool=false,
74+
) where {F<:Function,G<:Function,RT}
7175
return tree_mapreduce(f, f, op, tree, result_type; preserve_sharing)
7276
end
7377
function tree_mapreduce(
7478
f_leaf::F1,
7579
f_branch::F2,
7680
op::G,
77-
tree::N,
81+
tree::AbstractNode,
7882
result_type::Type{RT}=Nothing;
7983
preserve_sharing::Bool=false,
80-
) where {T,N<:Node{T},F1<:Function,F2<:Function,G<:Function,RT}
84+
) where {F1<:Function,F2<:Function,G<:Function,RT}
8185

8286
# Trick taken from here:
8387
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
8488
# to speed up recursive closure
85-
@memoize_on t function inner(inner, t::Node)
89+
@memoize_on t function inner(inner, t)
8690
if t.degree == 0
8791
return @inline(f_leaf(t))
8892
elseif t.degree == 1
@@ -97,19 +101,19 @@ function tree_mapreduce(
97101
throw(ArgumentError("Need to specify `result_type` if you use `preserve_sharing`."))
98102

99103
if preserve_sharing && RT != Nothing
100-
return @with_memoize inner(inner, tree) IdDict{N,RT}()
104+
return @with_memoize inner(inner, tree) IdDict{typeof(tree),RT}()
101105
else
102106
return inner(inner, tree)
103107
end
104108
end
105109

106110
"""
107-
any(f::Function, tree::Node)
111+
any(f::Function, tree::AbstractNode)
108112
109113
Reduce a flag function over a tree, returning `true` if the function returns `true` for any node.
110114
By using this instead of tree_mapreduce, we can take advantage of early exits.
111115
"""
112-
function any(f::F, tree::Node) where {F<:Function}
116+
function any(f::F, tree::AbstractNode) where {F<:Function}
113117
if tree.degree == 0
114118
return @inline(f(tree))::Bool
115119
elseif tree.degree == 1
@@ -119,19 +123,25 @@ function any(f::F, tree::Node) where {F<:Function}
119123
end
120124
end
121125

122-
function Base.:(==)(a::Node{T1}, b::Node{T2})::Bool where {T1,T2}
126+
function Base.:(==)(a::AbstractNode, b::AbstractNode)::Bool
123127
(degree = a.degree) != b.degree && return false
124128
if degree == 0
125-
(constant = a.constant) != b.constant && return false
126-
if constant
127-
return a.val::T1 == b.val::T2
128-
else
129-
return a.feature == b.feature
130-
end
129+
return isequal_deg0(a, b)
131130
elseif degree == 1
132-
return a.op == b.op && a.l == b.l
131+
return isequal_deg1(a, b) && a.l == b.l
132+
else
133+
return isequal_deg2(a, b) && a.l == b.l && a.r == b.r
134+
end
135+
end
136+
137+
@inline isequal_deg1(a::Node, b::Node) = a.op == b.op
138+
@inline isequal_deg2(a::Node, b::Node) = a.op == b.op
139+
@inline function isequal_deg0(a::Node{T1}, b::Node{T2}) where {T1,T2}
140+
(constant = a.constant) != b.constant && return false
141+
if constant
142+
return a.val::T1 == b.val::T2
133143
else
134-
return a.op == b.op && a.l == b.l && a.r == b.r
144+
return a.feature == b.feature
135145
end
136146
end
137147

@@ -144,33 +154,33 @@ end
144154
145155
Apply a function to each node in a tree.
146156
"""
147-
function foreach(f::Function, tree::Node)
157+
function foreach(f::Function, tree::AbstractNode)
148158
return tree_mapreduce(t -> (@inline(f(t)); nothing), Returns(nothing), tree)
149159
end
150160

151161
"""
152-
filter_map(filter_fnc::Function, map_fnc::Function, tree::Node, result_type::Type)
162+
filter_map(filter_fnc::Function, map_fnc::Function, tree::AbstractNode, result_type::Type)
153163
154164
A faster equivalent to `map(map_fnc, filter(filter_fnc, tree))`
155165
that avoids the intermediate allocation. However, using this requires
156166
specifying the `result_type` of `map_fnc` so the resultant array can
157167
be preallocated.
158168
"""
159169
function filter_map(
160-
filter_fnc::F, map_fnc::G, tree::Node, result_type::Type{GT}
170+
filter_fnc::F, map_fnc::G, tree::AbstractNode, result_type::Type{GT}
161171
) where {F<:Function,G<:Function,GT}
162172
stack = Array{GT}(undef, count(filter_fnc, tree))
163173
filter_map!(filter_fnc, map_fnc, stack, tree)
164174
return stack::Vector{GT}
165175
end
166176

167177
"""
168-
filter_map!(filter_fnc::Function, map_fnc::Function, stack::Vector{GT}, tree::Node)
178+
filter_map!(filter_fnc::Function, map_fnc::Function, stack::Vector{GT}, tree::AbstractNode)
169179
170180
Equivalent to `filter_map`, but stores the results in a preallocated array.
171181
"""
172182
function filter_map!(
173-
filter_fnc::Function, map_fnc::Function, destination::Vector{GT}, tree::Node
183+
filter_fnc::Function, map_fnc::Function, destination::Vector{GT}, tree::AbstractNode
174184
) where {GT}
175185
pointer = Ref(0)
176186
foreach(tree) do t
@@ -183,49 +193,49 @@ function filter_map!(
183193
end
184194

185195
"""
186-
filter(f::Function, tree::Node)
196+
filter(f::Function, tree::AbstractNode)
187197
188198
Filter nodes of a tree, returning a flat array of the nodes for which the function returns `true`.
189199
"""
190-
function filter(f::F, tree::Node{T}) where {F<:Function,T}
191-
return filter_map(f, identity, tree, Node{T})
200+
function filter(f::F, tree::AbstractNode) where {F<:Function}
201+
return filter_map(f, identity, tree, typeof(tree))
192202
end
193203

194-
collect(tree::Node) = filter(Returns(true), tree)
204+
collect(tree::AbstractNode) = filter(Returns(true), tree)
195205

196206
"""
197-
map(f::Function, tree::Node, result_type::Type{RT}=Nothing)
207+
map(f::Function, tree::AbstractNode, result_type::Type{RT}=Nothing)
198208
199209
Map a function over a tree and return a flat array of the results in depth-first order.
200210
Pre-specifying the `result_type` of the function can be used to avoid extra allocations,
201211
"""
202-
function map(f::F, tree::Node, result_type::Type{RT}=Nothing) where {F<:Function,RT}
212+
function map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing) where {F<:Function,RT}
203213
if RT == Nothing
204214
return f.(collect(tree))
205215
else
206216
return filter_map(Returns(true), f, tree, result_type)
207217
end
208218
end
209219

210-
function count(f::F, tree::Node; init=0) where {F<:Function}
220+
function count(f::F, tree::AbstractNode; init=0) where {F<:Function}
211221
return tree_mapreduce(t -> @inline(f(t)) ? 1 : 0, +, tree) + init
212222
end
213223

214-
function sum(f::F, tree::Node; init=0) where {F<:Function}
224+
function sum(f::F, tree::AbstractNode; init=0) where {F<:Function}
215225
return tree_mapreduce(f, +, tree) + init
216226
end
217227

218-
all(f::F, tree::Node) where {F<:Function} = !any(t -> !@inline(f(t)), tree)
228+
all(f::F, tree::AbstractNode) where {F<:Function} = !any(t -> !@inline(f(t)), tree)
219229

220-
function mapreduce(f::F, op::G, tree::Node) where {F<:Function,G<:Function}
230+
function mapreduce(f::F, op::G, tree::AbstractNode) where {F<:Function,G<:Function}
221231
return tree_mapreduce(f, (n...) -> reduce(op, n), tree)
222232
end
223233

224-
isempty(::Node) = false
225-
iterate(root::Node) = (root, collect(root)[(begin + 1):end])
226-
iterate(::Node, stack) = isempty(stack) ? nothing : (popfirst!(stack), stack)
227-
in(item, tree::Node) = any(t -> t == item, tree)
228-
length(tree::Node) = sum(Returns(1), tree)
234+
isempty(::AbstractNode) = false
235+
iterate(root::AbstractNode) = (root, collect(root)[(begin + 1):end])
236+
iterate(::AbstractNode, stack) = isempty(stack) ? nothing : (popfirst!(stack), stack)
237+
in(item, tree::AbstractNode) = any(t -> t == item, tree)
238+
length(tree::AbstractNode) = sum(Returns(1), tree)
229239
function hash(tree::Node{T}) where {T}
230240
return tree_mapreduce(
231241
t -> t.constant ? hash((0, t.val::T)) : hash((1, t.feature)),
@@ -299,11 +309,11 @@ end
299309

300310
for func in (:reduce, :foldl, :foldr, :mapfoldl, :mapfoldr)
301311
@eval begin
302-
function $func(f, tree::Node; kws...)
312+
function $func(f, tree::AbstractNode; kws...)
303313
throw(
304314
error(
305315
string($func) *
306-
" not implemented for Node. Use `tree_mapreduce` instead.",
316+
" not implemented for AbstractNode. Use `tree_mapreduce` instead.",
307317
),
308318
)
309319
end

0 commit comments

Comments
 (0)