|
162 | 162 | @testset "Test many operators" begin |
163 | 163 | # Since we use `@nif` in evaluating expressions, |
164 | 164 | # we can see if there are any issues with LARGE numbers of operators. |
165 | | - num_ops = 20 |
| 165 | + num_ops = 100 |
166 | 166 | binary_operators = [@eval function (x, y) |
167 | 167 | return x + y |
168 | 168 | end for i in 1:num_ops] |
|
180 | 180 | # = (3.0 + x2)^2 |
181 | 181 | X = randn(Float64, 2, 10) |
182 | 182 | truth = @. (3.0 + X[2, :])^2 |
183 | | - @test all(truth .≈ tree(X, operators)) |
| 183 | + @test truth ≈ tree(X, operators) |
184 | 184 |
|
185 | 185 | VERSION >= v"1.9" && |
186 | 186 | @test_warn "You have passed over 15 unary" OperatorEnum(; unary_operators) |
| 187 | + |
| 188 | + # This OperatorEnum will trigger the fallback code for fast compilation. |
| 189 | + many_ops_operators = OperatorEnum(; |
| 190 | + binary_operators=cat([+, -, *, /], binary_operators; dims=1), |
| 191 | + unary_operators=cat([sin, cos], unary_operators; dims=1), |
| 192 | + ) |
| 193 | + |
| 194 | + # This OperatorEnum will go through the regular evaluation code. |
| 195 | + only_basic_ops_operator = OperatorEnum(; |
| 196 | + binary_operators=[+, -, *, /], unary_operators=[sin, cos] |
| 197 | + ) |
| 198 | + |
| 199 | + # We want to compare them: |
| 200 | + num_tests = 100 |
| 201 | + n_features = 3 |
| 202 | + for _ in 1:num_tests |
| 203 | + tree = gen_random_tree_fixed_size(20, only_basic_ops_operator, n_features, Float64) |
| 204 | + X = randn(Float64, n_features, 10) |
| 205 | + basic_eval = tree(X, only_basic_ops_operator) |
| 206 | + many_ops_eval = tree(X, many_ops_operators) |
| 207 | + @test (all(isnan, basic_eval) && all(isnan, many_ops_eval)) || |
| 208 | + basic_eval ≈ many_ops_eval |
| 209 | + end |
187 | 210 | end |
0 commit comments