Skip to content

Commit 4310eca

Browse files
committed
Fix optim tests
1 parent 339c619 commit 4310eca

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

test/test_optim.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
using DynamicExpressions, Optim, Zygote
22
using Random: Xoshiro
33

4-
@testset "Basic optimization" begin
5-
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
6-
x1, x2 = (i -> Node(Float64; feature=i)).(1:2)
4+
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
5+
x1, x2 = (i -> Node(Float64; feature=i)).(1:2)
76

8-
X = rand(Xoshiro(0), Float64, 2, 100)
9-
y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9
7+
X = rand(Xoshiro(0), Float64, 2, 100)
8+
y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9
109

11-
original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2
12-
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2
13-
tree = copy(original_tree)
10+
original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2
11+
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2
1412

15-
f(tree) = sum(abs2, tree(X, operators) .- y)
13+
f(tree) = sum(abs2, tree(X, operators) .- y)
1614

15+
@testset "Basic optimization" begin
16+
tree = copy(original_tree)
1717
res = optimize(f, tree)
1818

1919
# Should be unchanged by default
@@ -26,6 +26,7 @@ using Random: Xoshiro
2626
end
2727

2828
@testset "With gradients" begin
29+
tree = copy(original_tree)
2930
did_i_run = Ref(false)
3031
# Now, try with gradients too (via Zygote and our hand-rolled forward-mode AD)
3132
g!(G, tree) =
@@ -49,6 +50,7 @@ end
4950

5051
# Now, try combined
5152
@testset "Combined evaluation with gradient" begin
53+
tree = copy(original_tree)
5254
did_i_run_2 = Ref(false)
5355
fg!(F, G, tree) =
5456
let

0 commit comments

Comments
 (0)