Skip to content

Commit 12df3e4

Browse files
committed
Fix error in benchmark
1 parent 3d8769c commit 12df3e4

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

benchmark/benchmarks.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,23 @@ function benchmark_evaluation()
3737

3838
(turbo || bumper) && !(T in (Float32, Float64)) && continue
3939
turbo && bumper && continue
40-
bumper && PACKAGE_VERSION < v"0.15.0" && continue
40+
if bumper
41+
try
42+
eval_tree_array(Node{T}(val=1.0), ones(T, 5, n), operators; turbo, bumper)
43+
catch e
44+
isa(e, MethodError) || rethrow(e)
45+
@warn "Skipping bumper tests"
46+
continue # Assume its not available
47+
end
48+
end
4149

4250
extra_key = turbo ? "_turbo" : (bumper ? "_bumper" : "")
4351
extra_kws = bumper ? (; bumper=Val(true)) : ()
4452
eval_tree_array(
4553
gen_random_tree_fixed_size(20, operators, 5, T),
4654
randn(MersenneTwister(0), T, 5, n),
4755
operators;
48-
turbo=turbo,
56+
turbo,
4957
extra_kws...
5058
)
5159
suite[T]["evaluation$(extra_key)"] = @benchmarkable(

src/EvaluateEquation.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,14 @@ function eval_tree_array(
6464
cX::AbstractMatrix{T},
6565
operators::OperatorEnum;
6666
turbo::Union{Bool,Val}=Val(false),
67-
bumper::Val=Val(false),
67+
bumper::Union{Bool,Val}=Val(false),
6868
) where {T<:Number}
69-
v_turbo = if isa(turbo, Val)
70-
turbo
71-
else
72-
turbo ? Val(true) : Val(false)
73-
end
74-
if v_turbo isa Val{true} || bumper isa Val{true}
69+
v_turbo = isa(turbo, Val) ? turbo : (turbo ? Val(true) : Val(false))
70+
v_bumper = isa(bumper, Val) ? bumper : (bumper ? Val(true) : Val(false))
71+
if v_turbo isa Val{true} || v_bumper isa Val{true}
7572
@assert T in (Float32, Float64)
7673
end
77-
if bumper isa Val{true}
74+
if v_bumper isa Val{true}
7875
return bumper_eval_tree_array(tree, cX, operators)
7976
end
8077
if v_turbo isa Val{true}
@@ -89,13 +86,14 @@ function eval_tree_array(
8986
tree::AbstractExpressionNode{T1},
9087
cX::AbstractMatrix{T2},
9188
operators::OperatorEnum;
92-
kws...,
89+
turbo::Union{Bool,Val}=Val(false),
90+
bumper::Union{Bool,Val}=Val(false),
9391
) where {T1<:Number,T2<:Number}
9492
T = promote_type(T1, T2)
9593
@warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)."
9694
tree = convert(constructorof(typeof(tree)){T}, tree)
9795
cX = Base.Fix1(convert, T).(cX)
98-
return eval_tree_array(tree, cX, operators; kws...)
96+
return eval_tree_array(tree, cX, operators; turbo, bumper)
9997
end
10098

10199
get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U)

0 commit comments

Comments
 (0)