@@ -2,12 +2,12 @@ module OperatorEnumConstructionModule
22
33import Zygote: gradient
44import .. UtilsModule: max_ops
5- import .. OperatorEnumModule: OperatorEnum, GenericOperatorEnum
5+ import .. OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
66import .. EquationModule: string_tree, Node
77import .. EvaluateEquationModule: eval_tree_array
88import .. EvaluateEquationDerivativeModule: eval_grad_tree_array
99
10- function create_evaluation_helper_functions (operators:: OperatorEnum )
10+ function create_evaluation_helpers! (operators:: OperatorEnum )
1111 @eval begin
1212 Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
1313 Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
@@ -37,7 +37,7 @@ function create_evaluation_helper_functions(operators::OperatorEnum)
3737 end
3838end
3939
40- function create_evaluation_helper_functions (operators:: GenericOperatorEnum )
40+ function create_evaluation_helpers! (operators:: GenericOperatorEnum )
4141 @eval begin
4242 Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
4343 Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
@@ -54,11 +54,14 @@ function create_evaluation_helper_functions(operators::GenericOperatorEnum)
5454 end
5555end
5656
57- function create_node_helper_functions (
57+ function create_construction_helpers! (
5858 operators:: AbstractOperatorEnum ; extend_user_operators:: Bool = false
5959)
60- for (op, f) in enumerate (map (Symbol, binary_operators))
61- if typeof (operators) <: OperatorEnum
60+ is_scalar_operator_enum = typeof (operators) <: OperatorEnum
61+ type_requirements = is_scalar_operator_enum ? Real : Any
62+
63+ for (op, f) in enumerate (map (Symbol, operators. binops))
64+ if is_scalar_operator_enum
6265 f = if f in [:pow , :safe_pow ]
6366 Symbol (^ )
6467 else
@@ -74,7 +77,9 @@ function create_node_helper_functions(
7477 Base. MainInclude. eval (
7578 quote
7679 import DynamicExpressions: Node
77- function $f (l:: Node{T1} , r:: Node{T2} ) where {T1<: Real ,T2<: Real }
80+ function $f (
81+ l:: Node{T1} , r:: Node{T2}
82+ ) where {T1<: $type_requirements ,T2<: $type_requirements }
7883 T = promote_type (T1, T2)
7984 l = convert (Node{T}, l)
8085 r = convert (Node{T}, r)
@@ -84,7 +89,9 @@ function create_node_helper_functions(
8489 return Node ($ op, l, r)
8590 end
8691 end
87- function $f (l:: Node{T1} , r:: T2 ) where {T1<: Real ,T2<: Real }
92+ function $f (
93+ l:: Node{T1} , r:: T2
94+ ) where {T1<: $type_requirements ,T2<: $type_requirements }
8895 T = promote_type (T1, T2)
8996 l = convert (Node{T}, l)
9097 r = convert (T, r)
@@ -94,7 +101,9 @@ function create_node_helper_functions(
94101 Node ($ op, l, Node (; val= r))
95102 end
96103 end
97- function $f (l:: T1 , r:: Node{T2} ) where {T1<: Real ,T2<: Real }
104+ function $f (
105+ l:: T1 , r:: Node{T2}
106+ ) where {T1<: $type_requirements ,T2<: $type_requirements }
98107 T = promote_type (T1, T2)
99108 l = convert (T, l)
100109 r = convert (Node{T}, r)
@@ -108,7 +117,7 @@ function create_node_helper_functions(
108117 )
109118 end
110119 # Redefine Base operations:
111- for (op, f) in enumerate (map (Symbol, unary_operators ))
120+ for (op, f) in enumerate (map (Symbol, operators . unaops ))
112121 if isdefined (Base, f)
113122 f = :(Base.$ (f))
114123 elseif ! extend_user_operators
@@ -118,7 +127,7 @@ function create_node_helper_functions(
118127 Base. MainInclude. eval (
119128 quote
120129 import DynamicExpressions: Node
121- function $f (l:: Node{T} ):: Node{T} where {T<: Real }
130+ function $f (l:: Node{T} ):: Node{T} where {T<: $type_requirements }
122131 return l. constant ? Node (; val= $ f (l. val)) : Node ($ op, l)
123132 end
124133 end ,
@@ -209,8 +218,8 @@ function OperatorEnum(;
209218 )
210219
211220 if define_helper_functions
212- create_node_helper_functions (operators; extend_user_operators= extend_user_operators)
213- create_evaluation_helper_functions (operators)
221+ create_construction_helpers! (operators; extend_user_operators= extend_user_operators)
222+ create_evaluation_helpers! (operators)
214223 end
215224
216225 return operators
@@ -249,8 +258,8 @@ function GenericOperatorEnum(;
249258 operators = GenericOperatorEnum (binary_operators, unary_operators)
250259
251260 if define_helper_functions
252- create_node_helper_functions (operators; extend_user_operators= extend_user_operators)
253- create_evaluation_helper_functions (operators)
261+ create_construction_helpers! (operators; extend_user_operators= extend_user_operators)
262+ create_evaluation_helpers! (operators)
254263 end
255264
256265 return operators
0 commit comments