@@ -26,8 +26,8 @@ import Compat: @inline, Returns
2626import .. UtilsModule: @memoize_on , @with_memoize
2727
2828"""
29- tree_mapreduce(f::Function, op::Function, tree::Node , result_type::Type=Nothing)
30- tree_mapreduce(f_leaf::Function, f_branch::Function, op::Function, tree::Node , result_type::Type=Nothing)
29+ tree_mapreduce(f::Function, op::Function, tree::AbstractNode , result_type::Type=Nothing)
30+ tree_mapreduce(f_leaf::Function, f_branch::Function, op::Function, tree::AbstractNode , result_type::Type=Nothing)
3131
3232Map a function over a tree and aggregate the result using an operator `op`.
3333`op` should be defined with inputs `(parent, child...) ->` so that it can aggregate
@@ -66,23 +66,27 @@ end # Get list of constants. (regular mapreduce also works)
6666```
6767"""
6868function tree_mapreduce (
69- f:: F , op:: G , tree:: N , result_type:: Type{RT} = Nothing; preserve_sharing:: Bool = false
70- ) where {T,N<: Node{T} ,F<: Function ,G<: Function ,RT}
69+ f:: F ,
70+ op:: G ,
71+ tree:: AbstractNode ,
72+ result_type:: Type{RT} = Nothing;
73+ preserve_sharing:: Bool = false ,
74+ ) where {F<: Function ,G<: Function ,RT}
7175 return tree_mapreduce (f, f, op, tree, result_type; preserve_sharing)
7276end
7377function tree_mapreduce (
7478 f_leaf:: F1 ,
7579 f_branch:: F2 ,
7680 op:: G ,
77- tree:: N ,
81+ tree:: AbstractNode ,
7882 result_type:: Type{RT} = Nothing;
7983 preserve_sharing:: Bool = false ,
80- ) where {T,N <: Node{T} , F1<: Function ,F2<: Function ,G<: Function ,RT}
84+ ) where {F1<: Function ,F2<: Function ,G<: Function ,RT}
8185
8286 # Trick taken from here:
8387 # https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
8488 # to speed up recursive closure
85- @memoize_on t function inner (inner, t:: Node )
89+ @memoize_on t function inner (inner, t)
8690 if t. degree == 0
8791 return @inline (f_leaf (t))
8892 elseif t. degree == 1
@@ -97,19 +101,19 @@ function tree_mapreduce(
97101 throw (ArgumentError (" Need to specify `result_type` if you use `preserve_sharing`." ))
98102
99103 if preserve_sharing && RT != Nothing
100- return @with_memoize inner (inner, tree) IdDict {N ,RT} ()
104+ return @with_memoize inner (inner, tree) IdDict {typeof(tree) ,RT} ()
101105 else
102106 return inner (inner, tree)
103107 end
104108end
105109
106110"""
107- any(f::Function, tree::Node )
111+ any(f::Function, tree::AbstractNode )
108112
109113Reduce a flag function over a tree, returning `true` if the function returns `true` for any node.
110114By using this instead of tree_mapreduce, we can take advantage of early exits.
111115"""
112- function any (f:: F , tree:: Node ) where {F<: Function }
116+ function any (f:: F , tree:: AbstractNode ) where {F<: Function }
113117 if tree. degree == 0
114118 return @inline (f (tree)):: Bool
115119 elseif tree. degree == 1
@@ -119,19 +123,25 @@ function any(f::F, tree::Node) where {F<:Function}
119123 end
120124end
121125
122- function Base.:(== )(a:: Node{T1} , b:: Node{T2} ):: Bool where {T1,T2}
126+ function Base.:(== )(a:: AbstractNode , b:: AbstractNode ):: Bool
123127 (degree = a. degree) != b. degree && return false
124128 if degree == 0
125- (constant = a. constant) != b. constant && return false
126- if constant
127- return a. val:: T1 == b. val:: T2
128- else
129- return a. feature == b. feature
130- end
129+ return isequal_deg0 (a, b)
131130 elseif degree == 1
132- return a. op == b. op && a. l == b. l
131+ return isequal_deg1 (a, b) && a. l == b. l
132+ else
133+ return isequal_deg2 (a, b) && a. l == b. l && a. r == b. r
134+ end
135+ end
136+
137+ @inline isequal_deg1 (a:: Node , b:: Node ) = a. op == b. op
138+ @inline isequal_deg2 (a:: Node , b:: Node ) = a. op == b. op
139+ @inline function isequal_deg0 (a:: Node{T1} , b:: Node{T2} ) where {T1,T2}
140+ (constant = a. constant) != b. constant && return false
141+ if constant
142+ return a. val:: T1 == b. val:: T2
133143 else
134- return a. op == b. op && a . l == b . l && a . r == b . r
144+ return a. feature == b. feature
135145 end
136146end
137147
@@ -144,33 +154,33 @@ end
144154
145155Apply a function to each node in a tree.
146156"""
147- function foreach (f:: Function , tree:: Node )
157+ function foreach (f:: Function , tree:: AbstractNode )
148158 return tree_mapreduce (t -> (@inline (f (t)); nothing ), Returns (nothing ), tree)
149159end
150160
151161"""
152- filter_map(filter_fnc::Function, map_fnc::Function, tree::Node , result_type::Type)
162+ filter_map(filter_fnc::Function, map_fnc::Function, tree::AbstractNode , result_type::Type)
153163
154164A faster equivalent to `map(map_fnc, filter(filter_fnc, tree))`
155165that avoids the intermediate allocation. However, using this requires
156166specifying the `result_type` of `map_fnc` so the resultant array can
157167be preallocated.
158168"""
159169function filter_map (
160- filter_fnc:: F , map_fnc:: G , tree:: Node , result_type:: Type{GT}
170+ filter_fnc:: F , map_fnc:: G , tree:: AbstractNode , result_type:: Type{GT}
161171) where {F<: Function ,G<: Function ,GT}
162172 stack = Array {GT} (undef, count (filter_fnc, tree))
163173 filter_map! (filter_fnc, map_fnc, stack, tree)
164174 return stack:: Vector{GT}
165175end
166176
167177"""
168- filter_map!(filter_fnc::Function, map_fnc::Function, stack::Vector{GT}, tree::Node )
178+ filter_map!(filter_fnc::Function, map_fnc::Function, stack::Vector{GT}, tree::AbstractNode )
169179
170180Equivalent to `filter_map`, but stores the results in a preallocated array.
171181"""
172182function filter_map! (
173- filter_fnc:: Function , map_fnc:: Function , destination:: Vector{GT} , tree:: Node
183+ filter_fnc:: Function , map_fnc:: Function , destination:: Vector{GT} , tree:: AbstractNode
174184) where {GT}
175185 pointer = Ref (0 )
176186 foreach (tree) do t
@@ -183,49 +193,49 @@ function filter_map!(
183193end
184194
185195"""
186- filter(f::Function, tree::Node )
196+ filter(f::Function, tree::AbstractNode )
187197
188198Filter nodes of a tree, returning a flat array of the nodes for which the function returns `true`.
189199"""
190- function filter (f:: F , tree:: Node{T} ) where {F<: Function ,T }
191- return filter_map (f, identity, tree, Node{T} )
200+ function filter (f:: F , tree:: AbstractNode ) where {F<: Function }
201+ return filter_map (f, identity, tree, typeof (tree) )
192202end
193203
194- collect (tree:: Node ) = filter (Returns (true ), tree)
204+ collect (tree:: AbstractNode ) = filter (Returns (true ), tree)
195205
196206"""
197- map(f::Function, tree::Node , result_type::Type{RT}=Nothing)
207+ map(f::Function, tree::AbstractNode , result_type::Type{RT}=Nothing)
198208
199209Map a function over a tree and return a flat array of the results in depth-first order.
200210Pre-specifying the `result_type` of the function can be used to avoid extra allocations,
201211"""
202- function map (f:: F , tree:: Node , result_type:: Type{RT} = Nothing) where {F<: Function ,RT}
212+ function map (f:: F , tree:: AbstractNode , result_type:: Type{RT} = Nothing) where {F<: Function ,RT}
203213 if RT == Nothing
204214 return f .(collect (tree))
205215 else
206216 return filter_map (Returns (true ), f, tree, result_type)
207217 end
208218end
209219
210- function count (f:: F , tree:: Node ; init= 0 ) where {F<: Function }
220+ function count (f:: F , tree:: AbstractNode ; init= 0 ) where {F<: Function }
211221 return tree_mapreduce (t -> @inline (f (t)) ? 1 : 0 , + , tree) + init
212222end
213223
214- function sum (f:: F , tree:: Node ; init= 0 ) where {F<: Function }
224+ function sum (f:: F , tree:: AbstractNode ; init= 0 ) where {F<: Function }
215225 return tree_mapreduce (f, + , tree) + init
216226end
217227
218- all (f:: F , tree:: Node ) where {F<: Function } = ! any (t -> ! @inline (f (t)), tree)
228+ all (f:: F , tree:: AbstractNode ) where {F<: Function } = ! any (t -> ! @inline (f (t)), tree)
219229
220- function mapreduce (f:: F , op:: G , tree:: Node ) where {F<: Function ,G<: Function }
230+ function mapreduce (f:: F , op:: G , tree:: AbstractNode ) where {F<: Function ,G<: Function }
221231 return tree_mapreduce (f, (n... ) -> reduce (op, n), tree)
222232end
223233
224- isempty (:: Node ) = false
225- iterate (root:: Node ) = (root, collect (root)[(begin + 1 ): end ])
226- iterate (:: Node , stack) = isempty (stack) ? nothing : (popfirst! (stack), stack)
227- in (item, tree:: Node ) = any (t -> t == item, tree)
228- length (tree:: Node ) = sum (Returns (1 ), tree)
234+ isempty (:: AbstractNode ) = false
235+ iterate (root:: AbstractNode ) = (root, collect (root)[(begin + 1 ): end ])
236+ iterate (:: AbstractNode , stack) = isempty (stack) ? nothing : (popfirst! (stack), stack)
237+ in (item, tree:: AbstractNode ) = any (t -> t == item, tree)
238+ length (tree:: AbstractNode ) = sum (Returns (1 ), tree)
229239function hash (tree:: Node{T} ) where {T}
230240 return tree_mapreduce (
231241 t -> t. constant ? hash ((0 , t. val:: T )) : hash ((1 , t. feature)),
@@ -299,11 +309,11 @@ end
299309
300310for func in (:reduce , :foldl , :foldr , :mapfoldl , :mapfoldr )
301311 @eval begin
302- function $func (f, tree:: Node ; kws... )
312+ function $func (f, tree:: AbstractNode ; kws... )
303313 throw (
304314 error (
305315 string ($ func) *
306- " not implemented for Node . Use `tree_mapreduce` instead." ,
316+ " not implemented for AbstractNode . Use `tree_mapreduce` instead." ,
307317 ),
308318 )
309319 end
0 commit comments