11module DynamicExpressionsOptimExt
22
3- using DynamicExpressions: AbstractExpressionNode, eval_tree_array
3+ using DynamicExpressions:
4+ AbstractExpressionNode, eval_tree_array, get_constant_refs, set_constant_refs!
45using Compat: @inline
56
67import Optim: Optim, OptimizationResults, NLSolversBase
@@ -31,45 +32,36 @@ function Optim.minimizer(r::ExpressionOptimizationResults)
3132 return r. tree
3233end
3334
34- function set_constant_nodes! (
35- constant_nodes:: AbstractArray{N} , x
36- ) where {T,N<: AbstractExpressionNode{T} }
37- for (ci, xi) in zip (constant_nodes, x)
38- ci. val:: T = xi:: T
39- end
40- return nothing
41- end
42-
4335""" Wrap function or objective with insertion of values of the constant nodes."""
4436function wrap_func (
45- f:: F , tree:: N , constant_nodes :: AbstractArray{N}
37+ f:: F , tree:: N , constant_refs :: AbstractArray
4638) where {F<: Function ,T,N<: AbstractExpressionNode{T} }
4739 function wrapped_f (args:: Vararg{Any,M} ) where {M}
4840 first_args = args[1 : (end - 1 )]
4941 x = last (args)
50- set_constant_nodes! (constant_nodes , x)
42+ set_constant_refs! (constant_refs , x)
5143 return @inline (f (first_args... , tree))
5244 end
5345 return wrapped_f
5446end
5547function wrap_func (
56- :: Nothing , tree:: N , constant_nodes :: AbstractArray{N}
48+ :: Nothing , tree:: N , constant_refs :: AbstractArray
5749) where {N<: AbstractExpressionNode }
5850 return nothing
5951end
6052function wrap_func (
61- f:: NLSolversBase.InplaceObjective , tree:: N , constant_nodes :: AbstractArray{N}
53+ f:: NLSolversBase.InplaceObjective , tree:: N , constant_refs :: AbstractArray
6254) where {N<: AbstractExpressionNode }
6355 # Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead
6456 # `InplaceObjective`. These contain multiple functions, each of which needs to be
6557 # wrapped. Some functions are `nothing`; those can be left as-is.
6658 @assert fieldnames (NLSolversBase. InplaceObjective) == (:df , :fdf , :fgh , :hv , :fghv )
6759 return NLSolversBase. InplaceObjective (
68- wrap_func (f. df, tree, constant_nodes ),
69- wrap_func (f. fdf, tree, constant_nodes ),
70- wrap_func (f. fgh, tree, constant_nodes ),
71- wrap_func (f. hv, tree, constant_nodes ),
72- wrap_func (f. fghv, tree, constant_nodes ),
60+ wrap_func (f. df, tree, constant_refs ),
61+ wrap_func (f. fdf, tree, constant_refs ),
62+ wrap_func (f. fgh, tree, constant_refs ),
63+ wrap_func (f. hv, tree, constant_refs ),
64+ wrap_func (f. fghv, tree, constant_refs ),
7365 )
7466end
7567
@@ -95,8 +87,8 @@ function Optim.optimize(
9587 if make_copy
9688 tree = copy (tree)
9789 end
98- constant_nodes = filter (t -> t . degree == 0 && t . constant, tree)
99- x0 = T[t . val :: T for t in constant_nodes]
90+ constant_refs = get_constant_refs ( tree)
91+ x0 = map (t -> t . x, constant_refs)
10092 if ! isnothing (h!)
10193 throw (
10294 ArgumentError (
@@ -106,17 +98,17 @@ function Optim.optimize(
10698 )
10799 end
108100 base_res = if isnothing (g!)
109- Optim. optimize (wrap_func (f, tree, constant_nodes ), x0, args... ; kwargs... )
101+ Optim. optimize (wrap_func (f, tree, constant_refs ), x0, args... ; kwargs... )
110102 else
111103 Optim. optimize (
112- wrap_func (f, tree, constant_nodes ),
113- wrap_func (g!, tree, constant_nodes ),
104+ wrap_func (f, tree, constant_refs ),
105+ wrap_func (g!, tree, constant_refs ),
114106 x0,
115107 args... ;
116108 kwargs... ,
117109 )
118110 end
119- set_constant_nodes! (constant_nodes , Optim. minimizer (base_res))
111+ set_constant_refs! (constant_refs , Optim. minimizer (base_res))
120112 return ExpressionOptimizationResults (base_res, tree)
121113end
122114
0 commit comments