Skip to content

Commit 1655a8e

Browse files
committed
Merge branch 'tree-map' into constant-optimization
2 parents 5a8e0b2 + fcc056e commit 1655a8e

File tree

9 files changed

+500
-392
lines changed

9 files changed

+500
-392
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,7 @@ end
2121

2222
using Reexport: @reexport
2323
@reexport import .EquationModule:
24-
Node,
25-
string_tree,
26-
print_tree,
27-
copy_node,
28-
set_node!,
29-
map,
30-
tree_mapreduce,
31-
any,
32-
filter_and_map
24+
Node, string_tree, print_tree, copy_node, set_node!, tree_mapreduce, filter_map
3325
@reexport import .EquationUtilsModule:
3426
count_nodes,
3527
count_constants,

src/Equation.jl

Lines changed: 2 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module EquationModule
22

33
import ..OperatorEnumModule: AbstractOperatorEnum
4-
import ..UtilsModule: @generate_idmap, @use_idmap
4+
import ..UtilsModule: @memoize_on, @with_memoize
55

66
const DEFAULT_NODE_TYPE = Float32
77

@@ -62,53 +62,7 @@ mutable struct Node{T}
6262
end
6363
################################################################################
6464

65-
include("tree_map.jl")
66-
67-
"""
68-
convert(::Type{Node{T1}}, n::Node{T2}) where {T1,T2}
69-
70-
Convert a `Node{T2}` to a `Node{T1}`.
71-
This will recursively convert all children nodes to `Node{T1}`,
72-
using `convert(T1, tree.val)` at constant nodes.
73-
74-
# Arguments
75-
- `::Type{Node{T1}}`: Type to convert to.
76-
- `tree::Node{T2}`: Node to convert.
77-
"""
78-
function Base.convert(
79-
::Type{Node{T1}}, tree::Node{T2}; preserve_sharing::Bool=false
80-
) where {T1,T2}
81-
if T1 == T2
82-
return tree
83-
end
84-
if preserve_sharing
85-
@use_idmap(_convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}())
86-
else
87-
_convert(Node{T1}, tree)
88-
end
89-
end
90-
91-
@generate_idmap tree function _convert(::Type{Node{T1}}, tree::Node{T2}) where {T1,T2}
92-
if tree.degree == 0
93-
if tree.constant
94-
val = tree.val::T2
95-
if !(T2 <: T1)
96-
# e.g., we don't want to convert Float32 to Union{Float32,Vector{Float32}}!
97-
val = convert(T1, val)
98-
end
99-
Node(T1, 0, tree.constant, val)
100-
else
101-
Node(T1, 0, tree.constant, nothing, tree.feature)
102-
end
103-
elseif tree.degree == 1
104-
l = _convert(Node{T1}, tree.l)
105-
Node(1, tree.constant, nothing, tree.feature, tree.op, l)
106-
else
107-
l = _convert(Node{T1}, tree.l)
108-
r = _convert(Node{T1}, tree.r)
109-
Node(2, tree.constant, nothing, tree.feature, tree.op, l, r)
110-
end
111-
end
65+
include("base.jl")
11266

