@@ -6,15 +6,8 @@ using Zygote
66using LinearAlgebra
77
88seed = 0
9- pow_abs2 (x:: T , y:: T ) where {T<: Real } = abs (x)^ y
10- custom_cos (x:: T ) where {T<: Real } = cos (x)^ 2
11-
12- # Define these custom functions for Node data types:
13- pow_abs2 (l:: Node , r:: Node ):: Node =
14- (l. constant && r. constant) ? Node (pow_abs2 (l. val, r. val):: Real ) : Node (5 , l, r)
15- pow_abs2 (l:: Node , r:: Real ):: Node = l. constant ? Node (pow_abs2 (l. val, r)) : Node (5 , l, r)
16- pow_abs2 (l:: Real , r:: Node ):: Node = r. constant ? Node (pow_abs2 (l, r. val)) : Node (5 , l, r)
17- custom_cos (x:: Node ):: Node = x. constant ? Node (; val= custom_cos (x. val)) : Node (1 , x)
9+ pow_abs2 (x, y) = abs (x)^ y
10+ custom_cos (x) = cos (x)^ 2
1811
1912equation1 (x1, x2, x3) = x1 + x2 + x3 + 3.2
2013equation2 (x1, x2, x3) = pow_abs2 (x1, x2) + x3 + custom_cos (1.0 + x3) + 3.0 / x1
@@ -56,6 +49,7 @@ for type in [Float16, Float32, Float64]
5649 unary_operators= (custom_cos, exp, sin),
5750 enable_autodiff= true ,
5851 )
52+ @extend_operators operators
5953
6054 for j in 1 : 3
6155 equation = [equation1, equation2, equation3][j]
@@ -83,14 +77,17 @@ for type in [Float16, Float32, Float64]
8377 reduce (
8478 hcat, [eval_diff_tree_array (tree, X, operators, i)[2 ] for i in 1 : nfeatures]
8579 )'
80+ predicted_grad3 = tree' (X)
8681
8782 # Print largest difference between predicted_grad, true_grad:
8883 @test array_test (predicted_grad, true_grad)
8984 @test array_test (predicted_grad2, true_grad)
85+ @test array_test (predicted_grad3, true_grad)
9086
9187 # Make sure that the array_test actually works:
9288 @test ! array_test (predicted_grad .* 0 , true_grad)
9389 @test ! array_test (predicted_grad2 .* 0 , true_grad)
90+ @test ! array_test (predicted_grad3 .* 0 , true_grad)
9491 end
9592 println (" Done." )
9693 println (" Testing derivatives with respect to constants, with type=$(type) ." )
@@ -141,6 +138,7 @@ operators = OperatorEnum(;
141138 unary_operators= (custom_cos, exp, sin),
142139 enable_autodiff= true ,
143140)
141+ @extend_operators operators
144142tree = equation3 (nx1, nx2, nx3)
145143
146144""" Check whether the ordering of constant_list is the same as the ordering of node_index."""
0 commit comments