Skip to content

Commit f45ade4

Browse files
committed
Turn off @turbo for derivatives; as Zygote won't fix
1 parent e0196b2 commit f45ade4

File tree

1 file changed

+40
-91
lines changed

1 file changed

+40
-91
lines changed

src/EvaluateEquationDerivative.jl

Lines changed: 40 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module EvaluateEquationDerivativeModule
22

3-
import LoopVectorization: @turbo, indices
3+
import LoopVectorization: indices
44
import ..EquationModule: Node
55
import ..OperatorEnumModule: OperatorEnum
6-
import ..UtilsModule: @return_on_false2, is_bad_array, vals, @maybe_turbo
6+
import ..UtilsModule: @return_on_false2, is_bad_array, vals
77
import ..EquationUtilsModule: count_constants, index_constants, NodeIndex
88
import ..EvaluateEquationModule: deg0_eval
99

@@ -37,27 +37,15 @@ respect to `x1`.
3737
the derivative, and whether the evaluation completed as normal (or encountered a nan or inf).
3838
"""
3939
function eval_diff_tree_array(
40-
tree::Node{T},
41-
cX::AbstractMatrix{T},
42-
operators::OperatorEnum,
43-
direction::Int;
44-
turbo::Bool=false,
40+
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int
4541
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real}
4642
assert_autodiff_enabled(operators)
4743
# TODO: Implement quick check for whether the variable is actually used
4844
# in this tree. Otherwise, return zero.
49-
evaluation, derivative, complete = _eval_diff_tree_array(
50-
tree, cX, operators, direction; turbo=turbo
51-
)
52-
@return_on_false2 complete evaluation derivative
53-
return evaluation, derivative, !(is_bad_array(evaluation) || is_bad_array(derivative))
45+
return _eval_diff_tree_array(tree, cX, operators, direction)
5446
end
5547
function eval_diff_tree_array(
56-
tree::Node{T1},
57-
cX::AbstractMatrix{T2},
58-
operators::OperatorEnum,
59-
direction::Int;
60-
turbo::Bool=false,
48+
tree::Node{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum, direction::Int;
6149
) where {T1<:Real,T2<:Real}
6250
T = promote_type(T1, T2)
6351
@warn "Warning: eval_diff_tree_array received mixed types: tree=$(T1) and data=$(T2)."
@@ -67,26 +55,24 @@ function eval_diff_tree_array(
6755
end
6856

6957
function _eval_diff_tree_array(
70-
tree::Node{T},
71-
cX::AbstractMatrix{T},
72-
operators::OperatorEnum,
73-
direction::Int;
74-
turbo::Bool,
58+
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int
7559
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real}
76-
if tree.degree == 0
77-
diff_deg0_eval(tree, cX, operators, direction)
60+
evaluation, derivative, complete = if tree.degree == 0
61+
diff_deg0_eval(tree, cX, direction)
7862
elseif tree.degree == 1
79-
diff_deg1_eval(tree, cX, vals[tree.op], operators, direction; turbo=turbo)
63+
diff_deg1_eval(tree, cX, vals[tree.op], operators, direction)
8064
else
81-
diff_deg2_eval(tree, cX, vals[tree.op], operators, direction; turbo=turbo)
65+
diff_deg2_eval(tree, cX, vals[tree.op], operators, direction)
8266
end
67+
@return_on_false2 complete evaluation derivative
68+
return evaluation, derivative, !(is_bad_array(evaluation) || is_bad_array(derivative))
8369
end
8470

8571
function diff_deg0_eval(
86-
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int
72+
tree::Node{T}, cX::AbstractMatrix{T}, direction::Int
8773
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real}
8874
n = size(cX, 2)
89-
const_part = deg0_eval(tree, cX, operators)[1]
75+
const_part = deg0_eval(tree, cX)[1]
9076
derivative_part =
9177
((!tree.constant) && tree.feature == direction) ? ones(T, n) : zeros(T, n)
9278
return (const_part, derivative_part, true)
@@ -97,20 +83,19 @@ function diff_deg1_eval(
9783
cX::AbstractMatrix{T},
9884
::Val{op_idx},
9985
operators::OperatorEnum,
100-
direction::Int;
101-
turbo::Bool,
86+
direction::Int,
10287
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real,op_idx}
10388
n = size(cX, 2)
104-
(cumulator, dcumulator, complete) = eval_diff_tree_array(
105-
tree.l, cX, operators, direction; turbo=turbo
89+
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
90+
tree.l, cX, operators, direction
10691
)
10792
@return_on_false2 complete cumulator dcumulator
10893

10994
op = operators.unaops[op_idx]
11095
diff_op = operators.diff_unaops[op_idx]
11196

11297
# TODO - add type assertions to get better speed:
113-
@maybe_turbo turbo for j in indices((cumulator, dcumulator), (1, 1))
98+
@inbounds @simd for j in indices((cumulator, dcumulator))
11499
x = op(cumulator[j])::T
115100
dx = diff_op(cumulator[j])::T * dcumulator[j]
116101

@@ -125,25 +110,22 @@ function diff_deg2_eval(
125110
cX::AbstractMatrix{T},
126111
::Val{op_idx},
127112
operators::OperatorEnum,
128-
direction::Int;
129-
turbo::Bool,
113+
direction::Int,
130114
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real,op_idx}
131115
n = size(cX, 2)
132-
(cumulator, dcumulator, complete) = eval_diff_tree_array(
116+
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
133117
tree.l, cX, operators, direction
134118
)
135119
@return_on_false2 complete cumulator dcumulator
136-
(array2, dcumulator2, complete2) = eval_diff_tree_array(
120+
(array2, dcumulator2, complete2) = _eval_diff_tree_array(
137121
tree.r, cX, operators, direction
138122
)
139123
@return_on_false2 complete2 array2 dcumulator2
140124

141125
op = operators.binops[op_idx]
142126
diff_op = operators.diff_binops[op_idx]
143127

144-
@maybe_turbo turbo for j in indices(
145-
(cumulator, dcumulator, array2, dcumulator2), (1, 1, 1, 1)
146-
)
128+
@inbounds @simd for j in indices((cumulator, dcumulator, array2, dcumulator2))
147129
x = op(cumulator[j], array2[j])
148130

149131
first, second = diff_op(cumulator[j], array2[j])
@@ -178,11 +160,7 @@ to every constant in the expression.
178160
the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
179161
"""
180162
function eval_grad_tree_array(
181-
tree::Node{T},
182-
cX::AbstractMatrix{T},
183-
operators::OperatorEnum;
184-
variable::Bool=false,
185-
turbo::Bool=false,
163+
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false
186164
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real}
187165
assert_autodiff_enabled(operators)
188166
n = size(cX, 2)
@@ -193,7 +171,7 @@ function eval_grad_tree_array(
193171
end
194172
index_tree = index_constants(tree, 0)
195173
return eval_grad_tree_array(
196-
tree, n, n_gradients, index_tree, cX, operators, Val(variable); turbo=turbo
174+
tree, n, n_gradients, index_tree, cX, operators, (variable ? Val(true) : Val(false))
197175
)
198176
end
199177

@@ -204,30 +182,21 @@ function eval_grad_tree_array(
204182
index_tree::NodeIndex,
205183
cX::AbstractMatrix{T},
206184
operators::OperatorEnum,
207-
::Val{variable};
208-
turbo::Bool,
185+
::Val{variable},
209186
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,variable}
210187
evaluation, gradient, complete = _eval_grad_tree_array(
211-
tree, n, n_gradients, index_tree, cX, operators, Val(variable); turbo=turbo
188+
tree, n, n_gradients, index_tree, cX, operators, Val(variable)
212189
)
213190
@return_on_false2 complete evaluation gradient
214191
return evaluation, gradient, !(is_bad_array(evaluation) || is_bad_array(gradient))
215192
end
216193

217194
function eval_grad_tree_array(
218-
tree::Node{T1},
219-
cX::AbstractMatrix{T2},
220-
operators::OperatorEnum;
221-
variable::Bool=false,
222-
turbo::Bool=false,
195+
tree::Node{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum; variable::Bool=false
223196
) where {T1<:Real,T2<:Real}
224197
T = promote_type(T1, T2)
225198
return eval_grad_tree_array(
226-
convert(Node{T}, tree),
227-
convert(AbstractMatrix{T}, cX),
228-
operators;
229-
variable=variable,
230-
turbo=turbo,
199+
convert(Node{T}, tree), convert(AbstractMatrix{T}, cX), operators; variable=variable
231200
)
232201
end
233202

@@ -238,34 +207,17 @@ function _eval_grad_tree_array(
238207
index_tree::NodeIndex,
239208
cX::AbstractMatrix{T},
240209
operators::OperatorEnum,
241-
::Val{variable};
242-
turbo::Bool,
210+
::Val{variable},
243211
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,variable}
244212
if tree.degree == 0
245-
grad_deg0_eval(tree, n, n_gradients, index_tree, cX, operators, Val(variable))
213+
grad_deg0_eval(tree, n, n_gradients, index_tree, cX, Val(variable))
246214
elseif tree.degree == 1
247215
grad_deg1_eval(
248-
tree,
249-
n,
250-
n_gradients,
251-
index_tree,
252-
cX,
253-
vals[tree.op],
254-
operators,
255-
Val(variable);
256-
turbo=turbo,
216+
tree, n, n_gradients, index_tree, cX, vals[tree.op], operators, Val(variable)
257217
)
258218
else
259219
grad_deg2_eval(
260-
tree,
261-
n,
262-
n_gradients,
263-
index_tree,
264-
cX,
265-
vals[tree.op],
266-
operators,
267-
Val(variable);
268-
turbo=turbo,
220+
tree, n, n_gradients, index_tree, cX, vals[tree.op], operators, Val(variable)
269221
)
270222
end
271223
end
@@ -276,10 +228,9 @@ function grad_deg0_eval(
276228
n_gradients::Int,
277229
index_tree::NodeIndex,
278230
cX::AbstractMatrix{T},
279-
operators::OperatorEnum,
280231
::Val{variable},
281232
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,variable}
282-
const_part = deg0_eval(tree, cX, operators)[1]
233+
const_part = deg0_eval(tree, cX)[1]
283234

284235
if variable == tree.constant
285236
return (const_part, zeros(T, n_gradients, n), true)
@@ -299,18 +250,17 @@ function grad_deg1_eval(
299250
cX::AbstractMatrix{T},
300251
::Val{op_idx},
301252
operators::OperatorEnum,
302-
::Val{variable};
303-
turbo::Bool,
253+
::Val{variable},
304254
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,op_idx,variable}
305255
(cumulator, dcumulator, complete) = eval_grad_tree_array(
306-
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable); turbo=turbo
256+
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable)
307257
)
308258
@return_on_false2 complete cumulator dcumulator
309259

