Skip to content

Commit 6ce0d3a

Browse files
committed
Test turbo in evaluation with different types
1 parent 5718dd3 commit 6ce0d3a

File tree

1 file changed

+54
-53
lines changed

1 file changed

+54
-53
lines changed

test/test_evaluation.jl

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,50 +8,42 @@ operators = OperatorEnum(;
88
default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin)
99
)
1010

11-
# Here, we unittest the fast function evaluation scheme
12-
# We need to trigger all possible fused functions, with all their logic.
13-
# These are as follows:
14-
15-
## We fuse (and compile) the following:
16-
## - op(op2(x, y)), where x, y, z are constants or variables.
17-
## - op(op2(x)), where x is a constant or variable.
18-
## - op(x), for any x.
19-
## We fuse (and compile) the following:
20-
## - op(x, y), where x, y are constants or variables.
21-
## - op(x, y), where x is a constant or variable but y is not.
22-
## - op(x, y), where y is a constant or variable but x is not.
23-
## - op(x, y), for any x or y
24-
for fnc in [
11+
functions = [
2512
# deg2_l0_r0_eval
2613
(x1, x2, x3) -> x1 * x2,
27-
(x1, x2, x3) -> x1 * 3.0f0,
28-
(x1, x2, x3) -> 3.0f0 * x2,
29-
(((x1, x2, x3) -> 3.0f0 * 6.0f0), ((x1, x2, x3) -> Node(; val=3.0f0) * 6.0f0)),
14+
(x1, x2, x3) -> x1 * 3.0,
15+
(x1, x2, x3) -> 3.0 * x2,
16+
(((x1, x2, x3) -> 3.0 * 6.0), ((x1, x2, x3) -> Node(; val=3.0) * 6.0)),
3017
# deg2_l0_eval
3118
(x1, x2, x3) -> x1 * sin(x2),
32-
(x1, x2, x3) -> 3.0f0 * sin(x2),
19+
(x1, x2, x3) -> 3.0 * sin(x2),
3320

3421
# deg2_r0_eval
3522
(x1, x2, x3) -> sin(x1) * x2,
36-
(x1, x2, x3) -> sin(x1) * 3.0f0,
23+
(x1, x2, x3) -> sin(x1) * 3.0,
3724

3825
# deg1_l2_ll0_lr0_eval
3926
(x1, x2, x3) -> cos(x1 * x2),
40-
(x1, x2, x3) -> cos(x1 * 3.0f0),
41-
(x1, x2, x3) -> cos(3.0f0 * x2),
27+
(x1, x2, x3) -> cos(x1 * 3.0),
28+
(x1, x2, x3) -> cos(3.0 * x2),
4229
(
43-
((x1, x2, x3) -> cos(3.0f0 * -0.5f0)),
44-
((x1, x2, x3) -> cos(Node(2, Node(; val=3.0f0), Node(; val=-0.5f0)))),
30+
((x1, x2, x3) -> cos(3.0 * -0.5)),
31+
((x1, x2, x3) -> cos(Node(2, Node(; val=3.0), Node(; val=-0.5)))),
4532
),
4633

4734
# deg1_l1_ll0_eval
4835
(x1, x2, x3) -> cos(sin(x1)),
49-
(((x1, x2, x3) -> cos(sin(3.0f0))), ((x1, x2, x3) -> cos(sin(Node(; val=3.0f0))))),
36+
(((x1, x2, x3) -> cos(sin(3.0))), ((x1, x2, x3) -> cos(sin(Node(; val=3.0))))),
5037

5138
# everything else:
52-
(x1, x2, x3) -> (sin(cos(sin(cos(x1) * x3) * 3.0f0) * -0.5f0) + 2.0f0) * 5.0f0,
39+
(x1, x2, x3) -> (sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0,
5340
]
5441

42+
for turbo in [false, true], T in [Float16, Float32, Float64], fnc in functions
43+
44+
# Float16 not implemented:
45+
turbo && T == Float16 && continue
46+
5547
# check if fnc is tuple
5648
if typeof(fnc) <: Tuple
5749
realfnc = fnc[1]
@@ -61,40 +53,49 @@ for fnc in [
6153
nodefnc = fnc
6254
end
6355

64-
global tree = nodefnc(Node("x1"), Node("x2"), Node("x3"))
56+
local tree, X
57+
tree = nodefnc(Node("x1"), Node("x2"), Node("x3"))
58+
tree = convert(Node{T}, tree)
6559

6660
N = 100
6761
nfeatures = 3
68-
X = randn(MersenneTwister(0), Float32, nfeatures, N)
62+
X = randn(MersenneTwister(0), T, nfeatures, N)
6963

70-
test_y = eval_tree_array(tree, X, operators)[1]
64+
test_y = eval_tree_array(tree, X, operators; turbo=turbo)[1]
7165
true_y = realfnc.(X[1, :], X[2, :], X[3, :])
7266

73-
zero_tolerance = 1e-6
67+
zero_tolerance = (T == Float16 ? 1e-4 : 1e-6)
7468
@test all(abs.(test_y .- true_y) / N .< zero_tolerance)
7569
end
7670

77-
# Test specific branches of evaluation code:
78-
# op(op(<constant>))
79-
tree = Node(1, Node(1, Node(; val=3.0f0)))
80-
@test repr(tree) == "cos(cos(3.0))"
81-
truth = cos(cos(3.0f0))
82-
@test DynamicExpressions.EvaluateEquationModule.deg1_l1_ll0_eval(
83-
tree, [0.0f0]', Val(1), Val(1), operators
84-
)[1][1] truth
85-
86-
# op(<constant>, <constant>)
87-
tree = Node(1, Node(; val=3.0f0), Node(; val=4.0f0))
88-
@test repr(tree) == "(3.0 + 4.0)"
89-
truth = 3.0f0 + 4.0f0
90-
@test DynamicExpressions.EvaluateEquationModule.deg2_l0_r0_eval(
91-
tree, [0.0f0]', Val(1), operators
92-
)[1][1] truth
93-
94-
# op(op(<constant>, <constant>))
95-
tree = Node(1, Node(1, Node(; val=3.0f0), Node(; val=4.0f0)))
96-
@test repr(tree) == "cos(3.0 + 4.0)"
97-
truth = cos(3.0f0 + 4.0f0)
98-
@test DynamicExpressions.EvaluateEquationModule.deg1_l2_ll0_lr0_eval(
99-
tree, [0.0f0]', Val(1), Val(1), operators
100-
)[1][1] truth
71+
for turbo in [false, true], T in [Float16, Float32, Float64]
72+
turbo && T == Float16 && continue
73+
# Test specific branches of evaluation code:
74+
# op(op(<constant>))
75+
local tree
76+
tree = Node(1, Node(1, Node(; val=3.0f0)))
77+
@test repr(tree) == "cos(cos(3.0))"
78+
tree = convert(Node{T}, tree)
79+
truth = cos(cos(T(3.0f0)))
80+
@test DynamicExpressions.EvaluateEquationModule.deg1_l1_ll0_eval(
81+
tree, [zero(T)]', Val(1), Val(1), operators, Val(turbo)
82+
)[1][1] truth
83+
84+
# op(<constant>, <constant>)
85+
tree = Node(1, Node(; val=3.0f0), Node(; val=4.0f0))
86+
@test repr(tree) == "(3.0 + 4.0)"
87+
tree = convert(Node{T}, tree)
88+
truth = T(3.0f0) + T(4.0f0)
89+
@test DynamicExpressions.EvaluateEquationModule.deg2_l0_r0_eval(
90+
tree, [zero(T)]', Val(1), operators, Val(turbo)
91+
)[1][1] truth
92+
93+
# op(op(<constant>, <constant>))
94+
tree = Node(1, Node(1, Node(; val=3.0f0), Node(; val=4.0f0)))
95+
@test repr(tree) == "cos(3.0 + 4.0)"
96+
tree = convert(Node{T}, tree)
97+
truth = cos(T(3.0f0) + T(4.0f0))
98+
@test DynamicExpressions.EvaluateEquationModule.deg1_l2_ll0_lr0_eval(
99+
tree, [zero(T)]', Val(1), Val(1), operators, Val(turbo)
100+
)[1][1] truth
101+
end

0 commit comments

Comments
 (0)