Skip to content

Commit 749385a

Browse files
committed
test: Zygote optimization within optim
1 parent 64b803e commit 749385a

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

test/test_optim.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2
1212
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2
1313

1414
f(tree) = sum(abs2, tree(X, operators) .- y)
15+
function g!(G, tree)
16+
dy = only(gradient(f, tree))
17+
G .= dy.gradient
18+
return nothing
19+
end
1520

1621
@testset "Basic optimization" begin
1722
tree = copy(original_tree)
@@ -26,7 +31,14 @@ f(tree) = sum(abs2, tree(X, operators) .- y)
2631
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
2732
end
2833

29-
@testset "With gradients" begin
34+
@testset "With gradients, using Zygote" begin
35+
tree = copy(original_tree)
36+
res = optimize(f, g!, tree, BFGS())
37+
@test tree == original_tree
38+
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
39+
end
40+
41+
@testset "With gradients, manually" begin
3042
tree = copy(original_tree)
3143
did_i_run = Ref(false)
3244
# Now, try with gradients too (via Zygote and our hand-rolled forward-mode AD)

0 commit comments

Comments
 (0)