Skip to content

Commit 9737763

Browse files
committed
Backport with Compat.jl
1 parent f3f3ad7 commit 9737763

File tree

3 files changed

+28
-34
lines changed

3 files changed

+28
-34
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,24 @@ authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
44
version = "0.7.0"
55

66
[deps]
7+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
910
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
11+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1012
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1113
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1214
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
13-
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1415
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1516
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[compat]
20+
Compat = "3.37, 4"
1921
LoopVectorization = "0.12"
2022
MacroTools = "0.4, 0.5"
21-
Reexport = "1"
2223
PrecompileTools = "1"
24+
Reexport = "1"
2325
SymbolicUtils = "0.19, ^1.0.5"
2426
Zygote = "0.6"
2527
julia = "1.6"

src/EquationUtils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module EquationUtilsModule
22

3+
import Compat: Returns
34
import ..EquationModule: Node, copy_node, tree_mapreduce, any, filter_and_map
45

56
"""
@@ -15,7 +16,9 @@ count_nodes(tree::Node) = tree_mapreduce(_ -> 1, +, tree)
1516
1617
Compute the max depth of the tree.
1718
"""
18-
count_depth(tree::Node) = tree_mapreduce(_ -> 1, (p, child...) -> p + max(child...), tree)
19+
function count_depth(tree::Node)
20+
return tree_mapreduce(Returns(1), (p, child...) -> p + max(child...), tree)
21+
end
1922

2023
"""
2124
is_node_constant(tree::Node)::Bool

src/tree_map.jl

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import Base:
2222
reduce,
2323
setindex!,
2424
sum
25+
import Compat: @inline, Returns
2526

2627
function reduce(f, tree::Node; init=nothing)
2728
throw(ArgumentError("reduce is not supported for trees. Use tree_mapreduce instead."))
@@ -39,20 +40,6 @@ function mapfoldr(f, tree::Node; init=nothing)
3940
throw(ArgumentError("mapfoldr is not supported for trees. Use tree_mapreduce instead."))
4041
end
4142

42-
"""Internal macro to fix @inline on Julia versions before 1.8"""
43-
macro _inline(ex)
44-
ex = _fix_inline(ex)
45-
return :($(esc(ex)))
46-
end
47-
48-
function _fix_inline(ex)
49-
if VERSION >= v"1.8"
50-
return Expr(:macrocall, Symbol("@inline"), LineNumberNode(@__LINE__), ex)
51-
else
52-
return ex
53-
end
54-
end
55-
5643
#! format: off
5744
"""
5845
tree_mapreduce(f::Function, op::Function, tree::Node)
@@ -92,21 +79,21 @@ end # Get list of constants. (regular mapreduce also works)
9279
"""
9380
function tree_mapreduce(f::F, op::G, tree::Node) where {F<:Function,G<:Function}
9481
if tree.degree == 0
95-
return @_inline(f(tree))
82+
return @inline(f(tree))
9683
elseif tree.degree == 1
97-
return op(@_inline(f(tree)), tree_mapreduce(f, op, tree.l))
84+
return op(@inline(f(tree)), tree_mapreduce(f, op, tree.l))
9885
else
99-
return op(@_inline(f(tree)), tree_mapreduce(f, op, tree.l), tree_mapreduce(f, op, tree.r))
86+
return op(@inline(f(tree)), tree_mapreduce(f, op, tree.l), tree_mapreduce(f, op, tree.r))
10087
end
10188
end
10289

10390
function mapreduce(f::F, op::G, tree::Node; init=nothing) where {F<:Function,G<:Function}
10491
if tree.degree == 0
105-
return @_inline(f(tree))
92+
return @inline(f(tree))
10693
elseif tree.degree == 1
107-
return op(@_inline(f(tree)), mapreduce(f, op, tree.l; init))
94+
return op(@inline(f(tree)), mapreduce(f, op, tree.l; init))
10895
else
109-
return op(op(@_inline(f(tree)), mapreduce(f, op, tree.l; init)), mapreduce(f, op, tree.r; init))
96+
return op(op(@inline(f(tree)), mapreduce(f, op, tree.l; init)), mapreduce(f, op, tree.r; init))
11097
end
11198
end
11299
#! format: on
@@ -132,8 +119,8 @@ end
132119
function _filter_and_map(
133120
filter_fnc::F, map_fnc::G, tree::Node, stack::Vector{GT}, pointer::Ref
134121
) where {F<:Function,G<:Function,GT}
135-
if @_inline(filter_fnc(tree))
136-
map_result = @_inline(map_fnc(tree))::GT
122+
if @inline(filter_fnc(tree))
123+
map_result = @inline(map_fnc(tree))::GT
137124
@inbounds stack[pointer.x += 1] = map_result
138125
end
139126
if tree.degree == 1
@@ -153,11 +140,11 @@ By using this instead of tree_mapreduce, we can take advantage of early exits.
153140
"""
154141
function any(f::F, tree::Node) where {F<:Function}
155142
if tree.degree == 0
156-
return @_inline(f(tree))::Bool
143+
return @inline(f(tree))::Bool
157144
elseif tree.degree == 1
158-
return @_inline(f(tree))::Bool || any(f, tree.l)
145+
return @inline(f(tree))::Bool || any(f, tree.l)
159146
else
160-
return @_inline(f(tree))::Bool || any(f, tree.l) || any(f, tree.r)
147+
return @inline(f(tree))::Bool || any(f, tree.l) || any(f, tree.r)
161148
end
162149
end
163150

