@@ -96,21 +96,26 @@ function lookup_op(@nospecialize(f), ::Val{degree}) where {degree}
9696 mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
9797 if ! haskey (mapping, f)
9898 error (
99- " Convenience constructor using operator `$(f) ` is out-of-date. " *
100- " Please create an `OperatorEnum` (or `GenericOperatorEnum`) with " *
101- " `define_helper_functions=true` and pass `$(f) `." ,
99+ " Convenience constructor for operator `$(f) ` is out-of-date. " *
100+ " Please create an `OperatorEnum` (or `GenericOperatorEnum`) containing " *
101+ " the operator ` $(f) ` which will define the `$(f) ` -> `Int` mapping ." ,
102102 )
103103 end
104104 return mapping[f]
105105end
106106
107- function _extend_unary_operator (f:: Symbol , type_requirements)
107+ function _extend_unary_operator (f:: Symbol , type_requirements, internal )
108108 quote
109109 @gensym _constructorof _AbstractExpressionNode
110110 quote
111- using DynamicExpressions:
112- constructorof as $ _constructorof,
113- AbstractExpressionNode as $ _AbstractExpressionNode
111+ if $$ internal
112+ import .. EquationModule. constructorof as $ _constructorof
113+ import .. EquationModule. AbstractExpressionNode as $ _AbstractExpressionNode
114+ else
115+ using DynamicExpressions:
116+ constructorof as $ _constructorof,
117+ AbstractExpressionNode as $ _AbstractExpressionNode
118+ end
114119
115120 function $ ($ f)(
116121 l:: N
@@ -126,13 +131,18 @@ function _extend_unary_operator(f::Symbol, type_requirements)
126131 end
127132end
128133
129- function _extend_binary_operator (f:: Symbol , type_requirements, build_converters)
134+ function _extend_binary_operator (f:: Symbol , type_requirements, build_converters, internal )
130135 quote
131136 @gensym _constructorof _AbstractExpressionNode
132137 quote
133- using DynamicExpressions:
134- constructorof as $ _constructorof,
135- AbstractExpressionNode as $ _AbstractExpressionNode
138+ if $$ internal
139+ import .. EquationModule. constructorof as $ _constructorof
140+ import .. EquationModule. AbstractExpressionNode as $ _AbstractExpressionNode
141+ else
142+ using DynamicExpressions:
143+ constructorof as $ _constructorof,
144+ AbstractExpressionNode as $ _AbstractExpressionNode
145+ end
136146
137147 function $ ($ f)(
138148 l:: N , r:: N
@@ -191,19 +201,32 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters)
191201end
192202
193203function _extend_operators (operators, skip_user_operators, kws, __module__:: Module )
194- empty_old_operators =
195- if length (kws) == 1 && :empty_old_operators in map (x -> x. args[1 ], kws)
196- @assert kws[1 ]. head == :(= )
197- kws[1 ]. args[2 ]
198- else
199- length (kws) > 0 && error (
200- " You passed the keywords $(kws) , but only `empty_old_operators` is supported." ,
201- )
202- true
203- end
204+ if ! all (x -> first (x. args) ∈ (:empty_old_operators , :internal ), kws)
205+ error (
206+ " You passed the keywords $(kws) , but only `empty_old_operators`, `internal` are supported." ,
207+ )
208+ end
209+
210+ empty_old_operators_idx = findfirst (x -> first (x. args) == :empty_old_operators , kws)
211+ internal_idx = findfirst (x -> first (x. args) == :internal , kws)
212+
213+ empty_old_operators = if empty_old_operators_idx != = nothing
214+ @assert kws[empty_old_operators_idx]. head == :(= )
215+ kws[empty_old_operators_idx]. args[2 ]
216+ else
217+ true
218+ end
219+
220+ internal = if internal_idx != = nothing
221+ @assert kws[internal_idx]. head == :(= )
222+ kws[internal_idx]. args[2 ]:: Bool
223+ else
224+ false
225+ end
226+
204227 @gensym f skip type_requirements build_converters binary_exists unary_exists
205- binary_ex = _extend_binary_operator (f, type_requirements, build_converters)
206- unary_ex = _extend_unary_operator (f, type_requirements)
228+ binary_ex = _extend_binary_operator (f, type_requirements, build_converters, internal )
229+ unary_ex = _extend_unary_operator (f, type_requirements, internal )
207230 return quote
208231 local $ type_requirements
209232 local $ build_converters
292315
293316Similar to `@extend_operators`, but only extends operators already
294317defined in `Base`.
318+ `kws` can include `empty_old_operators` which is default `true`,
319+ and `internal` which is default `false`.
295320"""
296321macro extend_operators_base (operators, kws... )
297322 ex = _extend_operators (operators, true , kws, __module__)
@@ -402,4 +427,25 @@ function GenericOperatorEnum(;
402427 return operators
403428end
404429
430+ # Predefine the most common operators so the errors
431+ # are more informative
432+ function _overload_common_operators ()
433+ # ! format: off
434+ operators = OperatorEnum (
435+ Function[+ , - , * , / , ^ , max, min, mod],
436+ Function[
437+ sin, cos, tan, exp, log, log1p, log2, log10, sqrt, cbrt, abs, sinh,
438+ cosh, tanh, atan, asinh, acosh, round, sign, floor, ceil,
439+ ],
440+ Function[],
441+ Function[],
442+ )
443+ # ! format: on
444+ @extend_operators (operators, empty_old_operators = false , internal = true )
445+ empty! (LATEST_UNARY_OPERATOR_MAPPING)
446+ empty! (LATEST_BINARY_OPERATOR_MAPPING)
447+ return nothing
448+ end
449+ _overload_common_operators ()
450+
405451end
0 commit comments