Skip to content

Commit 5718dd3

Browse files
committed
Make @turbo parameterizable by user
1 parent 093b6b1 commit 5718dd3

File tree

4 files changed

+225
-101
lines changed

4 files changed

+225
-101
lines changed

src/EvaluateEquation.jl

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module EvaluateEquationModule
33
import LoopVectorization: @turbo, indices
44
import ..EquationModule: Node, string_tree
55
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
6-
import ..UtilsModule: @return_on_false, is_bad_array, vals
6+
import ..UtilsModule: @return_on_false, @maybe_turbo, is_bad_array, vals
77
import ..EquationUtilsModule: is_constant
88

99
macro return_on_check(val, T, n)
@@ -28,7 +28,7 @@ macro return_on_nonfinite_array(array, T, n)
2828
end
2929

3030
"""
31-
eval_tree_array(tree::Node, cX::AbstractMatrix{T}, operators::OperatorEnum)
31+
eval_tree_array(tree::Node, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool)
3232
3333
Evaluate a binary tree (equation) over a given input data matrix. The
3434
operators contain all of the operators used. This function fuses doublets
@@ -52,6 +52,7 @@ which speed up evaluation significantly.
5252
- `tree::Node`: The root node of the tree to evaluate.
5353
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on.
5454
- `operators::OperatorEnum`: The operators used in the tree.
55+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
5556
5657
# Returns
5758
- `(output, complete)::Tuple{AbstractVector{T}, Bool}`: the result,
@@ -61,31 +62,31 @@ which speed up evaluation significantly.
6162
to the equation.
6263
"""
6364
function eval_tree_array(
64-
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum
65+
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false
6566
)::Tuple{AbstractVector{T},Bool} where {T<:Real}
6667
n = size(cX, 2)
67-
result, finished = _eval_tree_array(tree, cX, operators)
68+
result, finished = _eval_tree_array(tree, cX, operators; turbo=turbo)
6869
@return_on_false finished result
6970
@return_on_nonfinite_array result T n
7071
return result, finished
7172
end
7273
function eval_tree_array(
73-
tree::Node{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum
74+
tree::Node{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum; turbo::Bool=false
7475
) where {T1<:Real,T2<:Real}
7576
T = promote_type(T1, T2)
7677
@warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)."
7778
tree = convert(Node{T}, tree)
7879
cX = convert(AbstractMatrix{T}, cX)
79-
return eval_tree_array(tree, cX, operators)
80+
return eval_tree_array(tree, cX, operators; turbo=turbo)
8081
end
8182

8283
function _eval_tree_array(
83-
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum
84+
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false
8485
)::Tuple{AbstractVector{T},Bool} where {T<:Real}
8586
# First, we see if there are only constants in the tree - meaning
8687
# we can just return the constant result.
8788
if tree.degree == 0
88-
return deg0_eval(tree, cX, operators)
89+
return deg0_eval(tree, cX)
8990
elseif is_constant(tree)
9091
# Speed hack for constant trees.
9192
result, flag = _eval_constant_tree(tree, operators)
@@ -94,34 +95,42 @@ function _eval_tree_array(
9495
elseif tree.degree == 1
9596
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
9697
# op(op2(x, y)), where x, y, z are constants or variables.
97-
return deg1_l2_ll0_lr0_eval(tree, cX, vals[tree.op], vals[tree.l.op], operators)
98+
return deg1_l2_ll0_lr0_eval(
99+
tree, cX, vals[tree.op], vals[tree.l.op], operators; turbo=turbo
100+
)
98101
elseif tree.l.degree == 1 && tree.l.l.degree == 0
99102
# op(op2(x)), where x is a constant or variable.
100-
return deg1_l1_ll0_eval(tree, cX, vals[tree.op], vals[tree.l.op], operators)
103+
return deg1_l1_ll0_eval(
104+
tree, cX, vals[tree.op], vals[tree.l.op], operators; turbo=turbo
105+
)
101106
else
102107
# op(x), for any x.
103-
return deg1_eval(tree, cX, vals[tree.op], operators)
108+
return deg1_eval(tree, cX, vals[tree.op], operators; turbo=turbo)
104109
end
105110
elseif tree.degree == 2
106111
# TODO - add op(op2(x, y), z) and op(x, op2(y, z))
107112
if tree.l.degree == 0 && tree.r.degree == 0
108113
# op(x, y), where x, y are constants or variables.
109-
return deg2_l0_r0_eval(tree, cX, vals[tree.op], operators)
114+
return deg2_l0_r0_eval(tree, cX, vals[tree.op], operators; turbo=turbo)
110115
elseif tree.l.degree == 0
111116
# op(x, y), where x is a constant or variable but y is not.
112-
return deg2_l0_eval(tree, cX, vals[tree.op], operators)
117+
return deg2_l0_eval(tree, cX, vals[tree.op], operators; turbo=turbo)
113118
elseif tree.r.degree == 0
114119
# op(x, y), where y is a constant or variable but x is not.
115-
return deg2_r0_eval(tree, cX, vals[tree.op], operators)
120+
return deg2_r0_eval(tree, cX, vals[tree.op], operators; turbo=turbo)
116121
else
117122
# op(x, y), for any x or y
118-
return deg2_eval(tree, cX, vals[tree.op], operators)
123+
return deg2_eval(tree, cX, vals[tree.op], operators; turbo=turbo)
119124
end
120125
end
121126
end
122127

123128
function deg2_eval(
124-
tree::Node{T}, cX::AbstractMatrix{T}, ::Val{op_idx}, operators::OperatorEnum
129+
tree::Node{T},
130+
cX::AbstractMatrix{T},
131+
::Val{op_idx},
132+
operators::OperatorEnum;
133+
turbo::Bool,
125134
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx}
126135
n = size(cX, 2)
127136
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators)
@@ -133,31 +142,35 @@ function deg2_eval(
133142
op = operators.binops[op_idx]
134143

135144
# We check inputs (and intermediates), not outputs.
136-
@turbo for j in indices(cumulator)
137-
x = op(cumulator[j], array2[j])
145+
@maybe_turbo turbo for j in indices(cumulator)
146+
x = op(cumulator[j], array2[j])::T
138147
cumulator[j] = x
139148
end
140149
# return (cumulator, finished_loop) #
141150
return (cumulator, true)
142151
end
143152

144153
function deg1_eval(
145-
tree::Node{T}, cX::AbstractMatrix{T}, ::Val{op_idx}, operators::OperatorEnum
154+
tree::Node{T},
155+
cX::AbstractMatrix{T},
156+
::Val{op_idx},
157+
operators::OperatorEnum;
158+
turbo::Bool,
146159
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx}
147160
n = size(cX, 2)
148161
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators)
149162
@return_on_false complete cumulator
150163
@return_on_nonfinite_array cumulator T n
151164
op = operators.unaops[op_idx]
152-
@turbo for j in indices(cumulator)
153-
x = op(cumulator[j])
165+
@maybe_turbo turbo for j in indices(cumulator)
166+
x = op(cumulator[j])::T
154167
cumulator[j] = x
155168
end
156169
return (cumulator, true) #
157170
end
158171

159172
function deg0_eval(
160-
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum
173+
tree::Node{T}, cX::AbstractMatrix{T}
161174
)::Tuple{AbstractVector{T},Bool} where {T<:Real}
162175
n = size(cX, 2)
163176
if tree.constant
@@ -172,7 +185,8 @@ function deg1_l2_ll0_lr0_eval(
172185
cX::AbstractMatrix{T},
173186
::Val{op_idx},
174187
::Val{op_l_idx},
175-
operators::OperatorEnum,
188+
operators::OperatorEnum;
189+
turbo::Bool,
176190
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,op_l_idx}
177191
n = size(cX, 2)
178192
op = operators.unaops[op_idx]
@@ -192,9 +206,9 @@ function deg1_l2_ll0_lr0_eval(
192206
@return_on_check val_ll T n
193207
feature_lr = tree.l.r.feature
194208
cumulator = Array{T,1}(undef, n)
195-
@turbo for j in indices((cX, cumulator), (2, 1))
196-
x_l = op_l(val_ll, cX[feature_lr, j])
197-
x = isfinite(x_l) ? op(x_l) : T(Inf) # These will get discovered by _eval_tree_array at end.
209+
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
210+
x_l = op_l(val_ll, cX[feature_lr, j])::T
211+
x = isfinite(x_l) ? op(x_l)::T : T(Inf) # These will get discovered by _eval_tree_array at end.
198212
cumulator[j] = x
199213
end
200214
return (cumulator, true)
@@ -203,19 +217,19 @@ function deg1_l2_ll0_lr0_eval(
203217
val_lr = tree.l.r.val::T
204218
@return_on_check val_lr T n
205219
cumulator = Array{T,1}(undef, n)
206-
@turbo for j in indices((cX, cumulator), (2, 1))
207-
x_l = op_l(cX[feature_ll, j], val_lr)
208-
x = isfinite(x_l) ? op(x_l) : T(Inf)
220+
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
221+
x_l = op_l(cX[feature_ll, j], val_lr)::T
222+
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
209223
cumulator[j] = x
210224
end
211225
return (cumulator, true)
212226
else
213227
feature_ll = tree.l.l.feature
214228
feature_lr = tree.l.r.feature
215229
cumulator = Array{T,1}(undef, n)
216-
@turbo for j in indices((cX, cumulator), (2, 1))
217-
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])
218-
x = isfinite(x_l) ? op(x_l) : T(Inf)
230+
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
231+
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T
232+
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
219233
cumulator[j] = x
220234
end
221235
return (cumulator, true)
@@ -228,7 +242,8 @@ function deg1_l1_ll0_eval(
228242
cX::AbstractMatrix{T},
229243
::Val{op_idx},
230244
::Val{op_l_idx},
231-
operators::OperatorEnum,
245+
operators::OperatorEnum;
246+
turbo::Bool,
232247
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,op_l_idx}
233248
n = size(cX, 2)
234249
op = operators.unaops[op_idx]
@@ -244,17 +259,21 @@ function deg1_l1_ll0_eval(
244259
else
245260
feature_ll = tree.l.l.feature
246261
cumulator = Array{T,1}(undef, n)
247-
@turbo for j in indices((cX, cumulator), (2, 1))
248-
x_l = op_l(cX[feature_ll, j])
249-
x = isfinite(x_l) ? op(x_l) : T(Inf)
262+
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
263+
x_l = op_l(cX[feature_ll, j])::T
264+
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
250265
cumulator[j] = x
251266
end
252267
return (cumulator, true)
253268
end
254269
end
255270

256271
function deg2_l0_r0_eval(
257-
tree::Node{T}, cX::AbstractMatrix{T}, ::Val{op_idx}, operators::OperatorEnum
272+
tree::Node{T},
273+
cX::AbstractMatrix{T},
274+
::Val{op_idx},
275+
operators::OperatorEnum;
276+
turbo::Bool,
258277
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx}
259278
n = size(cX, 2)
260279
op = operators.binops[op_idx]
@@ -271,33 +290,37 @@ function deg2_l0_r0_eval(
271290
val_l = tree.l.val::T
272291
@return_on_check val_l T n
273292
feature_r = tree.r.feature
274-
@turbo for j in indices((cX, cumulator), (2, 1))
275-
x = op(val_l, cX[feature_r, j])
293+
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
294+
x = op(val_l, cX[feature_r, j])::T
276295
cumulator[j] = x
277296
end
278297
elseif tree.r.constant
279298
cumulator = Array{T,1}(undef, n)
280299
feature_l = tree.l.feature
281300
val_r = tree.r.val::T
282301
@return_on_check val_r T n
283-
@turbo for j in indices((cX, cumulator), (2, 1))
284-
x = op(cX[feature_l, j], val_r)
302+
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
303+
x = op(cX[feature_l, j], val_r)::T
285304
cumulator[j] = x
286305
end
287306
else
288307
cumulator = Array{T,1}(undef, n)
289308
feature_l = tree.l.feature
290309
feature_r = tree.r.feature
291-
@turbo for j in indices((cX, cumulator), (2, 1))
292-
x = op(cX[feature_l, j], cX[feature_r, j])
310+
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
311+
x = op(cX[feature_l, j], cX[feature_r, j])::T
293312
cumulator[j] = x
294313
end
295314
end
296315
return (cumulator, true)
297316
end
298317

299318
function deg2_l0_eval(
300-
tree::Node{T}, cX::AbstractMatrix{T}, ::Val{op_idx}, operators::OperatorEnum
319+
tree::Node{T},
320+
cX::AbstractMatrix{T},
321+
::Val{op_idx},
322+
operators::OperatorEnum;
323+
turbo::Bool,
301324
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx}
302325
n = size(cX, 2)
303326
(cumulator, complete) = _eval_tree_array(tree.r, cX, operators)
@@ -307,22 +330,26 @@ function deg2_l0_eval(
307330
if tree.l.constant
308331
val = tree.l.val::T
309332
@return_on_check val T n
310-
@turbo for j in indices(cumulator)
311-
x = op(val, cumulator[j])
333+
@maybe_turbo turbo for j in indices(cumulator)
334+
x = op(val, cumulator[j])::T
312335
cumulator[j] = x
313336
end
314337
else
315338
feature = tree.l.feature
316-
@turbo for j in indices((cX, cumulator), (2, 1))
317-
x = op(cX[feature, j], cumulator[j])
339+
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
340+
x = op(cX[feature, j], cumulator[j])::T
318341
cumulator[j] = x
319342
end
320343
end
321344
return (cumulator, true)
322345
end
323346

324347
function deg2_r0_eval(
325-
tree::Node{T}, cX::AbstractMatrix{T}, ::Val{op_idx}, operators::OperatorEnum
348+
tree::Node{T},
349+
cX::AbstractMatrix{T},
350+
::Val{op_idx},
351+
operators::OperatorEnum;
352+
turbo::Bool,
326353
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx}
327354
n = size(cX, 2)
328355
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators)
@@ -332,14 +359,14 @@ function deg2_r0_eval(
332359
if tree.r.constant
333360
val = tree.r.val::T
334361
@return_on_check val T n
335-
@turbo for j in indices(cumulator)
336-
x = op(cumulator[j], val)
362+
@maybe_turbo turbo for j in indices(cumulator)
363+
x = op(cumulator[j], val)::T
337364
cumulator[j] = x
338365
end
339366
else
340367
feature = tree.r.feature
341-
@turbo for j in indices((cX, cumulator), (2, 1))
342-
x = op(cumulator[j], cX[feature, j])
368+
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
369+
x = op(cumulator[j], cX[feature, j])::T
343370
cumulator[j] = x
344371
end
345372
end

0 commit comments

Comments
 (0)