Skip to content

Commit 4a231b9

Browse files
committed
test: simple test of chain rules
1 parent 1aabe1d commit 4a231b9

File tree

3 files changed

+61
-1
lines changed

3 files changed

+61
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ DynamicExpressionsZygoteExt = "Zygote"
3030
[compat]
3131
Aqua = "0.7"
3232
Bumper = "0.6"
33+
ChainRulesCore = "1"
3334
Compat = "3.37, 4"
3435
Enzyme = "^0.11.12"
3536
LoopVectorization = "0.12"
@@ -51,11 +52,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5152
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
5253
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
5354
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
55+
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
5456
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
5557
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
5658
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
5759
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5860
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5961

6062
[targets]
61-
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]
63+
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "Suppressor", "SymbolicUtils", "Zygote"]

test/test_chainrules.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using Test
2+
using DynamicExpressions
3+
using Random: MersenneTwister
4+
using ForwardDiff: gradient as fd_gradient
5+
using Zygote: gradient as zg_gradient
6+
using Suppressor: @suppress_err
7+
include("test_params.jl")
8+
include("tree_gen_utils.jl")
9+
10+
let
11+
rng = MersenneTwister(0)
12+
n_features = 5
13+
operators = OperatorEnum(; binary_operators=(+, *, -), unary_operators=(sin,))
14+
tree = gen_random_tree_fixed_size(20, operators, n_features, Float64, Node, rng)
15+
X = rand(rng, Float64, n_features, 100)
16+
17+
function f(X)
18+
y, _ = eval_tree_array(tree, X, operators)
19+
return sum(i -> y[i]^2, eachindex(y))
20+
end
21+
22+
@suppress_err begin
23+
# Check zg_gradient against fd_gradient; the latter of which is computed explicitly
24+
@test isapprox([only(zg_gradient(f, X))...], [fd_gradient(f, X)...]; atol=1e-6)
25+
end
26+
end
27+
28+
mean(x) = sum(x) / length(x)
29+
30+
let
31+
operators = OperatorEnum(; binary_operators=(+, *, -), unary_operators=(sin,))
32+
x1, x2, x3 = [Node{Float64}(; feature=i) for i in 1:3]
33+
tree = sin(x1 * 3.2 - 0.9) + 0.2 * x2 - x3
34+
X = [
35+
1.0 2.0 3.0
36+
4.0 5.0 6.0
37+
7.0 8.0 9.0
38+
]
39+
function eval_tree(X, tree)
40+
y, _ = eval_tree_array(tree, X, operators)
41+
return mean(y)
42+
end
43+
44+
function true_eval_tree(X, c)
45+
y = @. sin(X[1, :] * c[1] - c[2]) + c[3] * X[2, :] - X[3, :]
46+
return mean(y)
47+
end
48+
49+
evaluated_gradient = zg_gradient(tree -> eval_tree(X, tree), tree)[1]
50+
true_gradient = fd_gradient(c -> true_eval_tree(X, c), [3.2, 0.9, 0.2])
51+
52+
@test evaluated_gradient.tree == tree
53+
@test isapprox(evaluated_gradient.gradient, true_gradient)
54+
end

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ end
3131
include("test_derivatives.jl")
3232
end
3333

34+
@safetestset "Test chain rules" begin
35+
include("test_chainrules.jl")
36+
end
37+
3438
@safetestset "Test undefined derivatives" begin
3539
include("test_undefined_derivatives.jl")
3640
end

0 commit comments

Comments
 (0)