Skip to content

Commit 061c198

Browse files
committed
Make it optional to define helper functions
1 parent af68fa8 commit 061c198

File tree

1 file changed

+96
-137
lines changed

1 file changed

+96
-137
lines changed

src/OperatorEnumConstruction.jl

Lines changed: 96 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,63 @@ import ..EquationModule: string_tree, Node
77
import ..EvaluateEquationModule: eval_tree_array
88
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
99

10-
"""
11-
OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, extend_user_operators::Bool=false)
10+
function create_evaluation_helper_functions(operators::OperatorEnum)
11+
@eval begin
12+
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
13+
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
14+
function (tree::Node{T})(X::AbstractArray{T,2})::AbstractArray{T,1} where {T<:Real}
15+
out, did_finish = eval_tree_array(tree, X, $operators)
16+
if !did_finish
17+
out .= T(NaN)
18+
end
19+
return out
20+
end
21+
function (tree::Node{T1})(X::AbstractArray{T2,2}) where {T1<:Real,T2<:Real}
22+
if T1 != T2
23+
T = promote_type(T1, T2)
24+
tree = convert(Node{T}, tree)
25+
X = T.(X)
26+
end
27+
return tree(X)
28+
end
29+
# Gradients:
30+
function Base.adjoint(tree::Node{T}) where {T}
31+
return X -> begin
32+
_, grad, did_complete = eval_grad_tree_array(tree, X, $operators; variable=true)
33+
!did_complete && (grad .= T(NaN))
34+
grad
35+
end
36+
end
37+
end
38+
end
1239

13-
Construct an `OperatorEnum` object, defining the possible expressions. This will also
14-
redefine operators for `Node` types, as well as `show`, `print`, and `(::Node)(X)`.
15-
It will automatically compute derivatives with `Zygote.jl`.
40+
function create_evaluation_helper_functions(operators::GenericOperatorEnum)
41+
@eval begin
42+
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
43+
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
1644

17-
# Arguments
18-
- `binary_operators::Vector{Function}`: A vector of functions, each of which is a binary
19-
operator.
20-
- `unary_operators::Vector{Function}`: A vector of functions, each of which is a unary
21-
operator.
22-
- `enable_autodiff::Bool=false`: Whether to enable automatic differentiation.
23-
- `extend_user_operators::Bool=false`: Whether to extend the user's operators to
24-
`Node` types. All operators defined in `Base` will already be extended automatically.
25-
"""
26-
function OperatorEnum(;
27-
binary_operators=[],
28-
unary_operators=[],
29-
enable_autodiff::Bool=false,
30-
extend_user_operators::Bool=false,
31-
)
32-
@assert length(binary_operators) > 0 || length(unary_operators) > 0
33-
@assert length(binary_operators) <= max_ops && length(unary_operators) <= max_ops
34-
binary_operators = Tuple(binary_operators)
35-
unary_operators = Tuple(unary_operators)
45+
function (tree::Node)(X; throw_errors::Bool=true)
46+
out, did_finish = eval_tree_array(
47+
tree, X, $operators; throw_errors=throw_errors
48+
)
49+
if !did_finish
50+
return nothing
51+
end
52+
return out
53+
end
54+
end
55+
end
3656

57+
function create_node_helper_functions(
58+
operators::AbstractOperatorEnum; extend_user_operators::Bool=false
59+
)
3760
for (op, f) in enumerate(map(Symbol, binary_operators))
38-
f = if f in [:pow, :safe_pow]
39-
Symbol(^)
40-
else
41-
f
61+
if typeof(operators) <: OperatorEnum
62+
f = if f in [:pow, :safe_pow]
63+
Symbol(^)
64+
else
65+
f
66+
end
4267
end
4368
if isdefined(Base, f)
4469
f = :(Base.$(f))
@@ -99,6 +124,37 @@ function OperatorEnum(;
99124
end,
100125
)
101126
end
127+
end
128+
129+
"""
130+
OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, extend_user_operators::Bool=false)
131+
132+
Construct an `OperatorEnum` object, defining the possible expressions. This will also
133+
redefine operators for `Node` types, as well as `show`, `print`, and `(::Node)(X)`.
134+
It will automatically compute derivatives with `Zygote.jl`.
135+
136+
# Arguments
137+
- `binary_operators::Vector{Function}`: A vector of functions, each of which is a binary
138+
operator.
139+
- `unary_operators::Vector{Function}`: A vector of functions, each of which is a unary
140+
operator.
141+
- `enable_autodiff::Bool=false`: Whether to enable automatic differentiation.
142+
- `extend_user_operators::Bool=false`: Whether to extend the user's operators to
143+
`Node` types. All operators defined in `Base` will already be extended automatically.
144+
- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
145+
and evaluating node types. Turn this off when doing precompilation.
146+
"""
147+
function OperatorEnum(;
148+
binary_operators=[],
149+
unary_operators=[],
150+
enable_autodiff::Bool=false,
151+
extend_user_operators::Bool=false,
152+
define_helper_functions::Bool=true,
153+
)
154+
@assert length(binary_operators) > 0 || length(unary_operators) > 0
155+
@assert length(binary_operators) <= max_ops && length(unary_operators) <= max_ops
156+
binary_operators = Tuple(binary_operators)
157+
unary_operators = Tuple(unary_operators)
102158

