Skip to content

Commit 6211067

Browse files
authored
Merge pull request #71 from SymbolicML/chainrules-core
Add ChainRules support
2 parents f68d92e + 749385a commit 6211067

File tree

6 files changed

+229
-2
lines changed

6 files changed

+229
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
44
version = "0.16.0"
55

66
[deps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
910
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
@@ -29,6 +30,7 @@ DynamicExpressionsZygoteExt = "Zygote"
2930
[compat]
3031
Aqua = "0.7"
3132
Bumper = "0.6"
33+
ChainRulesCore = "1"
3234
Compat = "3.37, 4"
3335
Enzyme = "^0.11.12"
3436
LoopVectorization = "0.12"
@@ -58,4 +60,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5860
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5961

6062
[targets]
61-
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Suppressor", "Zygote"]
63+
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "Suppressor", "SymbolicUtils", "Zygote"]

src/ChainRules.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
module ChainRulesModule
2+
3+
using ChainRulesCore:
4+
ChainRulesCore, AbstractTangent, NoTangent, ZeroTangent, Tangent, @thunk, canonicalize
5+
using ..OperatorEnumModule: OperatorEnum
6+
using ..NodeModule: AbstractExpressionNode, with_type_parameters, tree_mapreduce
7+
using ..EvaluateModule: eval_tree_array
8+
using ..EvaluateDerivativeModule: eval_grad_tree_array
9+
10+
struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: AbstractTangent
11+
tree::N
12+
gradient::A
13+
end
14+
function Base.:+(a::NodeTangent, b::NodeTangent)
15+
@assert a.tree == b.tree
16+
return NodeTangent(a.tree, a.gradient + b.gradient)
17+
end
18+
Base.:*(a::Number, b::NodeTangent) = NodeTangent(b.tree, a * b.gradient)
19+
Base.:*(a::NodeTangent, b::Number) = NodeTangent(a.tree, a.gradient * b)
20+
Base.zero(::Union{Type{NodeTangent},NodeTangent}) = ZeroTangent()
21+
22+
function ChainRulesCore.rrule(
23+
::typeof(eval_tree_array),
24+
tree::AbstractExpressionNode,
25+
X::AbstractMatrix,
26+
operators::OperatorEnum;
27+
turbo=Val(false),
28+
bumper=Val(false),
29+
)
30+
primal, complete = eval_tree_array(tree, X, operators; turbo, bumper)
31+
32+
if !complete
33+
primal .= NaN
34+
end
35+
36+
# TODO: Preferable to use the primal in the pullback somehow
37+
function pullback((dY, _))
38+
dtree = let X = X, dY = dY, tree = tree, operators = operators
39+
@thunk(
40+
let
41+
_, gradient, complete = eval_grad_tree_array(
42+
tree, X, operators; variable=Val(false)
43+
)
44+
if !complete
45+
gradient .= NaN
46+
end
47+
48+
NodeTangent(
49+
tree,
50+
sum(j -> gradient[:, j] * dY[j], eachindex(dY, axes(gradient, 2))),
51+
)
52+
end
53+
)
54+
end
55+
dX = let X = X, dY = dY, tree = tree, operators = operators
56+
@thunk(
57+
let
58+
_, gradient, complete = eval_grad_tree_array(
59+
tree, X, operators; variable=Val(true)
60+
)
61+
if !complete
62+
gradient .= NaN
63+
end
64+
65+
gradient .* reshape(dY, 1, length(dY))
66+
end
67+
)
68+
end
69+
return (NoTangent(), dtree, dX, NoTangent())
70+
end
71+
72+
return (primal, complete), pullback
73+
end
74+
75+
end