@@ -201,7 +188,7 @@ function filter(f::F, tree::Node{T}) where {F<:Function,T}
201188
return filter_and_map(f, identity, tree; result_type=Node{T})
202189
end
203190

204-
collect(tree::Node) = filter(_ -> true, tree)
191+
collect(tree::Node) = filter(Returns(true), tree)
205192

206193
"""
207194
map(f::Function, tree::Node; result_type::Type{RT}=Nothing)
@@ -213,19 +200,19 @@ function map(f::F, tree::Node; result_type::Type{RT}=Nothing) where {F<:Function
213200
if RT == Nothing
214201
return f.(collect(tree))
215202
else
216-
return filter_and_map(_ -> true, f, tree; result_type=result_type)
203+
return filter_and_map(Returns(true), f, tree; result_type=result_type)
217204
end
218205
end
219206

220207
function count(f::F, tree::Node; init=0) where {F}
221-
return tree_mapreduce(t -> @_inline(f(t)) ? 1 : 0, +, tree) + init
208+
return tree_mapreduce(t -> @inline(f(t)) ? 1 : 0, +, tree) + init
222209
end
223210

224211
function sum(f::F, tree::Node; init=0) where {F}
225212
return tree_mapreduce(f, +, tree) + init
226213
end
227214

228-
all(f::F, tree::Node) where {F<:Function} = !any(t -> !@_inline(f(t)), tree)
215+
all(f::F, tree::Node) where {F<:Function} = !any(t -> !@inline(f(t)), tree)
229216

230217
function setindex!(root::Node{T}, insert::Node{T}, i::Int) where {T}
231218
set_node!(getindex(root, i), insert)
@@ -239,8 +226,10 @@ isempty(::Node) = false
239226
iterate(root::Node) = (root, collect(root)[(begin + 1):end])
240227
iterate(::Node, stack) = isempty(stack) ? nothing : (popfirst!(stack), stack)
241228
in(item, tree::Node) = any(t -> t == item, tree)
242-
length(tree::Node) = sum(_ -> 1, tree)
229+
length(tree::Node) = sum(Returns(1), tree)
243230
firstindex(::Node) = 1
244231
lastindex(tree::Node) = length(tree)
245232
keys(tree::Node) = Base.OneTo(length(tree))
246-
foreach(f::Function, tree::Node) = mapreduce(t -> (@_inline(f(t)); nothing), Returns(nothing), tree)
233+
function foreach(f::Function, tree::Node)
234+
return mapreduce(t -> (@inline(f(t)); nothing), Returns(nothing), tree)
235+
end

0 commit comments

Comments
 (0)