103159
if enable_autodiff
104160
diff_binary_operators = Any[]
@@ -152,35 +208,9 @@ function OperatorEnum(;
152208
binary_operators, unary_operators, diff_binary_operators, diff_unary_operators
153209
)
154210

155-
@eval begin
156-
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
157-
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
158-
import DynamicExpressions: Node
159-
160-
function (tree::Node{T})(X::AbstractArray{T,2})::AbstractArray{T,1} where {T<:Real}
161-
out, did_finish = eval_tree_array(tree, X, $operators)
162-
if !did_finish
163-
out .= T(NaN)
164-
end
165-
return out
166-
end
167-
function (tree::Node{T1})(X::AbstractArray{T2,2}) where {T1<:Real,T2<:Real}
168-
if T1 != T2
169-
T = promote_type(T1, T2)
170-
tree = convert(Node{T}, tree)
171-
X = T.(X)
172-
end
173-
return tree(X)
174-
end
175-
176-
# Gradients:
177-
function Base.adjoint(tree::Node{T}) where {T}
178-
return X -> begin
179-
_, grad, did_complete = eval_grad_tree_array(tree, X, $operators; variable=true)
180-
!did_complete && (grad .= T(NaN))
181-
grad
182-
end
183-
end
211+
if define_helper_functions
212+
create_node_helper_functions(operators; extend_user_operators=extend_user_operators)
213+
create_evaluation_helper_functions(operators)
184214
end
185215

186216
return operators
@@ -201,97 +231,26 @@ and `(::Node)(X)`.
201231
operator on real scalars.
202232
- `extend_user_operators::Bool=false`: Whether to extend the user's operators to
203233
`Node` types. All operators defined in `Base` will already be extended automatically.
234+
- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
235+
and evaluating node types. Turn this off when doing precompilation.
204236
"""
205237
function GenericOperatorEnum(;
206-
binary_operators=[], unary_operators=[], extend_user_operators::Bool=false
238+
binary_operators=[],
239+
unary_operators=[],
240+
extend_user_operators::Bool=false,
241+
define_helper_functions::Bool=true,
207242
)
208243
binary_operators = Tuple(binary_operators)
209244
unary_operators = Tuple(unary_operators)
210245

211246
@assert length(binary_operators) > 0 || length(unary_operators) > 0
212247
@assert length(binary_operators) <= max_ops && length(unary_operators) <= max_ops
213248

214-
for (op, f) in enumerate(map(Symbol, binary_operators))
215-
f = if f in [:pow, :safe_pow]
216-
Symbol(^)
217-
else
218-
f
219-
end
220-
if isdefined(Base, f)
221-
f = :(Base.$f)
222-
elseif !extend_user_operators
223-
# Skip non-Base operators!
224-
continue
225-
end
226-
Base.MainInclude.eval(
227-
quote
228-
import DynamicExpressions: Node
229-
function $f(l::Node{T1}, r::Node{T2}) where {T1,T2}
230-
T = promote_type(T1, T2)
231-
l = convert(Node{T}, l)
232-
r = convert(Node{T}, r)
233-
if (l.constant && r.constant)
234-
return Node(; val=$f(l.val, r.val))
235-
else
236-
return Node($op, l, r)
237-
end
238-
end
239-
function $f(l::Node{T1}, r::T2) where {T1,T2}
240-
T = promote_type(T1, T2)
241-
l = convert(Node{T}, l)
242-
r = convert(T, r)
243-
return if l.constant
244-
Node(; val=$f(l.val, r))
245-
else
246-
Node($op, l, Node(; val=r))
247-
end
248-
end
249-
function $f(l::T1, r::Node{T2}) where {T1,T2}
250-
T = promote_type(T1, T2)
251-
l = convert(T, l)
252-
r = convert(Node{T}, r)
253-
return if r.constant
254-
Node(; val=$f(l, r.val))
255-
else
256-
Node($op, Node(; val=l), r)
257-
end
258-
end
259-
end,
260-
)
261-
end
262-
# Redefine Base operations:
263-
for (op, f) in enumerate(map(Symbol, unary_operators))
264-
if isdefined(Base, f)
265-
f = :(Base.$f)
266-
elseif !extend_user_operators
267-
# Skip non-Base operators!
268-
continue
269-
end
270-
Base.MainInclude.eval(
271-
quote
272-
import DynamicExpressions: Node
273-
function $f(l::Node{T})::Node{T} where {T}
274-
return l.constant ? Node(; val=$f(l.val)) : Node($op, l)
275-
end
276-
end,
277-
)
278-
end
279-
280249
operators = GenericOperatorEnum(binary_operators, unary_operators)
281250

282-
@eval begin
283-
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
284-
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
285-
286-
function (tree::Node)(X; throw_errors::Bool=true)
287-
out, did_finish = eval_tree_array(
288-
tree, X, $operators; throw_errors=throw_errors
289-
)
290-
if !did_finish
291-
return nothing
292-
end
293-
return out
294-
end
251+
if define_helper_functions
252+
create_node_helper_functions(operators; extend_user_operators=extend_user_operators)
253+
create_evaluation_helper_functions(operators)
295254
end
296255

297256
return operators

0 commit comments

Comments
 (0)