Skip to content

Commit 906b847

Browse files
committed
test: NaN branches of rrule
1 parent 4a231b9 commit 906b847

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

src/ChainRules.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function ChainRulesCore.rrule(
3535

3636
# TODO: Preferable to use the primal in the pullback somehow
3737
function pullback((dY, _))
38-
dtree = let dY = dY, tree = tree, operators = operators
38+
dtree = let X = X, dY = dY, tree = tree, operators = operators
3939
@thunk(
4040
let
4141
_, gradient, complete = eval_grad_tree_array(
@@ -52,7 +52,7 @@ function ChainRulesCore.rrule(
5252
end
5353
)
5454
end
55-
dX = let dY = dY, tree = tree, operators = operators
55+
dX = let X = X, dY = dY, tree = tree, operators = operators
5656
@thunk(
5757
let
5858
_, gradient, complete = eval_grad_tree_array(

test/test_chainrules.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test
22
using DynamicExpressions
33
using Random: MersenneTwister
4+
using ChainRulesCore: ChainRulesCore, ZeroTangent, NoTangent
45
using ForwardDiff: gradient as fd_gradient
56
using Zygote: gradient as zg_gradient
67
using Suppressor: @suppress_err
@@ -51,4 +52,81 @@ let
5152

5253
@test evaluated_gradient.tree == tree
5354
@test isapprox(evaluated_gradient.gradient, true_gradient)
55+
56+
# Misc tests of uncovered portions
57+
let tree = tree,
58+
X = X,
59+
evaluated_gradient = evaluated_gradient,
60+
true_gradient = true_gradient
61+
62+
evaluated_gradient_2 = zg_gradient(tree -> eval_tree(X, tree), tree)[1]
63+
true_gradient_2 = fd_gradient(c -> true_eval_tree(X, c), [3.2, 0.9, 0.2])
64+
65+
evaluated_aggregate = evaluated_gradient + evaluated_gradient_2
66+
true_aggregate = true_gradient + true_gradient_2
67+
@test evaluated_aggregate.tree == tree
68+
@test isapprox(evaluated_aggregate.gradient, true_aggregate)
69+
70+
scalar_prod = evaluated_gradient * 2.0
71+
scalar_prod2 = 2.0 * (1.0 * evaluated_gradient)
72+
true_scalar_prod = true_gradient * 2.0
73+
@test scalar_prod.tree == tree
74+
@test isapprox(scalar_prod.gradient, true_scalar_prod)
75+
@test isapprox(scalar_prod2.gradient, true_scalar_prod)
76+
77+
# Should be able to use with other types
78+
@test zero(evaluated_gradient) == ZeroTangent()
79+
80+
@test evaluated_gradient + ZeroTangent() == evaluated_gradient
81+
@test evaluated_gradient + NoTangent() == evaluated_gradient
82+
end
83+
end
84+
85+
# Operator that is NaN for forward pass
86+
bad_op(x) = x > 0.0 ? log(x) : convert(typeof(x), NaN)
87+
# And operator that is undefined for backward pass
88+
undefined_grad_op(x) = x >= 0.0 ? x : zero(x)
89+
# And operator that gives a NaN for backward pass
90+
bad_grad_op(x) = x
91+
92+
function ChainRulesCore.rrule(::typeof(bad_grad_op), x)
93+
return bad_grad_op(x), (_) -> (NoTangent(), convert(typeof(x), NaN))
94+
end
95+
96+
# Also test NaN modes
97+
let
98+
operators = OperatorEnum(;
99+
binary_operators=(+, *, -),
100+
unary_operators=(sin, bad_op, bad_grad_op, undefined_grad_op),
101+
)
102+
@extend_operators operators
103+
x1 = Node(Float64; feature=1)
104+
105+
nan_forward = bad_op(x1 + 0.5)
106+
undefined_grad = undefined_grad_op(x1 + 0.5)
107+
nan_grad = bad_grad_op(x1)
108+
109+
function eval_tree(X, tree)
110+
y, _ = eval_tree_array(tree, X, operators)
111+
return mean(y)
112+
end
113+
X = ones(1, 1) * -1.0
114+
115+
# Forward pass is NaN; Gradient will also be NaN
116+
@test isnan(only(eval_tree(X, nan_forward)))
117+
evaluated_gradient = zg_gradient(X -> eval_tree(X, nan_forward), X)[1]
118+
@test isnan(only(evaluated_gradient))
119+
120+
# Both forward and gradient are not NaN despite giving `nothing` back
121+
@test !isnan(only(eval_tree(X, undefined_grad)))
122+
evaluated_gradient = zg_gradient(X -> eval_tree(X, undefined_grad), X)[1]
123+
@test iszero(only(evaluated_gradient))
124+
125+
# Finally, the operator with a NaN gradient but non-NaN forward
126+
@test !isnan(only(eval_tree(X, nan_grad)))
127+
evaluated_gradient = zg_gradient(X -> eval_tree(X, nan_grad), X)[1]
128+
@test isnan(only(evaluated_gradient))
129+
evaluated_gradient = zg_gradient(t -> eval_tree(X, t), nan_grad)[1]
130+
@show evaluated_gradient
131+
# @test isnan(only(evaluated_gradient.gradient))
54132
end

0 commit comments

Comments
 (0)