src/DynamicExpressions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ include("NodeUtils.jl")
88
include("Strings.jl")
99
include("Evaluate.jl")
1010
include("EvaluateDerivative.jl")
11+
include("ChainRules.jl")
1112
include("EvaluationHelpers.jl")
1213
include("Simplify.jl")
1314
include("OperatorEnumConstruction.jl")
@@ -42,6 +43,7 @@ import .NodeModule: constructorof, preserve_sharing
4243
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
4344
@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array
4445
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
46+
@reexport import .ChainRulesModule: NodeTangent
4547
@reexport import .SimplifyModule: combine_operators, simplify_tree!
4648
@reexport import .EvaluationHelpersModule
4749
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node

test/test_chainrules.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
using Test
2+
using DynamicExpressions
3+
using Random: MersenneTwister
4+
using ChainRulesCore: ChainRulesCore, ZeroTangent, NoTangent
5+
using ForwardDiff: gradient as fd_gradient
6+
using Zygote: gradient as zg_gradient
7+
using Suppressor: @suppress_err
8+
include("test_params.jl")
9+
include("tree_gen_utils.jl")
10+
11+
let
12+
rng = MersenneTwister(0)
13+
n_features = 5
14+
operators = OperatorEnum(; binary_operators=(+, *, -), unary_operators=(sin,))
15+
tree = gen_random_tree_fixed_size(20, operators, n_features, Float64, Node, rng)
16+
X = rand(rng, Float64, n_features, 100)
17+
18+
function f(X)
19+
y, _ = eval_tree_array(tree, X, operators)
20+
return sum(i -> y[i]^2, eachindex(y))
21+
end
22+
23+
@suppress_err begin
24+
# Check zg_gradient against fd_gradient; the latter of which is computed explicitly
25+
@test isapprox([only(zg_gradient(f, X))...], [fd_gradient(f, X)...]; atol=1e-6)
26+
end
27+
end
28+
29+
mean(x) = sum(x) / length(x)
30+
31+
let
32+
operators = OperatorEnum(; binary_operators=(+, *, -), unary_operators=(sin,))
33+
x1, x2, x3 = [Node{Float64}(; feature=i) for i in 1:3]
34+
tree = sin(x1 * 3.2 - 0.9) + 0.2 * x2 - x3
35+
X = [
36+
1.0 2.0 3.0
37+
4.0 5.0 6.0
38+
7.0 8.0 9.0
39+
]
40+
function eval_tree(X, tree)
41+
y, _ = eval_tree_array(tree, X, operators)
42+
return mean(y)
43+
end
44+
45+
function true_eval_tree(X, c)
46+
y = @. sin(X[1, :] * c[1] - c[2]) + c[3] * X[2, :] - X[3, :]
47+
return mean(y)
48+
end
49+
50+
evaluated_gradient = zg_gradient(tree -> eval_tree(X, tree), tree)[1]
51+
true_gradient = fd_gradient(c -> true_eval_tree(X, c), [3.2, 0.9, 0.2])
52+
53+
@test evaluated_gradient.tree == tree
54+
@test isapprox(evaluated_gradient.gradient, true_gradient)
55+
56+
# Misc tests of uncovered portions
57+
let tree = tree,
58+
X = X,
59+
evaluated_gradient = evaluated_gradient,
60+
true_gradient = true_gradient
61+
62+
evaluated_gradient_2 = zg_gradient(tree -> eval_tree(X, tree), tree)[1]
63+
true_gradient_2 = fd_gradient(c -> true_eval_tree(X, c), [3.2, 0.9, 0.2])
64+
65+
evaluated_aggregate = evaluated_gradient + evaluated_gradient_2
66+
true_aggregate = true_gradient + true_gradient_2
67+
@test evaluated_aggregate.tree == tree
68+
@test isapprox(evaluated_aggregate.gradient, true_aggregate)
69+
70+
scalar_prod = evaluated_gradient * 2.0
71+
scalar_prod2 = 2.0 * (1.0 * evaluated_gradient)
72+
true_scalar_prod = true_gradient * 2.0
73+
@test scalar_prod.tree == tree
74+
@test isapprox(scalar_prod.gradient, true_scalar_prod)
75+
@test isapprox(scalar_prod2.gradient, true_scalar_prod)
76+
77+
# Should be able to use with other types
78+
@test zero(evaluated_gradient) == ZeroTangent()
79+
80+
@test evaluated_gradient + ZeroTangent() == evaluated_gradient
81+
@test evaluated_gradient + NoTangent() == evaluated_gradient
82+
end
83+
end
84+
85+
# Operator that is NaN for forward pass
86+
bad_op(x) = x > 0.0 ? log(x) : convert(typeof(x), NaN)
87+
# And operator that is undefined for backward pass
88+
undefined_grad_op(x) = x >= 0.0 ? x : zero(x)
89+
# And operator that gives a NaN for backward pass
90+
bad_grad_op(x) = x
91+
92+
function ChainRulesCore.rrule(::typeof(bad_grad_op), x)
93+
return bad_grad_op(x), (_) -> (NoTangent(), convert(typeof(x), NaN))
94+
end
95+
96+
# Also test NaN modes
97+
let
98+
operators = OperatorEnum(;
99+
binary_operators=(+, *, -),
100+
unary_operators=(sin, bad_op, bad_grad_op, undefined_grad_op),
101+
)
102+
@extend_operators operators
103+
x1 = Node(Float64; feature=1)
104+
105+
nan_forward = bad_op(x1 + 0.5)
106+
undefined_grad = undefined_grad_op(x1 + 0.5)
107+
nan_grad = bad_grad_op(x1)
108+
109+
function eval_tree(X, tree)
110+
y, _ = eval_tree_array(tree, X, operators)
111+
return mean(y)
112+
end
113+
X = ones(1, 1) * -1.0
114+
115+
# Forward pass is NaN; Gradient will also be NaN
116+
@test isnan(only(eval_tree(X, nan_forward)))
117+
evaluated_gradient = zg_gradient(X -> eval_tree(X, nan_forward), X)[1]
118+
@test isnan(only(evaluated_gradient))
119+
120+
# Both forward and gradient are not NaN despite giving `nothing` back
121+
@test !isnan(only(eval_tree(X, undefined_grad)))
122+
evaluated_gradient = zg_gradient(X -> eval_tree(X, undefined_grad), X)[1]
123+
@test iszero(only(evaluated_gradient))
124+
125+
# Finally, the operator with a NaN gradient but non-NaN forward
126+
@test !isnan(only(eval_tree(X, nan_grad)))
127+
evaluated_gradient = zg_gradient(X -> eval_tree(X, nan_grad), X)[1]
128+
@test isnan(only(evaluated_gradient))
129+
evaluated_gradient = zg_gradient(t -> eval_tree(X, t), nan_grad)[1]
130+
@show evaluated_gradient
131+
# @test isnan(only(evaluated_gradient.gradient))
132+
end

