Skip to content

Commit e82db51

Browse files
committed
More type stability throughout package
1 parent c00186b commit e82db51

File tree

5 files changed

+61
-41
lines changed

5 files changed

+61
-41
lines changed

src/Equation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,14 @@ Convert an equation to a string.
337337
to print for each feature.
338338
"""
339339
function string_tree(
340-
tree::Node,
340+
tree::Node{T},
341341
operators::AbstractOperatorEnum;
342342
bracketed::Bool=false,
343343
varMap::Union{Array{String,1},Nothing}=nothing,
344-
)::String
344+
)::String where {T}
345345
if tree.degree == 0
346346
if tree.constant
347-
return string(tree.val)
347+
return string(tree.val::T)
348348
else
349349
if varMap === nothing
350350
return "x$(tree.feature)"

src/EquationUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ end
7373
function get_constants(tree::Node{T})::AbstractVector{T} where {T}
7474
if tree.degree == 0
7575
if tree.constant
76-
return [tree.val]
76+
return [tree.val::T]
7777
else
7878
return T[]
7979
end

src/InterfaceSymbolicUtils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ end
1919
subs_bad(x) = isgood(x) ? x : Inf
2020

2121
function parse_tree_to_eqs(
22-
tree::Node, operators::AbstractOperatorEnum, index_functions::Bool=false
23-
)
22+
tree::Node{T}, operators::AbstractOperatorEnum, index_functions::Bool=false
23+
) where {T}
2424
if tree.degree == 0
2525
# Return constant if needed
26-
tree.constant && return subs_bad(tree.val)
26+
tree.constant && return subs_bad(tree.val::T)
2727
return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)"))
2828
end
2929
# Collect the next children

src/OperatorEnumConstruction.jl

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,41 +77,53 @@ function create_construction_helpers!(
7777
Base.MainInclude.eval(
7878
quote
7979
import DynamicExpressions: Node
80+
81+
function $f(l::Node{T}, r::Node{T}) where {T<:$type_requirements}
82+
if (l.degree == 0 && l.constant && r.degree == 0 && r.constant)
83+
Node(T; val=$f(l.val::T, r.val::T))
84+
else
85+
Node($op, l, r)
86+
end
87+
end
88+
function $f(l::Node{T}, r::T) where {T<:$type_requirements}
89+
if l.degree == 0 && l.constant
90+
Node(T; val=$f(l.val::T, r))
91+
else
92+
Node($op, l, Node(T; val=r))
93+
end
94+
end
95+
function $f(l::T, r::Node{T}) where {T<:$type_requirements}
96+
if r.degree == 0 && r.constant
97+
Node(T; val=$f(l, r.val::T))
98+
else
99+
Node($op, Node(T; val=l), r)
100+
end
101+
end
102+
103+
# Converters:
80104
function $f(
81105
l::Node{T1}, r::Node{T2}
82106
) where {T1<:$type_requirements,T2<:$type_requirements}
83107
T = promote_type(T1, T2)
84108
l = convert(Node{T}, l)
85109
r = convert(Node{T}, r)
86-
if (l.constant && r.constant)
87-
return Node(; val=$f(l.val, r.val))
88-
else
89-
return Node($op, l, r)
90-
end
110+
return $f(l, r)
91111
end
92112
function $f(
93113
l::Node{T1}, r::T2
94114
) where {T1<:$type_requirements,T2<:$type_requirements}
95115
T = promote_type(T1, T2)
96116
l = convert(Node{T}, l)
97117
r = convert(T, r)
98-
return if l.constant
99-
Node(; val=$f(l.val, r))
100-
else
101-
Node($op, l, Node(; val=r))
102-
end
118+
return $f(l, r)
103119
end
104120
function $f(
105121
l::T1, r::Node{T2}
106122
) where {T1<:$type_requirements,T2<:$type_requirements}
107123
T = promote_type(T1, T2)
108124
l = convert(T, l)
109125
r = convert(Node{T}, r)
110-
return if r.constant
111-
Node(; val=$f(l, r.val))
112-
else
113-
Node($op, Node(; val=l), r)
114-
end
126+
return $f(l, r)
115127
end
116128
end,
117129
)
@@ -128,7 +140,11 @@ function create_construction_helpers!(
128140
quote
129141
import DynamicExpressions: Node
130142
function $f(l::Node{T})::Node{T} where {T<:$type_requirements}
131-
return l.constant ? Node(; val=$f(l.val)) : Node($op, l)
143+
return if (l.degree == 0 && l.constant)
144+
Node(T; val=$f(l.val::T))
145+
else
146+
Node($op, l)
147+
end
132148
end
133149
end,
134150
)
@@ -151,7 +167,8 @@ It will automatically compute derivatives with `Zygote.jl`.
151167
- `extend_user_operators::Bool=false`: Whether to extend the user's operators to
152168
`Node` types. All operators defined in `Base` will already be extended automatically.
153169
- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
154-
and evaluating node types. Turn this off when doing precompilation.
170+
and evaluating node types. Turn this off when doing precompilation. Note that these
171+
are *not* needed for the package to work; they are purely for convenience.
155172
"""
156173
function OperatorEnum(;
157174
binary_operators=[],
@@ -241,7 +258,8 @@ and `(::Node)(X)`.
241258
- `extend_user_operators::Bool=false`: Whether to extend the user's operators to
242259
`Node` types. All operators defined in `Base` will already be extended automatically.
243260
- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
244-
and evaluating node types. Turn this off when doing precompilation.
261+
and evaluating node types. Turn this off when doing precompilation. Note that these
262+
are *not* needed for the package to work; they are purely for convenience.
245263
"""
246264
function GenericOperatorEnum(;
247265
binary_operators=[],

src/SimplifyEquation.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@ function combine_operators(
3737
tree.r = tree.l
3838
tree.l = tmp
3939
end
40-
topconstant = tree.r.val
40+
topconstant = tree.r.val::T
4141
# Simplify down first
4242
below = tree.l
4343
if below.degree == 2 && below.op == op
4444
if below.l.constant
4545
tree = below
46-
tree.l.val = operators.binops[op](tree.l.val, topconstant)
46+
tree.l.val = operators.binops[op](tree.l.val::T, topconstant)
4747
elseif below.r.constant
4848
tree = below
49-
tree.r.val = operators.binops[op](tree.r.val, topconstant)
49+
tree.r.val = operators.binops[op](tree.r.val::T, topconstant)
5050
end
5151
end
5252
end
@@ -60,15 +60,15 @@ function combine_operators(
6060
#(const - (const - var)) => (var - const)
6161
l = tree.l
6262
r = tree.r
63-
simplified_const = -(l.val - r.l.val) #neg(sub(l.val, r.l.val))
63+
simplified_const = -(l.val::T - r.l.val::T) #neg(sub(l.val, r.l.val))
6464
tree.l = tree.r.r
6565
tree.r = l
6666
tree.r.val = simplified_const
6767
elseif tree.r.r.constant
6868
#(const - (var - const)) => (const - var)
6969
l = tree.l
7070
r = tree.r
71-
simplified_const = l.val + r.r.val #plus(l.val, r.r.val)
71+
simplified_const = l.val::T + r.r.val::T #plus(l.val, r.r.val)
7272
tree.r = tree.r.l
7373
tree.l.val = simplified_const
7474
end
@@ -79,15 +79,15 @@ function combine_operators(
7979
#((const - var) - const) => (const - var)
8080
l = tree.l
8181
r = tree.r
82-
simplified_const = l.l.val - r.val#sub(l.l.val, r.val)
82+
simplified_const = l.l.val::T - r.val::T#sub(l.l.val, r.val)
8383
tree.r = tree.l.r
8484
tree.l = r
8585
tree.l.val = simplified_const
8686
elseif tree.l.r.constant
8787
#((var - const) - const) => (var - const)
8888
l = tree.l
8989
r = tree.r
90-
simplified_const = r.val + l.r.val #plus(r.val, l.r.val)
90+
simplified_const = r.val::T + l.r.val::T #plus(r.val, l.r.val)
9191
tree.l = tree.l.l
9292
tree.r.val = simplified_const
9393
end
@@ -107,13 +107,15 @@ function simplify_tree(
107107
get!(id_map, tree) do
108108
if tree.degree == 1
109109
tree.l = simplify_tree(tree.l, operators, id_map)
110-
l = tree.l.val
111-
if tree.l.degree == 0 && tree.l.constant && isgood(l)
112-
out = operators.unaops[tree.op](l)
113-
if isbad(out)
114-
return tree
110+
if tree.l.degree == 0 && tree.l.constant
111+
l = tree.l.val::T
112+
if isgood(l)
113+
out = operators.unaops[tree.op](l)
114+
if isbad(out)
115+
return tree
116+
end
117+
return Node(T; val=convert(T, out))
115118
end
116-
return Node(; val=convert(T, out))
117119
end
118120
elseif tree.degree == 2
119121
tree.l = simplify_tree(tree.l, operators, id_map)
@@ -126,8 +128,8 @@ function simplify_tree(
126128
)
127129
if constantsBelow
128130
# NaN checks:
129-
l = tree.l.val
130-
r = tree.r.val
131+
l = tree.l.val::T
132+
r = tree.r.val::T
131133
if isbad(l) || isbad(r)
132134
return tree
133135
end
@@ -137,7 +139,7 @@ function simplify_tree(
137139
if isbad(out)
138140
return tree
139141
end
140-
return Node(; val=convert(T, out))
142+
return Node(T; val=convert(T, out))
141143
end
142144
end
143145
return tree

0 commit comments

Comments
 (0)