310260
op = operators.unaops[op_idx]
311261
diff_op = operators.diff_unaops[op_idx]
312262

313-
@maybe_turbo turbo for j in indices((cumulator, dcumulator), (1, 2))
263+
@inbounds @simd for j in indices((cumulator, dcumulator), (1, 2))
314264
x = op(cumulator[j])::T
315265
dx = diff_op(cumulator[j])
316266

@@ -330,22 +280,21 @@ function grad_deg2_eval(
330280
cX::AbstractMatrix{T},
331281
::Val{op_idx},
332282
operators::OperatorEnum,
333-
::Val{variable};
334-
turbo::Bool,
283+
::Val{variable},
335284
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,op_idx,variable}
336285
(cumulator1, dcumulator1, complete) = eval_grad_tree_array(
337-
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable); turbo=turbo
286+
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable)
338287
)
339288
@return_on_false2 complete cumulator1 dcumulator1
340289
(cumulator2, dcumulator2, complete2) = eval_grad_tree_array(
341-
tree.r, n, n_gradients, index_tree.r, cX, operators, Val(variable); turbo=turbo
290+
tree.r, n, n_gradients, index_tree.r, cX, operators, Val(variable)
342291
)
343292
@return_on_false2 complete2 cumulator1 dcumulator1
344293

345294
op = operators.binops[op_idx]
346295
diff_op = operators.diff_binops[op_idx]
347296

348-
@maybe_turbo turbo for j in indices(
297+
@inbounds @simd for j in indices(
349298
(cumulator1, cumulator2, dcumulator1, dcumulator2), (1, 1, 2, 2)
350299
)
351300
c1 = cumulator1[j]

0 commit comments

Comments
 (0)