Skip to content

Commit 9312a78

Browse files
committed
Fix speed regression due to type union
1 parent 94019c6 commit 9312a78

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

src/Equation.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ function Base.convert(
8282
get!(id_map, tree) do
8383
if tree.degree == 0
8484
if tree.constant
85-
Node(0, tree.constant, convert(T1, tree.val))
85+
Node(0, tree.constant, convert(T1, (tree.val::T2)))
8686
else
8787
Node(T1, 0, tree.constant, nothing, tree.feature)
8888
end
@@ -180,7 +180,7 @@ function set_node!(tree::Node{T}, new_tree::Node{T}) where {T}
180180
if new_tree.degree == 0
181181
tree.constant = new_tree.constant
182182
if new_tree.constant
183-
tree.val = new_tree.val
183+
tree.val = new_tree.val::T
184184
else
185185
tree.feature = new_tree.feature
186186
end
@@ -215,7 +215,7 @@ end
215215
function copy_node_break_topology(tree::Node{T})::Node{T} where {T}
216216
if tree.degree == 0
217217
if tree.constant
218-
Node(; val=copy(tree.val))
218+
Node(; val=copy(tree.val::T))
219219
else
220220
Node(T; feature=copy(tree.feature))
221221
end
@@ -248,7 +248,7 @@ function copy_node_with_topology(
248248
get!(id_map, tree) do
249249
if tree.degree == 0
250250
if tree.constant
251-
Node(; val=copy(tree.val))
251+
Node(; val=copy(tree.val::T))
252252
else
253253
Node(T; feature=copy(tree.feature))
254254
end
@@ -355,11 +355,11 @@ function print_tree(
355355
return println(string_tree(tree, operators; varMap=varMap))
356356
end
357357

358-
function Base.hash(tree::Node)::UInt
358+
function Base.hash(tree::Node{T})::UInt where {T}
359359
if tree.degree == 0
360360
if tree.constant
361361
# tree.val used.
362-
return hash((0, tree.val))
362+
return hash((0, tree.val::T))
363363
else
364364
# tree.feature used.
365365
return hash((1, tree.feature))

src/EvaluateEquation.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ function deg0_eval(
156156
)::Tuple{AbstractVector{T},Bool} where {T<:Real}
157157
n = size(cX, 2)
158158
if tree.constant
159-
return (fill(tree.val, n), true)
159+
return (fill(tree.val::T, n), true)
160160
else
161161
return (cX[tree.feature, :], true)
162162
end
@@ -173,8 +173,8 @@ function deg1_l2_ll0_lr0_eval(
173173
op = operators.unaops[op_idx]
174174
op_l = operators.binops[op_l_idx]
175175
if tree.l.l.constant && tree.l.r.constant
176-
val_ll = tree.l.l.val
177-
val_lr = tree.l.r.val
176+
val_ll = tree.l.l.val::T
177+
val_lr = tree.l.r.val::T
178178
@return_on_check val_ll T n
179179
@return_on_check val_lr T n
180180
x_l = op_l(val_ll, val_lr)::T
@@ -183,7 +183,7 @@ function deg1_l2_ll0_lr0_eval(
183183
@return_on_check x T n
184184
return (fill(x, n), true)
185185
elseif tree.l.l.constant
186-
val_ll = tree.l.l.val
186+
val_ll = tree.l.l.val::T
187187
@return_on_check val_ll T n
188188
feature_lr = tree.l.r.feature
189189
cumulator = Array{T,1}(undef, n)
@@ -195,7 +195,7 @@ function deg1_l2_ll0_lr0_eval(
195195
return (cumulator, true)
196196
elseif tree.l.r.constant
197197
feature_ll = tree.l.l.feature
198-
val_lr = tree.l.r.val
198+
val_lr = tree.l.r.val::T
199199
@return_on_check val_lr T n
200200
cumulator = Array{T,1}(undef, n)
201201
@inbounds @simd for j in 1:n
@@ -229,7 +229,7 @@ function deg1_l1_ll0_eval(
229229
op = operators.unaops[op_idx]
230230
op_l = operators.unaops[op_l_idx]
231231
if tree.l.l.constant
232-
val_ll = tree.l.l.val
232+
val_ll = tree.l.l.val::T
233233
@return_on_check val_ll T n
234234
x_l = op_l(val_ll)::T
235235
@return_on_check x_l T n
@@ -254,16 +254,16 @@ function deg2_l0_r0_eval(
254254
n = size(cX, 2)
255255
op = operators.binops[op_idx]
256256
if tree.l.constant && tree.r.constant
257-
val_l = tree.l.val
257+
val_l = tree.l.val::T
258258
@return_on_check val_l T n
259-
val_r = tree.r.val
259+
val_r = tree.r.val::T
260260
@return_on_check val_r T n
261261
x = op(val_l, val_r)::T
262262
@return_on_check x T n
263263
return (fill(x, n), true)
264264
elseif tree.l.constant
265265
cumulator = Array{T,1}(undef, n)
266-
val_l = tree.l.val
266+
val_l = tree.l.val::T
267267
@return_on_check val_l T n
268268
feature_r = tree.r.feature
269269
@inbounds @simd for j in 1:n
@@ -273,7 +273,7 @@ function deg2_l0_r0_eval(
273273
elseif tree.r.constant
274274
cumulator = Array{T,1}(undef, n)
275275
feature_l = tree.l.feature
276-
val_r = tree.r.val
276+
val_r = tree.r.val::T
277277
@return_on_check val_r T n
278278
@inbounds @simd for j in 1:n
279279
x = op(cX[feature_l, j], val_r)::T
@@ -300,7 +300,7 @@ function deg2_l0_eval(
300300
@return_on_nonfinite_array cumulator T n
301301
op = operators.binops[op_idx]
302302
if tree.l.constant
303-
val = tree.l.val
303+
val = tree.l.val::T
304304
@return_on_check val T n
305305
@inbounds @simd for j in 1:n
306306
x = op(val, cumulator[j])::T
@@ -325,7 +325,7 @@ function deg2_r0_eval(
325325
@return_on_nonfinite_array cumulator T n
326326
op = operators.binops[op_idx]
327327
if tree.r.constant
328-
val = tree.r.val
328+
val = tree.r.val::T
329329
@return_on_check val T n
330330
@inbounds @simd for j in 1:n
331331
x = op(cumulator[j], val)::T
@@ -361,7 +361,7 @@ function _eval_constant_tree(
361361
end
362362

363363
@inline function deg0_eval_constant(tree::Node{T})::Tuple{T,Bool} where {T<:Real}
364-
return tree.val, true
364+
return tree.val::T, true
365365
end
366366

367367
function deg1_eval_constant(
@@ -397,7 +397,7 @@ function differentiable_eval_tree_array(
397397
n = size(cX, 2)
398398
if tree.degree == 0
399399
if tree.constant
400-
return (ones(T, n) .* convert(T, tree.val), true)
400+
return (ones(T, n) .* (tree.val::T), true)
401401
else
402402
return (cX[tree.feature, :], true)
403403
end
@@ -476,7 +476,7 @@ function eval_tree_array(
476476
) where {T,N}
477477
if tree.degree == 0
478478
if tree.constant
479-
return tree.val, true
479+
return (tree.val::T), true
480480
else
481481
if N == 1
482482
return cX[tree.feature], true

0 commit comments

Comments
 (0)