Skip to content

Commit 62c6db7

Browse files
committed
Use function barriers not specialized to operators
1 parent 7334f79 commit 62c6db7

File tree

5 files changed

+84
-119
lines changed

5 files changed

+84
-119
lines changed

src/EvaluateEquation.jl

Lines changed: 47 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module EvaluateEquationModule
33
import LoopVectorization: @turbo, indices
44
import ..EquationModule: Node, string_tree
55
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
6-
import ..UtilsModule: @return_on_false, @maybe_turbo, is_bad_array, vals
6+
import ..UtilsModule: @return_on_false, @maybe_turbo, is_bad_array
77
import ..EquationUtilsModule: is_constant
88

99
macro return_on_check(val, T, n)
@@ -98,53 +98,48 @@ function _eval_tree_array(
9898
!flag && return Array{T,1}(undef, size(cX, 2)), false
9999
return fill(result, size(cX, 2)), true
100100
elseif tree.degree == 1
101+
op = operators.unaops[tree.op]
101102
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
102103
# op(op2(x, y)), where x, y, z are constants or variables.
103-
return deg1_l2_ll0_lr0_eval(
104-
tree, cX, vals[tree.op], vals[tree.l.op], operators, Val(turbo)
105-
)
104+
op_l = operators.binops[tree.l.op]
105+
return deg1_l2_ll0_lr0_eval(tree, cX, op, op_l, Val(turbo))
106106
elseif tree.l.degree == 1 && tree.l.l.degree == 0
107107
# op(op2(x)), where x is a constant or variable.
108-
return deg1_l1_ll0_eval(
109-
tree, cX, vals[tree.op], vals[tree.l.op], operators, Val(turbo)
110-
)
108+
op_l = operators.unaops[tree.l.op]
109+
return deg1_l1_ll0_eval(tree, cX, op, op_l, Val(turbo))
111110
else
112111
# op(x), for any x.
113-
return deg1_eval(tree, cX, vals[tree.op], operators, Val(turbo))
112+
return deg1_eval(tree, cX, op, operators, Val(turbo))
114113
end
115114
elseif tree.degree == 2
116115
# TODO - add op(op2(x, y), z) and op(x, op2(y, z))
116+
op = operators.binops[tree.op]
117117
if tree.l.degree == 0 && tree.r.degree == 0
118118
# op(x, y), where x, y are constants or variables.
119-
return deg2_l0_r0_eval(tree, cX, vals[tree.op], operators, Val(turbo))
119+
return deg2_l0_r0_eval(tree, cX, op, Val(turbo))
120120
elseif tree.l.degree == 0
121121
# op(x, y), where x is a constant or variable but y is not.
122-
return deg2_l0_eval(tree, cX, vals[tree.op], operators, Val(turbo))
122+
return deg2_l0_eval(tree, cX, op, operators, Val(turbo))
123123
elseif tree.r.degree == 0
124124
# op(x, y), where y is a constant or variable but x is not.
125-
return deg2_r0_eval(tree, cX, vals[tree.op], operators, Val(turbo))
125+
return deg2_r0_eval(tree, cX, op, operators, Val(turbo))
126126
else
127127
# op(x, y), for any x or y
128-
return deg2_eval(tree, cX, vals[tree.op], operators, Val(turbo))
128+
return deg2_eval(tree, cX, op, operators, Val(turbo))
129129
end
130130
end
131131
end
132132

133133
function deg2_eval(
134-
tree::Node{T},
135-
cX::AbstractMatrix{T},
136-
::Val{op_idx},
137-
operators::OperatorEnum,
138-
::Val{turbo},
139-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
134+
tree::Node{T}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum, ::Val{turbo}
135+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
140136
n = size(cX, 2)
141137
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
142138
@return_on_false complete cumulator
143139
@return_on_nonfinite_array cumulator T n
144140
(array2, complete2) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
145141
@return_on_false complete2 cumulator
146142
@return_on_nonfinite_array array2 T n
147-
op = operators.binops[op_idx]
148143

149144
# We check inputs (and intermediates), not outputs.
150145
@maybe_turbo turbo for j in indices(cumulator)
@@ -156,22 +151,17 @@ function deg2_eval(
156151
end
157152

158153
function deg1_eval(
159-
tree::Node{T},
160-
cX::AbstractMatrix{T},
161-
::Val{op_idx},
162-
operators::OperatorEnum,
163-
::Val{turbo},
164-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
154+
tree::Node{T}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum, ::Val{turbo}
155+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
165156
n = size(cX, 2)
166157
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
167158
@return_on_false complete cumulator
168159
@return_on_nonfinite_array cumulator T n
169-
op = operators.unaops[op_idx]
170160
@maybe_turbo turbo for j in indices(cumulator)
171161
x = op(cumulator[j])::T
172162
cumulator[j] = x
173163
end
174-
return (cumulator, true) #
164+
return (cumulator, true)
175165
end
176166

177167
function deg0_eval(
@@ -188,14 +178,11 @@ end
188178
function deg1_l2_ll0_lr0_eval(
189179
tree::Node{T},
190180
cX::AbstractMatrix{T},
191-
::Val{op_idx},
192-
::Val{op_l_idx},
193-
operators::OperatorEnum,
181+
op::F,
182+
op_l::F2,
194183
::Val{turbo},
195-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,op_l_idx,turbo}
184+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,F2,turbo}
196185
n = size(cX, 2)
197-
op = operators.unaops[op_idx]
198-
op_l = operators.binops[op_l_idx]
199186
if tree.l.l.constant && tree.l.r.constant
200187
val_ll = tree.l.l.val::T
201188
val_lr = tree.l.r.val::T
@@ -245,14 +232,11 @@ end
245232
function deg1_l1_ll0_eval(
246233
tree::Node{T},
247234
cX::AbstractMatrix{T},
248-
::Val{op_idx},
249-
::Val{op_l_idx},
250-
operators::OperatorEnum,
235+
op::F,
236+
op_l::F2,
251237
::Val{turbo},
252-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,op_l_idx,turbo}
238+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,F2,turbo}
253239
n = size(cX, 2)
254-
op = operators.unaops[op_idx]
255-
op_l = operators.unaops[op_l_idx]
256240
if tree.l.l.constant
257241
val_ll = tree.l.l.val::T
258242
@return_on_check val_ll T n
@@ -275,14 +259,9 @@ end
275259

276260
# op(x, y) for x and y variable/constant
277261
function deg2_l0_r0_eval(
278-
tree::Node{T},
279-
cX::AbstractMatrix{T},
280-
::Val{op_idx},
281-
operators::OperatorEnum,
282-
::Val{turbo},
283-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
262+
tree::Node{T}, cX::AbstractMatrix{T}, op::F, ::Val{turbo}
263+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
284264
n = size(cX, 2)
285-
op = operators.binops[op_idx]
286265
if tree.l.constant && tree.r.constant
287266
val_l = tree.l.val::T
288267
@return_on_check val_l T n
@@ -323,17 +302,12 @@ end
323302

324303
# op(x, y) for x variable/constant, y arbitrary
325304
function deg2_l0_eval(
326-
tree::Node{T},
327-
cX::AbstractMatrix{T},
328-
::Val{op_idx},
329-
operators::OperatorEnum,
330-
::Val{turbo},
331-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
305+
tree::Node{T}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum, ::Val{turbo}
306+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
332307
n = size(cX, 2)
333308
(cumulator, complete) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
334309
@return_on_false complete cumulator
335310
@return_on_nonfinite_array cumulator T n
336-
op = operators.binops[op_idx]
337311
if tree.l.constant
338312
val = tree.l.val::T
339313
@return_on_check val T n
@@ -353,17 +327,12 @@ end
353327

354328
# op(x, y) for x arbitrary, y variable/constant
355329
function deg2_r0_eval(
356-
tree::Node{T},
357-
cX::AbstractMatrix{T},
358-
::Val{op_idx},
359-
operators::OperatorEnum,
360-
::Val{turbo},
361-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
330+
tree::Node{T}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum, ::Val{turbo}
331+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
362332
n = size(cX, 2)
363333
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
364334
@return_on_false complete cumulator
365335
@return_on_nonfinite_array cumulator T n
366-
op = operators.binops[op_idx]
367336
if tree.r.constant
368337
val = tree.r.val::T
369338
@return_on_check val T n
@@ -394,9 +363,9 @@ function _eval_constant_tree(
394363
if tree.degree == 0
395364
return deg0_eval_constant(tree)
396365
elseif tree.degree == 1
397-
return deg1_eval_constant(tree, vals[tree.op], operators)
366+
return deg1_eval_constant(tree, operators.unaops[tree.op], operators)
398367
else
399-
return deg2_eval_constant(tree, vals[tree.op], operators)
368+
return deg2_eval_constant(tree, operators.binops[tree.op], operators)
400369
end
401370
end
402371

@@ -405,19 +374,17 @@ end
405374
end
406375

407376
function deg1_eval_constant(
408-
tree::Node{T}, ::Val{op_idx}, operators::OperatorEnum
409-
)::Tuple{T,Bool} where {T<:Real,op_idx}
410-
op = operators.unaops[op_idx]
377+
tree::Node{T}, op::F, operators::OperatorEnum
378+
)::Tuple{T,Bool} where {T<:Real,F}
411379
(cumulator, complete) = _eval_constant_tree(tree.l, operators)
412380
!complete && return zero(T), false
413381
output = op(cumulator)::T
414382
return output, isfinite(output)
415383
end
416384

417385
function deg2_eval_constant(
418-
tree::Node{T}, ::Val{op_idx}, operators::OperatorEnum
419-
)::Tuple{T,Bool} where {T<:Real,op_idx}
420-
op = operators.binops[op_idx]
386+
tree::Node{T}, op::F, operators::OperatorEnum
387+
)::Tuple{T,Bool} where {T<:Real,F}
421388
(cumulator, complete) = _eval_constant_tree(tree.l, operators)
422389
!complete && return zero(T), false
423390
(cumulator2, complete2) = _eval_constant_tree(tree.r, operators)
@@ -442,31 +409,29 @@ function differentiable_eval_tree_array(
442409
return (cX[tree.feature, :], true)
443410
end
444411
elseif tree.degree == 1
445-
return deg1_diff_eval(tree, cX, vals[tree.op], operators)
412+
return deg1_diff_eval(tree, cX, operators.unaops[tree.op], operators)
446413
else
447-
return deg2_diff_eval(tree, cX, vals[tree.op], operators)
414+
return deg2_diff_eval(tree, cX, operators.binops[tree.op], operators)
448415
end
449416
end
450417

451418
function deg1_diff_eval(
452-
tree::Node{T1}, cX::AbstractMatrix{T}, ::Val{op_idx}, operators::OperatorEnum
453-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,T1}
419+
tree::Node{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum
420+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,T1}
454421
(left, complete) = differentiable_eval_tree_array(tree.l, cX, operators)
455422
@return_on_false complete left
456-
op = operators.unaops[op_idx]
457423
out = op.(left)
458424
no_nans = !any(x -> (!isfinite(x)), out)
459425
return (out, no_nans)
460426
end
461427

462428
function deg2_diff_eval(
463-
tree::Node{T1}, cX::AbstractMatrix{T}, ::Val{op_idx}, operators::OperatorEnum
464-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,T1}
429+
tree::Node{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum
430+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,T1}
465431
(left, complete) = differentiable_eval_tree_array(tree.l, cX, operators)
466432
@return_on_false complete left
467433
(right, complete2) = differentiable_eval_tree_array(tree.r, cX, operators)
468434
@return_on_false complete2 left
469-
op = operators.binops[op_idx]
470435
out = op.(left, right)
471436
no_nans = !any(x -> (!isfinite(x)), out)
472437
return (out, no_nans)
@@ -557,30 +522,28 @@ function _eval_tree_array_generic(
557522
end
558523
end
559524
elseif tree.degree == 1
560-
return deg1_eval_generic(tree, cX, vals[tree.op], operators, Val(throw_errors))
525+
return deg1_eval_generic(tree, cX, operators.unaops[tree.op], operators, Val(throw_errors))
561526
else
562-
return deg2_eval_generic(tree, cX, vals[tree.op], operators, Val(throw_errors))
527+
return deg2_eval_generic(tree, cX, operators.binops[tree.op], operators, Val(throw_errors))
563528
end
564529
end
565530

566531
function deg1_eval_generic(
567-
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
568-
) where {op_idx,throw_errors}
532+
tree, cX, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
533+
) where {F,throw_errors}
569534
left, complete = eval_tree_array(tree.l, cX, operators)
570535
!throw_errors && !complete && return nothing, false
571-
op = operators.unaops[op_idx]
572536
!throw_errors && !hasmethod(op, Tuple{typeof(left)}) && return nothing, false
573537
return op(left), true
574538
end
575539

576540
function deg2_eval_generic(
577-
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
578-
) where {op_idx,throw_errors}
541+
tree, cX, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
542+
) where {F,throw_errors}
579543
left, complete = eval_tree_array(tree.l, cX, operators)
580544
!throw_errors && !complete && return nothing, false
581545
right, complete = eval_tree_array(tree.r, cX, operators)
582546
!throw_errors && !complete && return nothing, false
583-
op = operators.binops[op_idx]
584547
!throw_errors &&
585548
!hasmethod(op, Tuple{typeof(left),typeof(right)}) &&
586549
return nothing, false

0 commit comments

Comments
 (0)