@@ -3,16 +3,51 @@ module DynamicExpressionsOptimExt
33using DynamicExpressions: AbstractExpressionNode, eval_tree_array
44using Compat: @inline
55
6- import Optim: optimize
6+ import Optim: Optim, OptimizationResults
7+
8+ # ! format: off
9+ """
10+ ExpressionOptimizationResults{R,N<:AbstractExpressionNode}
11+
12+ Optimization results for an expression, which wraps the base optimization results
13+ on a vector of constants.
14+ """
15+ struct ExpressionOptimizationResults{R<: OptimizationResults ,N<: AbstractExpressionNode } <: OptimizationResults
16+ _results:: R # The raw results from Optim.
17+ tree:: N # The final expression tree
18+ end
19+ # ! format: on
20+ function Base. getproperty (r:: ExpressionOptimizationResults , s:: Symbol )
21+ if s == :tree || s == :minimizer
22+ return getfield (r, :tree )
23+ else
24+ return getproperty (getfield (r, :_results ), s)
25+ end
26+ end
27+ function Base. propertynames (r:: ExpressionOptimizationResults )
28+ return (:tree , propertynames (getfield (r, :_results ))... )
29+ end
30+ function base_results (r:: ExpressionOptimizationResults )
31+ return getfield (r, :_results )
32+ end
33+ function Optim. minimizer (r:: ExpressionOptimizationResults )
34+ return r. tree
35+ end
36+
37+ function set_constant_nodes! (
38+ constant_nodes:: AbstractArray{N} , x
39+ ) where {T,N<: AbstractExpressionNode{T} }
40+ for (ci, xi) in zip (constant_nodes, x)
41+ ci. val:: T = xi:: T
42+ end
43+ end
744
845""" Wrap f with insertion of values of the constant nodes."""
946function get_wrapped_f (
1047 f:: F , tree:: N , constant_nodes:: AbstractArray{N}
1148) where {F,T,N<: AbstractExpressionNode{T} }
1249 function wrapped_f (x)
13- for (ci, xi) in zip (constant_nodes, x)
14- ci. val:: T = xi:: T
15- end
50+ set_constant_nodes! (constant_nodes, x)
1651 return @inline (f (tree))
1752 end
1853 return wrapped_f
@@ -23,31 +58,52 @@ function get_wrapped_gh!(
2358 gh!:: GH , tree:: N , constant_nodes:: AbstractArray{N}
2459) where {GH,T,N<: AbstractExpressionNode{T} }
2560 function wrapped_gh! (G, x)
26- for (ci, xi) in zip (constant_nodes, x)
27- ci. val:: T = xi:: T
28- end
61+ set_constant_nodes! (constant_nodes, x)
2962 @inline (gh! (G, tree))
3063 return nothing
3164 end
3265 return wrapped_gh!
3366end
3467
35- function optimize (f:: F , g!:: G , h!:: H , tree:: AbstractExpressionNode{T} , args... ; kwargs... ) where {F,G,H,T}
68+ """
69+ optimize(f, [g!, [h!,]] tree, args...; kwargs...)
70+
71+ Optimize an expression tree with respect to the constants in the tree.
72+ Returns an `ExpressionOptimizationResults` object, which wraps the base
73+ optimization results on a vector of constants. You may use `res.minimizer`
74+ to view the optimized expression tree.
75+ """
76+ function Optim. optimize (
77+ f:: F , tree:: AbstractExpressionNode , args... ; kwargs...
78+ ) where {F<: Function }
79+ return Optim. optimize (f, nothing , tree, args... ; kwargs... )
80+ end
81+ function Optim. optimize (
82+ f:: F , g!:: G , tree:: AbstractExpressionNode , args... ; kwargs...
83+ ) where {F,G<: Union{Function,Nothing} }
84+ return Optim. optimize (f, g!, nothing , tree, args... ; kwargs... )
85+ end
86+ function Optim. optimize (
87+ f:: F , g!:: G , h!:: H , tree:: AbstractExpressionNode{T} , args... ; make_copy= true , kwargs...
88+ ) where {F,G<: Union{Function,Nothing} ,H<: Union{Function,Nothing} ,T}
89+ if make_copy
90+ tree = copy (tree)
91+ end
3692 constant_nodes = filter (t -> t. degree == 0 && t. constant, tree)
3793 x0 = T[t. val:: T for t in constant_nodes]
38- if g! === nothing
94+ base_res = if g! === nothing
3995 @assert h! === nothing
40- return optimize (get_wrapped_f (f, tree, constant_nodes), x0, args... ; kwargs... )
96+ Optim . optimize (get_wrapped_f (f, tree, constant_nodes), x0, args... ; kwargs... )
4197 elseif h! === nothing
42- return optimize (
98+ Optim . optimize (
4399 get_wrapped_f (f, tree, constant_nodes),
44100 get_wrapped_gh! (g!, tree, constant_nodes),
45101 x0,
46102 args... ;
47103 kwargs... ,
48104 )
49105 else
50- return optimize (
106+ Optim . optimize (
51107 get_wrapped_f (f, tree, constant_nodes),
52108 get_wrapped_gh! (g!, tree, constant_nodes),
53109 get_wrapped_gh! (h!, tree, constant_nodes),
@@ -56,12 +112,8 @@ function optimize(f::F, g!::G, h!::H, tree::AbstractExpressionNode{T}, args...;
56112 kwargs... ,
57113 )
58114 end
59- end
60- function optimize (f:: F , g!:: G , tree:: AbstractExpressionNode , args... ; kwargs... ) where {F,G}
61- return optimize (f, g!, nothing , tree, args... ; kwargs... )
62- end
63- function optimize (f:: F , tree:: AbstractExpressionNode , args... ; kwargs... ) where {F}
64- return optimize (f, nothing , tree, args... ; kwargs... )
115+ set_constant_nodes! (constant_nodes, Optim. minimizer (base_res))
116+ return ExpressionOptimizationResults (base_res, tree)
65117end
66118
67119end
0 commit comments