Skip to content

Commit e0196b2

Browse files
committed
Use Val for turbo mode to speed up
1 parent a1f1c90 commit e0196b2

File tree

1 file changed

+41
-36
lines changed

1 file changed

+41
-36
lines changed

src/EvaluateEquation.jl

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ function eval_tree_array(
6565
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false
6666
)::Tuple{AbstractVector{T},Bool} where {T<:Real}
6767
n = size(cX, 2)
68-
result, finished = _eval_tree_array(tree, cX, operators; turbo=turbo)
68+
if turbo
69+
@assert T in (Float32, Float64)
70+
end
71+
result, finished = _eval_tree_array(
72+
tree, cX, operators, (turbo ? Val(true) : Val(false))
73+
)
6974
@return_on_false finished result
7075
@return_on_nonfinite_array result T n
7176
return result, finished
@@ -81,8 +86,8 @@ function eval_tree_array(
8186
end
8287

8388
function _eval_tree_array(
84-
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false
85-
)::Tuple{AbstractVector{T},Bool} where {T<:Real}
89+
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo}
90+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,turbo}
8691
# First, we see if there are only constants in the tree - meaning
8792
# we can just return the constant result.
8893
if tree.degree == 0
@@ -96,31 +101,31 @@ function _eval_tree_array(
96101
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
97102
# op(op2(x, y)), where x, y, z are constants or variables.
98103
return deg1_l2_ll0_lr0_eval(
99-
tree, cX, vals[tree.op], vals[tree.l.op], operators; turbo=turbo
104+
tree, cX, vals[tree.op], vals[tree.l.op], operators, Val(turbo)
100105
)
101106
elseif tree.l.degree == 1 && tree.l.l.degree == 0
102107
# op(op2(x)), where x is a constant or variable.
103108
return deg1_l1_ll0_eval(
104-
tree, cX, vals[tree.op], vals[tree.l.op], operators; turbo=turbo
109+
tree, cX, vals[tree.op], vals[tree.l.op], operators, Val(turbo)
105110
)
106111
else
107112
# op(x), for any x.
108-
return deg1_eval(tree, cX, vals[tree.op], operators; turbo=turbo)
113+
return deg1_eval(tree, cX, vals[tree.op], operators, Val(turbo))
109114
end
110115
elseif tree.degree == 2
111116
# TODO - add op(op2(x, y), z) and op(x, op2(y, z))
112117
if tree.l.degree == 0 && tree.r.degree == 0
113118
# op(x, y), where x, y are constants or variables.
114-
return deg2_l0_r0_eval(tree, cX, vals[tree.op], operators; turbo=turbo)
119+
return deg2_l0_r0_eval(tree, cX, vals[tree.op], operators, Val(turbo))
115120
elseif tree.l.degree == 0
116121
# op(x, y), where x is a constant or variable but y is not.
117-
return deg2_l0_eval(tree, cX, vals[tree.op], operators; turbo=turbo)
122+
return deg2_l0_eval(tree, cX, vals[tree.op], operators, Val(turbo))
118123
elseif tree.r.degree == 0
119124
# op(x, y), where y is a constant or variable but x is not.
120-
return deg2_r0_eval(tree, cX, vals[tree.op], operators; turbo=turbo)
125+
return deg2_r0_eval(tree, cX, vals[tree.op], operators, Val(turbo))
121126
else
122127
# op(x, y), for any x or y
123-
return deg2_eval(tree, cX, vals[tree.op], operators; turbo=turbo)
128+
return deg2_eval(tree, cX, vals[tree.op], operators, Val(turbo))
124129
end
125130
end
126131
end
@@ -129,14 +134,14 @@ function deg2_eval(
129134
tree::Node{T},
130135
cX::AbstractMatrix{T},
131136
::Val{op_idx},
132-
operators::OperatorEnum;
133-
turbo::Bool,
134-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx}
137+
operators::OperatorEnum,
138+
::Val{turbo},
139+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
135140
n = size(cX, 2)
136-
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators)
141+
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
137142
@return_on_false complete cumulator
138143
@return_on_nonfinite_array cumulator T n
139-
(array2, complete2) = _eval_tree_array(tree.r, cX, operators)
144+
(array2, complete2) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
140145
@return_on_false complete2 cumulator
141146
@return_on_nonfinite_array array2 T n
142147
op = operators.binops[op_idx]
@@ -154,11 +159,11 @@ function deg1_eval(
154159
tree::Node{T},
155160
cX::AbstractMatrix{T},
156161
::Val{op_idx},
157-
operators::OperatorEnum;
158-
turbo::Bool,
159-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx}
162+
operators::OperatorEnum,
163+
::Val{turbo},
164+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
160165
n = size(cX, 2)
161-
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators)
166+
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
162167
@return_on_false complete cumulator
163168
@return_on_nonfinite_array cumulator T n
164169
op = operators.unaops[op_idx]
@@ -185,9 +190,9 @@ function deg1_l2_ll0_lr0_eval(
185190
cX::AbstractMatrix{T},
186191
::Val{op_idx},
187192
::Val{op_l_idx},
188-
operators::OperatorEnum;
189-
turbo::Bool,
190-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,op_l_idx}
193+
operators::OperatorEnum,
194+
::Val{turbo},
195+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,op_l_idx,turbo}
191196
n = size(cX, 2)
192197
op = operators.unaops[op_idx]
193198
op_l = operators.binops[op_l_idx]
@@ -242,9 +247,9 @@ function deg1_l1_ll0_eval(
242247
cX::AbstractMatrix{T},
243248
::Val{op_idx},
244249
::Val{op_l_idx},
245-
operators::OperatorEnum;
246-
turbo::Bool,
247-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,op_l_idx}
250+
operators::OperatorEnum,
251+
::Val{turbo},
252+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,op_l_idx,turbo}
248253
n = size(cX, 2)
249254
op = operators.unaops[op_idx]
250255
op_l = operators.unaops[op_l_idx]
@@ -272,9 +277,9 @@ function deg2_l0_r0_eval(
272277
tree::Node{T},
273278
cX::AbstractMatrix{T},
274279
::Val{op_idx},
275-
operators::OperatorEnum;
276-
turbo::Bool,
277-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx}
280+
operators::OperatorEnum,
281+
::Val{turbo},
282+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
278283
n = size(cX, 2)
279284
op = operators.binops[op_idx]
280285
if tree.l.constant && tree.r.constant
@@ -319,11 +324,11 @@ function deg2_l0_eval(
319324
tree::Node{T},
320325
cX::AbstractMatrix{T},
321326
::Val{op_idx},
322-
operators::OperatorEnum;
323-
turbo::Bool,
324-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx}
327+
operators::OperatorEnum,
328+
::Val{turbo},
329+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
325330
n = size(cX, 2)
326-
(cumulator, complete) = _eval_tree_array(tree.r, cX, operators)
331+
(cumulator, complete) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
327332
@return_on_false complete cumulator
328333
@return_on_nonfinite_array cumulator T n
329334
op = operators.binops[op_idx]
@@ -348,11 +353,11 @@ function deg2_r0_eval(
348353
tree::Node{T},
349354
cX::AbstractMatrix{T},
350355
::Val{op_idx},
351-
operators::OperatorEnum;
352-
turbo::Bool,
353-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx}
356+
operators::OperatorEnum,
357+
::Val{turbo},
358+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
354359
n = size(cX, 2)
355-
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators)
360+
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
356361
@return_on_false complete cumulator
357362
@return_on_nonfinite_array cumulator T n
358363
op = operators.binops[op_idx]

0 commit comments

Comments
 (0)