test/test_optim.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2
1212
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2
1313

1414
f(tree) = sum(abs2, tree(X, operators) .- y)
15+
function g!(G, tree)
16+
dy = only(gradient(f, tree))
17+
G .= dy.gradient
18+
return nothing
19+
end
1520

1621
@testset "Basic optimization" begin
1722
tree = copy(original_tree)
@@ -26,7 +31,14 @@ f(tree) = sum(abs2, tree(X, operators) .- y)
2631
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
2732
end
2833

29-
@testset "With gradients" begin
34+
@testset "With gradients, using Zygote" begin
35+
tree = copy(original_tree)
36+
res = optimize(f, g!, tree, BFGS())
37+
@test tree == original_tree
38+
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
39+
end
40+
41+
@testset "With gradients, manually" begin
3042
tree = copy(original_tree)
3143
did_i_run = Ref(false)
3244
# Now, try with gradients too (via Zygote and our hand-rolled forward-mode AD)

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ end
3131
include("test_derivatives.jl")
3232
end
3333

34+
@safetestset "Test chain rules" begin
35+
include("test_chainrules.jl")
36+
end
37+
3438
@safetestset "Test undefined derivatives" begin
3539
include("test_undefined_derivatives.jl")
3640
end

0 commit comments

Comments
 (0)