Skip to content

Commit f1af402

Browse files
committed
Overload common operators like +, -, *, so errors are more informative
1 parent 5bc2a2a commit f1af402

File tree

1 file changed

+69
-23
lines changed

1 file changed

+69
-23
lines changed

src/OperatorEnumConstruction.jl

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,26 @@ function lookup_op(@nospecialize(f), ::Val{degree}) where {degree}
9696
mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
9797
if !haskey(mapping, f)
9898
error(
99-
"Convenience constructor using operator `$(f)` is out-of-date. " *
100-
"Please create an `OperatorEnum` (or `GenericOperatorEnum`) with " *
101-
"`define_helper_functions=true` and pass `$(f)`.",
99+
"Convenience constructor for operator `$(f)` is out-of-date. " *
100+
"Please create an `OperatorEnum` (or `GenericOperatorEnum`) containing " *
101+
"the operator `$(f)` which will define the `$(f)` -> `Int` mapping.",
102102
)
103103
end
104104
return mapping[f]
105105
end
106106

107-
function _extend_unary_operator(f::Symbol, type_requirements)
107+
function _extend_unary_operator(f::Symbol, type_requirements, internal)
108108
quote
109109
@gensym _constructorof _AbstractExpressionNode
110110
quote
111-
using DynamicExpressions:
112-
constructorof as $_constructorof,
113-
AbstractExpressionNode as $_AbstractExpressionNode
111+
if $$internal
112+
import ..EquationModule.constructorof as $_constructorof
113+
import ..EquationModule.AbstractExpressionNode as $_AbstractExpressionNode
114+
else
115+
using DynamicExpressions:
116+
constructorof as $_constructorof,
117+
AbstractExpressionNode as $_AbstractExpressionNode
118+
end
114119

115120
function $($f)(
116121
l::N
@@ -126,13 +131,18 @@ function _extend_unary_operator(f::Symbol, type_requirements)
126131
end
127132
end
128133

129-
function _extend_binary_operator(f::Symbol, type_requirements, build_converters)
134+
function _extend_binary_operator(f::Symbol, type_requirements, build_converters, internal)
130135
quote
131136
@gensym _constructorof _AbstractExpressionNode
132137
quote
133-
using DynamicExpressions:
134-
constructorof as $_constructorof,
135-
AbstractExpressionNode as $_AbstractExpressionNode
138+
if $$internal
139+
import ..EquationModule.constructorof as $_constructorof
140+
import ..EquationModule.AbstractExpressionNode as $_AbstractExpressionNode
141+
else
142+
using DynamicExpressions:
143+
constructorof as $_constructorof,
144+
AbstractExpressionNode as $_AbstractExpressionNode
145+
end
136146

137147
function $($f)(
138148
l::N, r::N
@@ -191,19 +201,32 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters)
191201
end
192202

193203
function _extend_operators(operators, skip_user_operators, kws, __module__::Module)
194-
empty_old_operators =
195-
if length(kws) == 1 && :empty_old_operators in map(x -> x.args[1], kws)
196-
@assert kws[1].head == :(=)
197-
kws[1].args[2]
198-
else
199-
length(kws) > 0 && error(
200-
"You passed the keywords $(kws), but only `empty_old_operators` is supported.",
201-
)
202-
true
203-
end
204+
if !all(x -> first(x.args) (:empty_old_operators, :internal), kws)
205+
error(
206+
"You passed the keywords $(kws), but only `empty_old_operators`, `internal` are supported.",
207+
)
208+
end
209+
210+
empty_old_operators_idx = findfirst(x -> first(x.args) == :empty_old_operators, kws)
211+
internal_idx = findfirst(x -> first(x.args) == :internal, kws)
212+
213+
empty_old_operators = if empty_old_operators_idx !== nothing
214+
@assert kws[empty_old_operators_idx].head == :(=)
215+
kws[empty_old_operators_idx].args[2]
216+
else
217+
true
218+
end
219+
220+
internal = if internal_idx !== nothing
221+
@assert kws[internal_idx].head == :(=)
222+
kws[internal_idx].args[2]::Bool
223+
else
224+
false
225+
end
226+
204227
@gensym f skip type_requirements build_converters binary_exists unary_exists
205-
binary_ex = _extend_binary_operator(f, type_requirements, build_converters)
206-
unary_ex = _extend_unary_operator(f, type_requirements)
228+
binary_ex = _extend_binary_operator(f, type_requirements, build_converters, internal)
229+
unary_ex = _extend_unary_operator(f, type_requirements, internal)
207230
return quote
208231
local $type_requirements
209232
local $build_converters
@@ -292,6 +315,8 @@ end
292315
293316
Similar to `@extend_operators`, but only extends operators already
294317
defined in `Base`.
318+
`kws` can include `empty_old_operators` which is default `true`,
319+
and `internal` which is default `false`.
295320
"""
296321
macro extend_operators_base(operators, kws...)
297322
ex = _extend_operators(operators, true, kws, __module__)
@@ -402,4 +427,25 @@ function GenericOperatorEnum(;
402427
return operators
403428
end
404429

430+
# Predefine the most common operators so the errors
431+
# are more informative
432+
function _overload_common_operators()
433+
#! format: off
434+
operators = OperatorEnum(
435+
Function[+, -, *, /, ^, max, min, mod],
436+
Function[
437+
sin, cos, tan, exp, log, log1p, log2, log10, sqrt, cbrt, abs, sinh,
438+
cosh, tanh, atan, asinh, acosh, round, sign, floor, ceil,
439+
],
440+
Function[],
441+
Function[],
442+
)
443+
#! format: on
444+
@extend_operators(operators, empty_old_operators = false, internal = true)
445+
empty!(LATEST_UNARY_OPERATOR_MAPPING)
446+
empty!(LATEST_BINARY_OPERATOR_MAPPING)
447+
return nothing
448+
end
449+
_overload_common_operators()
450+
405451
end

0 commit comments

Comments
 (0)