Skip to content

Commit 977e0c7

Browse files
committed
Include test for derivative from adjoint symbol
1 parent fb7c5a4 commit 977e0c7

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

test/test_derivatives.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,8 @@ using Zygote
66
using LinearAlgebra
77

88
seed = 0
9-
pow_abs2(x::T, y::T) where {T<:Real} = abs(x)^y
10-
custom_cos(x::T) where {T<:Real} = cos(x)^2
11-
12-
# Define these custom functions for Node data types:
13-
pow_abs2(l::Node, r::Node)::Node =
14-
(l.constant && r.constant) ? Node(pow_abs2(l.val, r.val)::Real) : Node(5, l, r)
15-
pow_abs2(l::Node, r::Real)::Node = l.constant ? Node(pow_abs2(l.val, r)) : Node(5, l, r)
16-
pow_abs2(l::Real, r::Node)::Node = r.constant ? Node(pow_abs2(l, r.val)) : Node(5, l, r)
17-
custom_cos(x::Node)::Node = x.constant ? Node(; val=custom_cos(x.val)) : Node(1, x)
9+
pow_abs2(x, y) = abs(x)^y
10+
custom_cos(x) = cos(x)^2
1811

1912
equation1(x1, x2, x3) = x1 + x2 + x3 + 3.2
2013
equation2(x1, x2, x3) = pow_abs2(x1, x2) + x3 + custom_cos(1.0 + x3) + 3.0 / x1
@@ -56,6 +49,7 @@ for type in [Float16, Float32, Float64]
5649
unary_operators=(custom_cos, exp, sin),
5750
enable_autodiff=true,
5851
)
52+
@extend_operators operators
5953

6054
for j in 1:3
6155
equation = [equation1, equation2, equation3][j]
@@ -83,14 +77,17 @@ for type in [Float16, Float32, Float64]
8377
reduce(
8478
hcat, [eval_diff_tree_array(tree, X, operators, i)[2] for i in 1:nfeatures]
8579
)'
80+
predicted_grad3 = tree'(X)
8681

8782
# Print largest difference between predicted_grad, true_grad:
8883
@test array_test(predicted_grad, true_grad)
8984
@test array_test(predicted_grad2, true_grad)
85+
@test array_test(predicted_grad3, true_grad)
9086

9187
# Make sure that the array_test actually works:
9288
@test !array_test(predicted_grad .* 0, true_grad)
9389
@test !array_test(predicted_grad2 .* 0, true_grad)
90+
@test !array_test(predicted_grad3 .* 0, true_grad)
9491
end
9592
println("Done.")
9693
println("Testing derivatives with respect to constants, with type=$(type).")
@@ -141,6 +138,7 @@ operators = OperatorEnum(;
141138
unary_operators=(custom_cos, exp, sin),
142139
enable_autodiff=true,
143140
)
141+
@extend_operators operators
144142
tree = equation3(nx1, nx2, nx3)
145143

146144
"""Check whether the ordering of constant_list is the same as the ordering of node_index."""

0 commit comments

Comments
 (0)