|
| 1 | +using Test |
| 2 | +using DynamicExpressions |
| 3 | +using Random: MersenneTwister |
| 4 | +using ForwardDiff: gradient as fd_gradient |
| 5 | +using Zygote: gradient as zg_gradient |
| 6 | +using Suppressor: @suppress_err |
| 7 | +include("test_params.jl") |
| 8 | +include("tree_gen_utils.jl") |
| 9 | + |
| 10 | +let |
| 11 | + rng = MersenneTwister(0) |
| 12 | + n_features = 5 |
| 13 | + operators = OperatorEnum(; binary_operators=(+, *, -), unary_operators=(sin,)) |
| 14 | + tree = gen_random_tree_fixed_size(20, operators, n_features, Float64, Node, rng) |
| 15 | + X = rand(rng, Float64, n_features, 100) |
| 16 | + |
| 17 | + function f(X) |
| 18 | + y, _ = eval_tree_array(tree, X, operators) |
| 19 | + return sum(i -> y[i]^2, eachindex(y)) |
| 20 | + end |
| 21 | + |
| 22 | + @suppress_err begin |
| 23 | + # Check zg_gradient against fd_gradient; the latter of which is computed explicitly |
| 24 | + @test isapprox([only(zg_gradient(f, X))...], [fd_gradient(f, X)...]; atol=1e-6) |
| 25 | + end |
| 26 | +end |
| 27 | + |
| 28 | +mean(x) = sum(x) / length(x) |
| 29 | + |
| 30 | +let |
| 31 | + operators = OperatorEnum(; binary_operators=(+, *, -), unary_operators=(sin,)) |
| 32 | + x1, x2, x3 = [Node{Float64}(; feature=i) for i in 1:3] |
| 33 | + tree = sin(x1 * 3.2 - 0.9) + 0.2 * x2 - x3 |
| 34 | + X = [ |
| 35 | + 1.0 2.0 3.0 |
| 36 | + 4.0 5.0 6.0 |
| 37 | + 7.0 8.0 9.0 |
| 38 | + ] |
| 39 | + function eval_tree(X, tree) |
| 40 | + y, _ = eval_tree_array(tree, X, operators) |
| 41 | + return mean(y) |
| 42 | + end |
| 43 | + |
| 44 | + function true_eval_tree(X, c) |
| 45 | + y = @. sin(X[1, :] * c[1] - c[2]) + c[3] * X[2, :] - X[3, :] |
| 46 | + return mean(y) |
| 47 | + end |
| 48 | + |
| 49 | + evaluated_gradient = zg_gradient(tree -> eval_tree(X, tree), tree)[1] |
| 50 | + true_gradient = fd_gradient(c -> true_eval_tree(X, c), [3.2, 0.9, 0.2]) |
| 51 | + |
| 52 | + @test evaluated_gradient.tree == tree |
| 53 | + @test isapprox(evaluated_gradient.gradient, true_gradient) |
| 54 | +end |
0 commit comments