1+ module ConstantOptimizationModule
2+
3+ import Compat: @inline
4+ import Optim: optimize
5+ import .. EquationModule: Node
6+ import .. EvaluateEquationModule: eval_tree_array
7+
8+ """ Wrap f with insertion of values of the constant nodes."""
9+ function get_wrapped_f (
10+ f:: F , tree:: N , constant_nodes:: AbstractArray{N}
11+ ) where {F,T,N<: Node{T} }
12+ function wrapped_f (x)
13+ for (ci, xi) in zip (constant_nodes, x)
14+ ci. val:: T = xi:: T
15+ end
16+ return @inline (f (tree))
17+ end
18+ return wrapped_f
19+ end
20+
21+ """ Wrap g! or h! with insertion of values of the constant nodes."""
22+ function get_wrapped_gh! (
23+ gh!:: GH , tree:: N , constant_nodes:: AbstractArray{N}
24+ ) where {GH,T,N<: Node{T} }
25+ function wrapped_gh! (G, x)
26+ for (ci, xi) in zip (constant_nodes, x)
27+ ci. val:: T = xi:: T
28+ end
29+ @inline (gh! (G, tree))
30+ return nothing
31+ end
32+ return wrapped_gh!
33+ end
34+
35+ function optimize (f:: F , g!:: G , h!:: H , tree:: Node{T} , args... ; kwargs... ) where {F,G,H,T}
36+ constant_nodes = filter (t -> t. degree == 0 && t. constant, tree)
37+ x0 = [t. val:: T for t in constant_nodes]
38+ if g! === nothing
39+ @assert h! === nothing
40+ return optimize (get_wrapped_f (f, tree, constant_nodes), x0, args... ; kwargs... )
41+ elseif h! === nothing
42+ return optimize (
43+ get_wrapped_f (f, tree, constant_nodes),
44+ get_wrapped_gh! (g!, tree, constant_nodes),
45+ x0,
46+ args... ;
47+ kwargs... ,
48+ )
49+ else
50+ return optimize (
51+ get_wrapped_f (f, tree, constant_nodes),
52+ get_wrapped_gh! (g!, tree, constant_nodes),
53+ get_wrapped_gh! (h!, tree, constant_nodes),
54+ x0,
55+ args... ;
56+ kwargs... ,
57+ )
58+ end
59+ end
60+ function optimize (f:: F , g!:: G , tree:: Node , args... ; kwargs... ) where {F,G}
61+ return optimize (f, g!, nothing , tree, args... ; kwargs... )
62+ end
63+ function optimize (f:: F , tree:: Node , args... ; kwargs... ) where {F}
64+ return optimize (f, nothing , tree, args... ; kwargs... )
65+ end
66+
67+ end
0 commit comments