|
1 | | -using DynamicExpressions, Optim, Zygote |
2 | | -using Random: MersenneTwister as RNG |
3 | | -using Test |
| 1 | +@testitem "Basic optimization" begin |
| 2 | + using DynamicExpressions, Optim |
4 | 3 |
|
5 | | -operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,)) |
6 | | -x1, x2 = (i -> Node(Float64; feature=i)).(1:2) |
7 | | - |
8 | | -X = rand(RNG(0), Float64, 2, 100) |
9 | | -y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9 |
10 | | - |
11 | | -original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2 |
12 | | -target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2 |
13 | | - |
14 | | -f(tree) = sum(abs2, tree(X, operators) .- y) |
15 | | -function g!(G, tree) |
16 | | - dy = only(gradient(f, tree)) |
17 | | - G .= dy.gradient |
18 | | - return nothing |
19 | | -end |
20 | | - |
21 | | -@testset "Basic optimization" begin |
| 4 | + include("test_optim_setup.jl") |
22 | 5 | tree = copy(original_tree) |
23 | 6 | res = optimize(f, tree) |
24 | 7 |
|
|
33 | 16 | ) |
34 | 17 | end |
35 | 18 |
|
36 | | -@testset "With gradients, using Zygote" begin |
| 19 | +@testitem "With gradients, using Zygote" begin |
| 20 | + using DynamicExpressions, Optim, Zygote |
| 21 | + |
| 22 | + include("test_optim_setup.jl") |
| 23 | + |
37 | 24 | tree = copy(original_tree) |
38 | 25 | res = optimize(f, g!, tree, BFGS()) |
39 | 26 | @test tree == original_tree |
|
42 | 29 | ) |
43 | 30 | end |
44 | 31 |
|
45 | | -@testset "With gradients, manually" begin |
| 32 | +@testitem "With gradients, manually" begin |
| 33 | + using DynamicExpressions, Optim, Zygote |
| 34 | + |
| 35 | + include("test_optim_setup.jl") |
| 36 | + |
46 | 37 | tree = copy(original_tree) |
47 | 38 | did_i_run = Ref(false) |
48 | 39 | # Now, try with gradients too (via Zygote and our hand-rolled forward-mode AD) |
|
79 | 70 | end |
80 | 71 |
|
81 | 72 | # Now, try combined |
82 | | -@testset "Combined evaluation with gradient" begin |
| 73 | +@testitem "Combined evaluation with gradient" begin |
| 74 | + using DynamicExpressions, Optim, Zygote |
| 75 | + include("test_optim_setup.jl") |
| 76 | + |
83 | 77 | tree = copy(original_tree) |
84 | 78 | did_i_run_2 = Ref(false) |
85 | 79 | fg!(F, G, tree) = |
|
0 commit comments