Skip to content

Commit cb4cec3

Browse files
committed
Make kernels unspecialized
1 parent 9b9fa9d commit cb4cec3

File tree

1 file changed

+36
-38
lines changed

1 file changed

+36
-38
lines changed

src/EvaluateEquation.jl

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ end
8888
function _eval_tree_array(
8989
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo}
9090
)::Tuple{AbstractVector{T},Bool} where {T<:Real,turbo}
91+
n = size(cX, 2)
9192
# First, we see if there are only constants in the tree - meaning
9293
# we can just return the constant result.
9394
if tree.degree == 0
@@ -107,56 +108,59 @@ function _eval_tree_array(
107108
# op(op2(x)), where x is a constant or variable.
108109
op_l = operators.unaops[tree.l.op]
109110
return deg1_l1_ll0_eval(tree, cX, op, op_l, Val(turbo))
110-
else
111-
# op(x), for any x.
112-
return deg1_eval(tree, cX, op, operators, Val(turbo))
113111
end
112+
113+
# op(x), for any x.
114+
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
115+
@return_on_false complete cumulator
116+
@return_on_nonfinite_array cumulator T n
117+
return deg1_eval(cumulator, op, Val(turbo))
118+
114119
elseif tree.degree == 2
115-
# TODO - add op(op2(x, y), z) and op(x, op2(y, z))
116120
op = operators.binops[tree.op]
121+
# TODO - add op(op2(x, y), z) and op(x, op2(y, z))
122+
# op(x, y), where x, y are constants or variables.
117123
if tree.l.degree == 0 && tree.r.degree == 0
118-
# op(x, y), where x, y are constants or variables.
119124
return deg2_l0_r0_eval(tree, cX, op, Val(turbo))
120-
elseif tree.l.degree == 0
121-
# op(x, y), where x is a constant or variable but y is not.
122-
return deg2_l0_eval(tree, cX, op, operators, Val(turbo))
123125
elseif tree.r.degree == 0
126+
(cumulator_l, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
127+
@return_on_false complete cumulator_l
128+
@return_on_nonfinite_array cumulator_l T n
124129
# op(x, y), where y is a constant or variable but x is not.
125-
return deg2_r0_eval(tree, cX, op, operators, Val(turbo))
126-
else
127-
# op(x, y), for any x or y
128-
return deg2_eval(tree, cX, op, operators, Val(turbo))
130+
return deg2_r0_eval(tree, cumulator_l, cX, op, Val(turbo))
131+
elseif tree.l.degree == 0
132+
(cumulator_r, complete) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
133+
@return_on_false complete cumulator_r
134+
@return_on_nonfinite_array cumulator_r T n
135+
# op(x, y), where x is a constant or variable but y is not.
136+
return deg2_l0_eval(tree, cumulator_r, cX, op, Val(turbo))
129137
end
138+
(cumulator_l, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
139+
@return_on_false complete cumulator_l
140+
@return_on_nonfinite_array cumulator_l T n
141+
(cumulator_r, complete) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
142+
@return_on_false complete cumulator_r
143+
@return_on_nonfinite_array cumulator_r T n
144+
# op(x, y), for any x or y
145+
return deg2_eval(cumulator_l, cumulator_r, op, Val(turbo))
130146
end
131147
end
132148

133149
function deg2_eval(
134-
tree::Node{T}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum, ::Val{turbo}
150+
cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{turbo}
135151
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
136-
n = size(cX, 2)
137-
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
138-
@return_on_false complete cumulator
139-
@return_on_nonfinite_array cumulator T n
140-
(array2, complete2) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
141-
@return_on_false complete2 cumulator
142-
@return_on_nonfinite_array array2 T n
143-
144152
# We check inputs (and intermediates), not outputs.
145-
@maybe_turbo turbo for j in indices(cumulator)
146-
x = op(cumulator[j], array2[j])::T
147-
cumulator[j] = x
153+
@maybe_turbo turbo for j in indices(cumulator_l)
154+
x = op(cumulator_l[j], cumulator_r[j])::T
155+
cumulator_l[j] = x
148156
end
149157
# return (cumulator, finished_loop) #
150-
return (cumulator, true)
158+
return (cumulator_l, true)
151159
end
152160

153161
function deg1_eval(
154-
tree::Node{T}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum, ::Val{turbo}
162+
cumulator::AbstractVector{T}, op::F, ::Val{turbo}
155163
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
156-
n = size(cX, 2)
157-
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
158-
@return_on_false complete cumulator
159-
@return_on_nonfinite_array cumulator T n
160164
@maybe_turbo turbo for j in indices(cumulator)
161165
x = op(cumulator[j])::T
162166
cumulator[j] = x
@@ -294,12 +298,9 @@ end
294298

295299
# op(x, y) for x variable/constant, y arbitrary
296300
function deg2_l0_eval(
297-
tree::Node{T}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum, ::Val{turbo}
301+
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
298302
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
299303
n = size(cX, 2)
300-
(cumulator, complete) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
301-
@return_on_false complete cumulator
302-
@return_on_nonfinite_array cumulator T n
303304
if tree.l.constant
304305
val = tree.l.val::T
305306
@return_on_check val T n
@@ -319,12 +320,9 @@ end
319320

320321
# op(x, y) for x arbitrary, y variable/constant
321322
function deg2_r0_eval(
322-
tree::Node{T}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum, ::Val{turbo}
323+
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
323324
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
324325
n = size(cX, 2)
325-
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
326-
@return_on_false complete cumulator
327-
@return_on_nonfinite_array cumulator T n
328326
if tree.r.constant
329327
val = tree.r.val::T
330328
@return_on_check val T n

0 commit comments

Comments
 (0)