@@ -8,50 +8,42 @@ operators = OperatorEnum(;
88 default_params... , binary_operators= (+ , * , / , - ), unary_operators= (cos, sin)
99)
1010
11- # Here, we unittest the fast function evaluation scheme
12- # We need to trigger all possible fused functions, with all their logic.
13- # These are as follows:
14-
15- # # We fuse (and compile) the following:
16- # # - op(op2(x, y)), where x, y, z are constants or variables.
17- # # - op(op2(x)), where x is a constant or variable.
18- # # - op(x), for any x.
19- # # We fuse (and compile) the following:
20- # # - op(x, y), where x, y are constants or variables.
21- # # - op(x, y), where x is a constant or variable but y is not.
22- # # - op(x, y), where y is a constant or variable but x is not.
23- # # - op(x, y), for any x or y
24- for fnc in [
11+ functions = [
2512 # deg2_l0_r0_eval
2613 (x1, x2, x3) -> x1 * x2,
27- (x1, x2, x3) -> x1 * 3.0f0 ,
28- (x1, x2, x3) -> 3.0f0 * x2,
29- (((x1, x2, x3) -> 3.0f0 * 6.0f0 ), ((x1, x2, x3) -> Node (; val= 3.0f0 ) * 6.0f0 )),
14+ (x1, x2, x3) -> x1 * 3.0 ,
15+ (x1, x2, x3) -> 3.0 * x2,
16+ (((x1, x2, x3) -> 3.0 * 6.0 ), ((x1, x2, x3) -> Node (; val= 3.0 ) * 6.0 )),
3017 # deg2_l0_eval
3118 (x1, x2, x3) -> x1 * sin (x2),
32- (x1, x2, x3) -> 3.0f0 * sin (x2),
19+ (x1, x2, x3) -> 3.0 * sin (x2),
3320
3421 # deg2_r0_eval
3522 (x1, x2, x3) -> sin (x1) * x2,
36- (x1, x2, x3) -> sin (x1) * 3.0f0 ,
23+ (x1, x2, x3) -> sin (x1) * 3.0 ,
3724
3825 # deg1_l2_ll0_lr0_eval
3926 (x1, x2, x3) -> cos (x1 * x2),
40- (x1, x2, x3) -> cos (x1 * 3.0f0 ),
41- (x1, x2, x3) -> cos (3.0f0 * x2),
27+ (x1, x2, x3) -> cos (x1 * 3.0 ),
28+ (x1, x2, x3) -> cos (3.0 * x2),
4229 (
43- ((x1, x2, x3) -> cos (3.0f0 * - 0.5f0 )),
44- ((x1, x2, x3) -> cos (Node (2 , Node (; val= 3.0f0 ), Node (; val= - 0.5f0 )))),
30+ ((x1, x2, x3) -> cos (3.0 * - 0.5 )),
31+ ((x1, x2, x3) -> cos (Node (2 , Node (; val= 3.0 ), Node (; val= - 0.5 )))),
4532 ),
4633
4734 # deg1_l1_ll0_eval
4835 (x1, x2, x3) -> cos (sin (x1)),
49- (((x1, x2, x3) -> cos (sin (3.0f0 ))), ((x1, x2, x3) -> cos (sin (Node (; val= 3.0f0 ))))),
36+ (((x1, x2, x3) -> cos (sin (3.0 ))), ((x1, x2, x3) -> cos (sin (Node (; val= 3.0 ))))),
5037
5138 # everything else:
52- (x1, x2, x3) -> (sin (cos (sin (cos (x1) * x3) * 3.0f0 ) * - 0.5f0 ) + 2.0f0 ) * 5.0f0 ,
39+ (x1, x2, x3) -> (sin (cos (sin (cos (x1) * x3) * 3.0 ) * - 0.5 ) + 2.0 ) * 5.0 ,
5340]
5441
42+ for turbo in [false , true ], T in [Float16, Float32, Float64], fnc in functions
43+
44+ # Float16 not implemented:
45+ turbo && T == Float16 && continue
46+
5547 # check if fnc is tuple
5648 if typeof (fnc) <: Tuple
5749 realfnc = fnc[1 ]
@@ -61,40 +53,49 @@ for fnc in [
6153 nodefnc = fnc
6254 end
6355
64- global tree = nodefnc (Node (" x1" ), Node (" x2" ), Node (" x3" ))
56+ local tree, X
57+ tree = nodefnc (Node (" x1" ), Node (" x2" ), Node (" x3" ))
58+ tree = convert (Node{T}, tree)
6559
6660 N = 100
6761 nfeatures = 3
68- X = randn (MersenneTwister (0 ), Float32 , nfeatures, N)
62+ X = randn (MersenneTwister (0 ), T , nfeatures, N)
6963
70- test_y = eval_tree_array (tree, X, operators)[1 ]
64+ test_y = eval_tree_array (tree, X, operators; turbo = turbo )[1 ]
7165 true_y = realfnc .(X[1 , :], X[2 , :], X[3 , :])
7266
73- zero_tolerance = 1e-6
67+ zero_tolerance = (T == Float16 ? 1e-4 : 1e-6 )
7468 @test all (abs .(test_y .- true_y) / N .< zero_tolerance)
7569end
7670
77- # Test specific branches of evaluation code:
78- # op(op(<constant>))
79- tree = Node (1 , Node (1 , Node (; val= 3.0f0 )))
80- @test repr (tree) == " cos(cos(3.0))"
81- truth = cos (cos (3.0f0 ))
82- @test DynamicExpressions. EvaluateEquationModule. deg1_l1_ll0_eval (
83- tree, [0.0f0 ]' , Val (1 ), Val (1 ), operators
84- )[1 ][1 ] ≈ truth
85-
86- # op(<constant>, <constant>)
87- tree = Node (1 , Node (; val= 3.0f0 ), Node (; val= 4.0f0 ))
88- @test repr (tree) == " (3.0 + 4.0)"
89- truth = 3.0f0 + 4.0f0
90- @test DynamicExpressions. EvaluateEquationModule. deg2_l0_r0_eval (
91- tree, [0.0f0 ]' , Val (1 ), operators
92- )[1 ][1 ] ≈ truth
93-
94- # op(op(<constant>, <constant>))
95- tree = Node (1 , Node (1 , Node (; val= 3.0f0 ), Node (; val= 4.0f0 )))
96- @test repr (tree) == " cos(3.0 + 4.0)"
97- truth = cos (3.0f0 + 4.0f0 )
98- @test DynamicExpressions. EvaluateEquationModule. deg1_l2_ll0_lr0_eval (
99- tree, [0.0f0 ]' , Val (1 ), Val (1 ), operators
100- )[1 ][1 ] ≈ truth
71+ for turbo in [false , true ], T in [Float16, Float32, Float64]
72+ turbo && T == Float16 && continue
73+ # Test specific branches of evaluation code:
74+ # op(op(<constant>))
75+ local tree
76+ tree = Node (1 , Node (1 , Node (; val= 3.0f0 )))
77+ @test repr (tree) == " cos(cos(3.0))"
78+ tree = convert (Node{T}, tree)
79+ truth = cos (cos (T (3.0f0 )))
80+ @test DynamicExpressions. EvaluateEquationModule. deg1_l1_ll0_eval (
81+ tree, [zero (T)]' , Val (1 ), Val (1 ), operators, Val (turbo)
82+ )[1 ][1 ] ≈ truth
83+
84+ # op(<constant>, <constant>)
85+ tree = Node (1 , Node (; val= 3.0f0 ), Node (; val= 4.0f0 ))
86+ @test repr (tree) == " (3.0 + 4.0)"
87+ tree = convert (Node{T}, tree)
88+ truth = T (3.0f0 ) + T (4.0f0 )
89+ @test DynamicExpressions. EvaluateEquationModule. deg2_l0_r0_eval (
90+ tree, [zero (T)]' , Val (1 ), operators, Val (turbo)
91+ )[1 ][1 ] ≈ truth
92+
93+ # op(op(<constant>, <constant>))
94+ tree = Node (1 , Node (1 , Node (; val= 3.0f0 ), Node (; val= 4.0f0 )))
95+ @test repr (tree) == " cos(3.0 + 4.0)"
96+ tree = convert (Node{T}, tree)
97+ truth = cos (T (3.0f0 ) + T (4.0f0 ))
98+ @test DynamicExpressions. EvaluateEquationModule. deg1_l2_ll0_lr0_eval (
99+ tree, [zero (T)]' , Val (1 ), Val (1 ), operators, Val (turbo)
100+ )[1 ][1 ] ≈ truth
101+ end
0 commit comments