Skip to content

Commit a968ac1

Browse files
committed
Add Optim to tests
1 parent 471578b commit a968ac1

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ DynamicExpressionsZygoteExt = "Zygote"
2929
Aqua = "0.7"
3030
Compat = "3.37, 4"
3131
LoopVectorization = "0.12"
32+
Optim = "0.19, 1"
3233
MacroTools = "0.4, 0.5"
3334
PackageExtensionCompat = "1"
3435
PrecompileTools = "1"
@@ -40,6 +41,7 @@ julia = "1.6"
4041
[extras]
4142
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4243
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
44+
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
4345
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4446
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4547
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -48,4 +50,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4850
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4951

5052
[targets]
51-
test = ["Test", "SafeTestsets", "Aqua", "SpecialFunctions", "ForwardDiff", "StaticArrays", "SymbolicUtils", "Zygote"]
53+
test = ["Test", "SafeTestsets", "Aqua", "Optim", "SpecialFunctions", "ForwardDiff", "StaticArrays", "SymbolicUtils", "Zygote"]

ext/DynamicExpressionsOptimExt.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
module DynamicExpressionsOptimExt
22

3-
using DynamicExpressions: Node, eval_tree_array
3+
using DynamicExpressions: AbstractExpressionNode, eval_tree_array
4+
using Compat: @inline
45

5-
import Compat: @inline
66
import Optim: optimize
77

88
"""Wrap f with insertion of values of the constant nodes."""
99
function get_wrapped_f(
1010
f::F, tree::N, constant_nodes::AbstractArray{N}
11-
) where {F,T,N<:Node{T}}
11+
) where {F,T,N<:AbstractExpressionNode{T}}
1212
function wrapped_f(x)
1313
for (ci, xi) in zip(constant_nodes, x)
1414
ci.val::T = xi::T
@@ -21,7 +21,7 @@ end
2121
"""Wrap g! or h! with insertion of values of the constant nodes."""
2222
function get_wrapped_gh!(
2323
gh!::GH, tree::N, constant_nodes::AbstractArray{N}
24-
) where {GH,T,N<:Node{T}}
24+
) where {GH,T,N<:AbstractExpressionNode{T}}
2525
function wrapped_gh!(G, x)
2626
for (ci, xi) in zip(constant_nodes, x)
2727
ci.val::T = xi::T
@@ -32,9 +32,9 @@ function get_wrapped_gh!(
3232
return wrapped_gh!
3333
end
3434

35-
function optimize(f::F, g!::G, h!::H, tree::Node{T}, args...; kwargs...) where {F,G,H,T}
35+
function optimize(f::F, g!::G, h!::H, tree::AbstractExpressionNode{T}, args...; kwargs...) where {F,G,H,T}
3636
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)
37-
x0 = [t.val::T for t in constant_nodes]
37+
x0 = T[t.val::T for t in constant_nodes]
3838
if g! === nothing
3939
@assert h! === nothing
4040
return optimize(get_wrapped_f(f, tree, constant_nodes), x0, args...; kwargs...)
@@ -57,10 +57,10 @@ function optimize(f::F, g!::G, h!::H, tree::Node{T}, args...; kwargs...) where {
5757
)
5858
end
5959
end
60-
function optimize(f::F, g!::G, tree::Node, args...; kwargs...) where {F,G}
60+
function optimize(f::F, g!::G, tree::AbstractExpressionNode, args...; kwargs...) where {F,G}
6161
return optimize(f, g!, nothing, tree, args...; kwargs...)
6262
end
63-
function optimize(f::F, tree::Node, args...; kwargs...) where {F}
63+
function optimize(f::F, tree::AbstractExpressionNode, args...; kwargs...) where {F}
6464
return optimize(f, nothing, tree, args...; kwargs...)
6565
end
6666

0 commit comments

Comments
 (0)