Skip to content

Commit 73cbeae

Browse files
committed
Rigorous testing of fast compilation branches
1 parent 3152c9c commit 73cbeae

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

test/test_evaluation.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ end
162162
@testset "Test many operators" begin
163163
# Since we use `@nif` in evaluating expressions,
164164
# we can see if there are any issues with LARGE numbers of operators.
165-
num_ops = 20
165+
num_ops = 100
166166
binary_operators = [@eval function (x, y)
167167
return x + y
168168
end for i in 1:num_ops]
@@ -180,8 +180,31 @@ end
180180
# = (3.0 + x2)^2
181181
X = randn(Float64, 2, 10)
182182
truth = @. (3.0 + X[2, :])^2
183-
@test all(truth . tree(X, operators))
183+
@test truth tree(X, operators)
184184

185185
VERSION >= v"1.9" &&
186186
@test_warn "You have passed over 15 unary" OperatorEnum(; unary_operators)
187+
188+
# This OperatorEnum will trigger the fallback code for fast compilation.
189+
many_ops_operators = OperatorEnum(;
190+
binary_operators=cat([+, -, *, /], binary_operators; dims=1),
191+
unary_operators=cat([sin, cos], unary_operators; dims=1),
192+
)
193+
194+
# This OperatorEnum will go through the regular evaluation code.
195+
only_basic_ops_operator = OperatorEnum(;
196+
binary_operators=[+, -, *, /], unary_operators=[sin, cos]
197+
)
198+
199+
# We want to compare them:
200+
num_tests = 100
201+
n_features = 3
202+
for _ in 1:num_tests
203+
tree = gen_random_tree_fixed_size(20, only_basic_ops_operator, n_features, Float64)
204+
X = randn(Float64, n_features, 10)
205+
basic_eval = tree(X, only_basic_ops_operator)
206+
many_ops_eval = tree(X, many_ops_operators)
207+
@test (all(isnan, basic_eval) && all(isnan, many_ops_eval)) ||
208+
basic_eval many_ops_eval
209+
end
187210
end

0 commit comments

Comments
 (0)