Skip to content

Commit 7dd069a

Browse files
committed
refactor: optim tests
1 parent 666ad1b commit 7dd069a

File tree

4 files changed

+37
-27
lines changed

4 files changed

+37
-27
lines changed

ext/DynamicExpressionsOptimExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function wrap_func(
4242
f::F, tree::N, refs
4343
) where {F<:Function,T,N<:Union{AbstractExpressionNode{T},AbstractExpression{T}}}
4444
function wrapped_f(args::Vararg{Any,M}) where {M}
45-
first_args = args[begin:end-1]
45+
first_args = args[begin:(end - 1)]
4646
x = args[end]
4747
set_constants!(tree, x, refs)
4848
return @inline(f(first_args..., tree))

test/test_optim.jl

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,7 @@
1-
using DynamicExpressions, Optim, Zygote
2-
using Random: MersenneTwister as RNG
3-
using Test
1+
@testitem "Basic optimization" begin
2+
using DynamicExpressions, Optim
43

5-
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
6-
x1, x2 = (i -> Node(Float64; feature=i)).(1:2)
7-
8-
X = rand(RNG(0), Float64, 2, 100)
9-
y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9
10-
11-
original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2
12-
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2
13-
14-
f(tree) = sum(abs2, tree(X, operators) .- y)
15-
function g!(G, tree)
16-
dy = only(gradient(f, tree))
17-
G .= dy.gradient
18-
return nothing
19-
end
20-
21-
@testset "Basic optimization" begin
4+
include("test_optim_setup.jl")
225
tree = copy(original_tree)
236
res = optimize(f, tree)
247

@@ -33,7 +16,11 @@ end
3316
)
3417
end
3518

36-
@testset "With gradients, using Zygote" begin
19+
@testitem "With gradients, using Zygote" begin
20+
using DynamicExpressions, Optim, Zygote
21+
22+
include("test_optim_setup.jl")
23+
3724
tree = copy(original_tree)
3825
res = optimize(f, g!, tree, BFGS())
3926
@test tree == original_tree
@@ -42,7 +29,11 @@ end
4229
)
4330
end
4431

45-
@testset "With gradients, manually" begin
32+
@testitem "With gradients, manually" begin
33+
using DynamicExpressions, Optim, Zygote
34+
35+
include("test_optim_setup.jl")
36+
4637
tree = copy(original_tree)
4738
did_i_run = Ref(false)
4839
# Now, try with gradients too (via Zygote and our hand-rolled forward-mode AD)
@@ -79,7 +70,10 @@ end
7970
end
8071

8172
# Now, try combined
82-
@testset "Combined evaluation with gradient" begin
73+
@testitem "Combined evaluation with gradient" begin
74+
using DynamicExpressions, Optim, Zygote
75+
include("test_optim_setup.jl")
76+
8377
tree = copy(original_tree)
8478
did_i_run_2 = Ref(false)
8579
fg!(F, G, tree) =

test/test_optim_setup.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using DynamicExpressions
2+
using Random: MersenneTwister as RNG
3+
4+
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
5+
x1, x2 = (i -> Node(Float64; feature=i)).(1:2)
6+
7+
X = rand(RNG(0), Float64, 2, 100)
8+
y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9
9+
10+
original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2
11+
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2
12+
13+
f(tree) = sum(abs2, tree(X, operators) .- y)
14+
function g!(G, tree)
15+
dy = only(gradient(f, tree))
16+
G .= dy.gradient
17+
return nothing
18+
end

test/unittest.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ using Zygote, SymbolicUtils, LoopVectorization, Bumper, Optim
1616
include("test_deprecations.jl")
1717
end
1818

19-
@testitem "Test Optim.jl" begin
20-
include("test_optim.jl")
21-
end
19+
include("test_optim.jl")
2220

2321
@testitem "Test tree construction and scoring" begin
2422
include("test_tree_construction.jl")

0 commit comments

Comments
 (0)