@@ -7,38 +7,63 @@ import ..EquationModule: string_tree, Node
77import .. EvaluateEquationModule: eval_tree_array
88import .. EvaluateEquationDerivativeModule: eval_grad_tree_array
99
10- """
11- OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, extend_user_operators::Bool=false)
10+ function create_evaluation_helper_functions (operators:: OperatorEnum )
11+ @eval begin
12+ Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
13+ Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
14+ function (tree:: Node{T} )(X:: AbstractArray{T,2} ):: AbstractArray{T,1} where {T<: Real }
15+ out, did_finish = eval_tree_array (tree, X, $ operators)
16+ if ! did_finish
17+ out .= T (NaN )
18+ end
19+ return out
20+ end
21+ function (tree:: Node{T1} )(X:: AbstractArray{T2,2} ) where {T1<: Real ,T2<: Real }
22+ if T1 != T2
23+ T = promote_type (T1, T2)
24+ tree = convert (Node{T}, tree)
25+ X = T .(X)
26+ end
27+ return tree (X)
28+ end
29+ # Gradients:
30+ function Base. adjoint (tree:: Node{T} ) where {T}
31+ return X -> begin
32+ _, grad, did_complete = eval_grad_tree_array (tree, X, $ operators; variable= true )
33+ ! did_complete && (grad .= T (NaN ))
34+ grad
35+ end
36+ end
37+ end
38+ end
1239
13- Construct an `OperatorEnum` object, defining the possible expressions. This will also
14- redefine operators for `Node` types, as well as `show`, `print`, and `(::Node)(X)`.
15- It will automatically compute derivatives with `Zygote.jl`.
40+ function create_evaluation_helper_functions (operators:: GenericOperatorEnum )
41+ @eval begin
42+ Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
43+ Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
1644
17- # Arguments
18- - `binary_operators::Vector{Function}`: A vector of functions, each of which is a binary
19- operator.
20- - `unary_operators::Vector{Function}`: A vector of functions, each of which is a unary
21- operator.
22- - `enable_autodiff::Bool=false`: Whether to enable automatic differentiation.
23- - `extend_user_operators::Bool=false`: Whether to extend the user's operators to
24- `Node` types. All operators defined in `Base` will already be extended automatically.
25- """
26- function OperatorEnum (;
27- binary_operators= [],
28- unary_operators= [],
29- enable_autodiff:: Bool = false ,
30- extend_user_operators:: Bool = false ,
31- )
32- @assert length (binary_operators) > 0 || length (unary_operators) > 0
33- @assert length (binary_operators) <= max_ops && length (unary_operators) <= max_ops
34- binary_operators = Tuple (binary_operators)
35- unary_operators = Tuple (unary_operators)
45+ function (tree:: Node )(X; throw_errors:: Bool = true )
46+ out, did_finish = eval_tree_array (
47+ tree, X, $ operators; throw_errors= throw_errors
48+ )
49+ if ! did_finish
50+ return nothing
51+ end
52+ return out
53+ end
54+ end
55+ end
3656
57+ function create_node_helper_functions (
58+ operators:: AbstractOperatorEnum ; extend_user_operators:: Bool = false
59+ )
3760 for (op, f) in enumerate (map (Symbol, binary_operators))
38- f = if f in [:pow , :safe_pow ]
39- Symbol (^ )
40- else
41- f
61+ if typeof (operators) <: OperatorEnum
62+ f = if f in [:pow , :safe_pow ]
63+ Symbol (^ )
64+ else
65+ f
66+ end
4267 end
4368 if isdefined (Base, f)
4469 f = :(Base.$ (f))
@@ -99,6 +124,37 @@ function OperatorEnum(;
99124 end ,
100125 )
101126 end
127+ end
128+
129+ """
130+ OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, extend_user_operators::Bool=false)
131+
132+ Construct an `OperatorEnum` object, defining the possible expressions. This will also
133+ redefine operators for `Node` types, as well as `show`, `print`, and `(::Node)(X)`.
134+ It will automatically compute derivatives with `Zygote.jl`.
135+
136+ # Arguments
137+ - `binary_operators::Vector{Function}`: A vector of functions, each of which is a binary
138+ operator.
139+ - `unary_operators::Vector{Function}`: A vector of functions, each of which is a unary
140+ operator.
141+ - `enable_autodiff::Bool=false`: Whether to enable automatic differentiation.
142+ - `extend_user_operators::Bool=false`: Whether to extend the user's operators to
143+ `Node` types. All operators defined in `Base` will already be extended automatically.
144+ - `define_helper_functions::Bool=true`: Whether to define helper functions for creating
145+ and evaluating node types. Turn this off when doing precompilation.
146+ """
147+ function OperatorEnum (;
148+ binary_operators= [],
149+ unary_operators= [],
150+ enable_autodiff:: Bool = false ,
151+ extend_user_operators:: Bool = false ,
152+ define_helper_functions:: Bool = true ,
153+ )
154+ @assert length (binary_operators) > 0 || length (unary_operators) > 0
155+ @assert length (binary_operators) <= max_ops && length (unary_operators) <= max_ops
156+ binary_operators = Tuple (binary_operators)
157+ unary_operators = Tuple (unary_operators)
102158
103159 if enable_autodiff
104160 diff_binary_operators = Any[]
@@ -152,35 +208,9 @@ function OperatorEnum(;
152208 binary_operators, unary_operators, diff_binary_operators, diff_unary_operators
153209 )
154210
155- @eval begin
156- Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
157- Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
158- import DynamicExpressions: Node
159-
160- function (tree:: Node{T} )(X:: AbstractArray{T,2} ):: AbstractArray{T,1} where {T<: Real }
161- out, did_finish = eval_tree_array (tree, X, $ operators)
162- if ! did_finish
163- out .= T (NaN )
164- end
165- return out
166- end
167- function (tree:: Node{T1} )(X:: AbstractArray{T2,2} ) where {T1<: Real ,T2<: Real }
168- if T1 != T2
169- T = promote_type (T1, T2)
170- tree = convert (Node{T}, tree)
171- X = T .(X)
172- end
173- return tree (X)
174- end
175-
176- # Gradients:
177- function Base. adjoint (tree:: Node{T} ) where {T}
178- return X -> begin
179- _, grad, did_complete = eval_grad_tree_array (tree, X, $ operators; variable= true )
180- ! did_complete && (grad .= T (NaN ))
181- grad
182- end
183- end
211+ if define_helper_functions
212+ create_node_helper_functions (operators; extend_user_operators= extend_user_operators)
213+ create_evaluation_helper_functions (operators)
184214 end
185215
186216 return operators
@@ -201,97 +231,26 @@ and `(::Node)(X)`.
201231 operator on real scalars.
202232- `extend_user_operators::Bool=false`: Whether to extend the user's operators to
203233 `Node` types. All operators defined in `Base` will already be extended automatically.
234+ - `define_helper_functions::Bool=true`: Whether to define helper functions for creating
235+ and evaluating node types. Turn this off when doing precompilation.
204236"""
205237function GenericOperatorEnum (;
206- binary_operators= [], unary_operators= [], extend_user_operators:: Bool = false
238+ binary_operators= [],
239+ unary_operators= [],
240+ extend_user_operators:: Bool = false ,
241+ define_helper_functions:: Bool = true ,
207242)
208243 binary_operators = Tuple (binary_operators)
209244 unary_operators = Tuple (unary_operators)
210245
211246 @assert length (binary_operators) > 0 || length (unary_operators) > 0
212247 @assert length (binary_operators) <= max_ops && length (unary_operators) <= max_ops
213248
214- for (op, f) in enumerate (map (Symbol, binary_operators))
215- f = if f in [:pow , :safe_pow ]
216- Symbol (^ )
217- else
218- f
219- end
220- if isdefined (Base, f)
221- f = :(Base.$ f)
222- elseif ! extend_user_operators
223- # Skip non-Base operators!
224- continue
225- end
226- Base. MainInclude. eval (
227- quote
228- import DynamicExpressions: Node
229- function $f (l:: Node{T1} , r:: Node{T2} ) where {T1,T2}
230- T = promote_type (T1, T2)
231- l = convert (Node{T}, l)
232- r = convert (Node{T}, r)
233- if (l. constant && r. constant)
234- return Node (; val= $ f (l. val, r. val))
235- else
236- return Node ($ op, l, r)
237- end
238- end
239- function $f (l:: Node{T1} , r:: T2 ) where {T1,T2}
240- T = promote_type (T1, T2)
241- l = convert (Node{T}, l)
242- r = convert (T, r)
243- return if l. constant
244- Node (; val= $ f (l. val, r))
245- else
246- Node ($ op, l, Node (; val= r))
247- end
248- end
249- function $f (l:: T1 , r:: Node{T2} ) where {T1,T2}
250- T = promote_type (T1, T2)
251- l = convert (T, l)
252- r = convert (Node{T}, r)
253- return if r. constant
254- Node (; val= $ f (l, r. val))
255- else
256- Node ($ op, Node (; val= l), r)
257- end
258- end
259- end ,
260- )
261- end
262- # Redefine Base operations:
263- for (op, f) in enumerate (map (Symbol, unary_operators))
264- if isdefined (Base, f)
265- f = :(Base.$ f)
266- elseif ! extend_user_operators
267- # Skip non-Base operators!
268- continue
269- end
270- Base. MainInclude. eval (
271- quote
272- import DynamicExpressions: Node
273- function $f (l:: Node{T} ):: Node{T} where {T}
274- return l. constant ? Node (; val= $ f (l. val)) : Node ($ op, l)
275- end
276- end ,
277- )
278- end
279-
280249 operators = GenericOperatorEnum (binary_operators, unary_operators)
281250
282- @eval begin
283- Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
284- Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
285-
286- function (tree:: Node )(X; throw_errors:: Bool = true )
287- out, did_finish = eval_tree_array (
288- tree, X, $ operators; throw_errors= throw_errors
289- )
290- if ! did_finish
291- return nothing
292- end
293- return out
294- end
251+ if define_helper_functions
252+ create_node_helper_functions (operators; extend_user_operators= extend_user_operators)
253+ create_evaluation_helper_functions (operators)
295254 end
296255
297256 return operators
0 commit comments