11367
"""
11468
Node([::Type{T}]; val=nothing, feature::Int=nothing) where {T}
@@ -226,45 +180,6 @@ function set_node!(tree::Node{T}, new_tree::Node{T}) where {T}
226180
return nothing
227181
end
228182

229-
"""
230-
copy_node(tree::Node; preserve_sharing::Bool=false)
231-
232-
Copy a node, recursively copying all children nodes.
233-
This is more efficient than the built-in copy.
234-
With `preserve_sharing=true`, this will also
235-
preserve linkage between a node and
236-
multiple parents, whereas without, this would create
237-
duplicate child node copies.
238-
239-
id_map is a map from `objectid(tree)` to `copy(tree)`.
240-
We check against the map before making a new copy; otherwise
241-
we can simply reference the existing copy.
242-
[Thanks to Ted Hopp.](https://stackoverflow.com/questions/49285475/how-to-copy-a-full-non-binary-tree-including-loops)
243-
244-
Note that this will *not* preserve loops in graphs.
245-
"""
246-
function copy_node(tree::Node{T}; preserve_sharing::Bool=false)::Node{T} where {T}
247-
if preserve_sharing
248-
@use_idmap(_copy_node(tree), IdDict{Node{T},Node{T}}())
249-
else
250-
_copy_node(tree)
251-
end
252-
end
253-
254-
@generate_idmap tree function _copy_node(tree::Node{T})::Node{T} where {T}
255-
if tree.degree == 0
256-
if tree.constant
257-
Node(; val=copy(tree.val::T))
258-
else
259-
Node(T; feature=copy(tree.feature))
260-
end
261-
elseif tree.degree == 1
262-
Node(copy(tree.op), _copy_node(tree.l))
263-
else
264-
Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r))
265-
end
266-
end
267-
268183
const OP_NAMES = Dict(
269184
"safe_log" => "log",
270185
"safe_log2" => "log2",
@@ -365,51 +280,4 @@ function print_tree(
365280
return println(string_tree(tree, operators; varMap=varMap))
366281
end
367282

368-
function Base.hash(tree::Node{T})::UInt where {T}
369-
if tree.degree == 0
370-
if tree.constant
371-
# tree.val used.
372-
return hash((0, tree.val::T))
373-
else
374-
# tree.feature used.
375-
return hash((1, tree.feature))
376-
end
377-
elseif tree.degree == 1
378-
return hash((1, tree.op, hash(tree.l)))
379-
else
380-
return hash((2, tree.op, hash(tree.l), hash(tree.r)))
381-
end
382-
end
383-
384-
function is_equal(a::Node{T}, b::Node{T})::Bool where {T}
385-
if a.degree == 0
386-
b.degree != 0 && return false
387-
if a.constant
388-
!(b.constant) && return false
389-
return a.val::T == b.val::T
390-
else
391-
b.constant && return false
392-
return a.feature == b.feature
393-
end
394-
elseif a.degree == 1
395-
b.degree != 1 && return false
396-
a.op != b.op && return false
397-
return is_equal(a.l, b.l)
398-
else
399-
b.degree != 2 && return false
400-
a.op != b.op && return false
401-
return is_equal(a.l, b.l) && is_equal(a.r, b.r)
402-
end
403-
end
404-
405-
function Base.:(==)(a::Node{T}, b::Node{T})::Bool where {T}
406-
return is_equal(a, b)
407-
end
408-
409-
function Base.:(==)(a::Node{T1}, b::Node{T2})::Bool where {T1,T2}
410-
T = promote_type(T1, T2)
411-
# TODO: Should also have preserve_sharing check...
412-
return is_equal(convert(Node{T}, a), convert(Node{T}, b))
413-
end
414-
415283
end

src/EquationUtils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module EquationUtilsModule
22

33
import Compat: Returns
4-
import ..EquationModule: Node, copy_node, tree_mapreduce, any, filter_and_map
4+
import ..EquationModule: Node, copy_node, tree_mapreduce, any, filter_map
55

66
"""
77
count_nodes(tree::Node{T})::Int where {T}
@@ -64,7 +64,7 @@ The function `set_constants!` sets them in the same order,
6464
given the output of this function.
6565
"""
6666
function get_constants(tree::Node{T}) where {T}
67-
return filter_and_map(is_node_constant, t -> (t.val::T), tree; result_type=T)
67+
return filter_map(is_node_constant, t -> (t.val::T), tree, T)
6868
end
6969

7070
"""

src/Utils.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ isgood(x) = true
8181
isbad(x) = !isgood(x)
8282

8383
"""
84-
@generate_idmap tree function my_function_on_tree(tree::Node)
84+
@memoize_on tree function my_function_on_tree(tree::Node)
8585
...
8686
end
8787
@@ -91,14 +91,14 @@ IdDict()), it will use use the `id_map` to avoid recomputing the same value
9191
for the same node in a tree. Use this to automatically create functions that
9292
work with trees that have shared child nodes.
9393
"""
94-
macro generate_idmap(tree, def)
95-
idmap_def = _generate_idmap(tree, def)
94+
macro memoize_on(tree, def)
95+
idmap_def = _memoize_on(tree, def)
9696
return quote
9797
$(esc(def)) # The normal function
9898
$(esc(idmap_def)) # The function with an id_map argument
9999
end
100100
end
101-
function _generate_idmap(tree::Symbol, def::Expr)
101+
function _memoize_on(tree::Symbol, def::Expr)
102102
sdef = splitdef(def)
103103

104104
# Add an id_map argument
@@ -127,13 +127,13 @@ function _generate_idmap(tree::Symbol, def::Expr)
127127
end
128128

129129
"""
130-
@use_idmap(call, id_map)
130+
@with_memoize(call, id_map)
131131
132132
This simple macro simply puts the `id_map`
133-
into the call, to be consistent with the `@generate_idmap` macro.
133+
into the call, to be consistent with the `@memoize_on` macro.
134134
135135
```
136-
@use_idmap(_copy_node(tree), IdDict{Any,Any}())
136+
@with_memoize(_copy_node(tree), IdDict{Any,Any}())
137137
````
138138
139139
is converted to
@@ -143,7 +143,7 @@ _copy_node(tree, IdDict{Any,Any}())
143143
```
144144
145145
"""
146-
macro use_idmap(def, id_map)
146+
macro with_memoize(def, id_map)
147147
idmap_def = _add_idmap_to_call(def, id_map)
148148
return quote
149149
$(esc(idmap_def))

0 commit comments

Comments
 (0)