Skip to content

Commit 9990e2b

Browse files
committed
Add @turbo mode to gradient calculation
1 parent 953dd31 commit 9990e2b

File tree

3 files changed

+97
-42
lines changed

3 files changed

+97
-42
lines changed

src/EvaluateEquation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ macro return_on_nonfinite_array(array, T, n)
2828
end
2929

3030
"""
31-
eval_tree_array(tree::Node, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool)
31+
eval_tree_array(tree::Node, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false)
3232
3333
Evaluate a binary tree (equation) over a given input data matrix. The
3434
operators contain all of the operators used. This function fuses doublets

src/EvaluateEquationDerivative.jl

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module EvaluateEquationDerivativeModule
22

3-
import LoopVectorization: indices
3+
import LoopVectorization: indices, @turbo
44
import ..EquationModule: Node
55
import ..OperatorEnumModule: OperatorEnum
6-
import ..UtilsModule: @return_on_false2, is_bad_array
6+
import ..UtilsModule: @return_on_false2, @maybe_turbo, is_bad_array
77
import ..EquationUtilsModule: count_constants, index_constants, NodeIndex
88
import ..EvaluateEquationModule: deg0_eval
99

@@ -16,7 +16,7 @@ function assert_autodiff_enabled(operators::OperatorEnum)
1616
end
1717

1818
"""
19-
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int)
19+
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int; turbo::Bool=false)
2020
2121
Compute the forward derivative of an expression, using a similar
2222
structure and optimization to eval_tree_array. `direction` is the index of a particular
@@ -30,33 +30,48 @@ respect to `x1`.
3030
- `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
3131
must be `true`. This is needed to create the derivative operations.
3232
- `direction::Int`: The index of the variable to take the derivative with respect to.
33+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
3334
3435
# Returns
3536
3637
- `(evaluation, derivative, complete)::Tuple{AbstractVector{T}, AbstractVector{T}, Bool}`: the normal evaluation,
3738
the derivative, and whether the evaluation completed as normal (or encountered a nan or inf).
3839
"""
3940
function eval_diff_tree_array(
40-
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int
41+
tree::Node{T},
42+
cX::AbstractMatrix{T},
43+
operators::OperatorEnum,
44+
direction::Int;
45+
turbo::Bool=false,
4146
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real}
4247
assert_autodiff_enabled(operators)
4348
# TODO: Implement quick check for whether the variable is actually used
4449
# in this tree. Otherwise, return zero.
45-
return _eval_diff_tree_array(tree, cX, operators, direction)
50+
return _eval_diff_tree_array(
51+
tree, cX, operators, direction, (turbo ? Val(true) : Val(false))
52+
)
4653
end
4754
function eval_diff_tree_array(
48-
tree::Node{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum, direction::Int;
55+
tree::Node{T1},
56+
cX::AbstractMatrix{T2},
57+
operators::OperatorEnum,
58+
direction::Int;
59+
turbo::Bool=false,
4960
) where {T1<:Real,T2<:Real}
5061
T = promote_type(T1, T2)
5162
@warn "Warning: eval_diff_tree_array received mixed types: tree=$(T1) and data=$(T2)."
5263
tree = convert(Node{T}, tree)
5364
cX = convert(AbstractMatrix{T}, cX)
54-
return eval_diff_tree_array(tree, cX, operators, direction)
65+
return eval_diff_tree_array(tree, cX, operators, direction; turbo=turbo)
5566
end
5667

5768
function _eval_diff_tree_array(
58-
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int
59-
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real}
69+
tree::Node{T},
70+
cX::AbstractMatrix{T},
71+
operators::OperatorEnum,
72+
direction::Int,
73+
::Val{turbo},
74+
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real,turbo}
6075
evaluation, derivative, complete = if tree.degree == 0
6176
diff_deg0_eval(tree, cX, direction)
6277
elseif tree.degree == 1
@@ -67,6 +82,7 @@ function _eval_diff_tree_array(
6782
operators.diff_unaops[tree.op],
6883
operators,
6984
direction,
85+
Val(turbo),
7086
)
7187
else
7288
diff_deg2_eval(
@@ -76,6 +92,7 @@ function _eval_diff_tree_array(
7692
operators.diff_binops[tree.op],
7793
operators,
7894
direction,
95+
Val(turbo),
7996
)
8097
end
8198
@return_on_false2 complete evaluation derivative
@@ -99,15 +116,16 @@ function diff_deg1_eval(
99116
diff_op::dF,
100117
operators::OperatorEnum,
101118
direction::Int,
102-
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real,F,dF}
119+
::Val{turbo},
120+
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real,F,dF,turbo}
103121
n = size(cX, 2)
104122
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
105-
tree.l, cX, operators, direction
123+
tree.l, cX, operators, direction, Val(turbo)
106124
)
107125
@return_on_false2 complete cumulator dcumulator
108126

109127
# TODO - add type assertions to get better speed:
110-
@inbounds @simd for j in indices((cumulator, dcumulator))
128+
@maybe_turbo turbo for j in indices((cumulator, dcumulator))
111129
x = op(cumulator[j])::T
112130
dx = diff_op(cumulator[j])::T * dcumulator[j]
113131

@@ -124,21 +142,21 @@ function diff_deg2_eval(
124142
diff_op::dF,
125143
operators::OperatorEnum,
126144
direction::Int,
127-
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real,F,dF}
128-
n = size(cX, 2)
145+
::Val{turbo},
146+
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real,F,dF,turbo}
129147
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
130-
tree.l, cX, operators, direction
148+
tree.l, cX, operators, direction, Val(turbo)
131149
)
132150
@return_on_false2 complete cumulator dcumulator
133151
(array2, dcumulator2, complete2) = _eval_diff_tree_array(
134-
tree.r, cX, operators, direction
152+
tree.r, cX, operators, direction, Val(turbo)
135153
)
136154
@return_on_false2 complete2 array2 dcumulator2
137155

138-
@inbounds @simd for j in indices((cumulator, dcumulator, array2, dcumulator2))
139-
x = op(cumulator[j], array2[j])
156+
@maybe_turbo turbo for j in indices((cumulator, dcumulator, array2, dcumulator2))
157+
x = op(cumulator[j], array2[j])::T
140158

141-
first, second = diff_op(cumulator[j], array2[j])
159+
first, second = diff_op(cumulator[j], array2[j])::Tuple{T,T}
142160
dx = first * dcumulator[j] + second * dcumulator2[j]
143161

144162
cumulator[j] = x
@@ -148,7 +166,7 @@ function diff_deg2_eval(
148166
end
149167

150168
"""
151-
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false)
169+
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false, turbo::Bool=false)
152170
153171
Compute the forward-mode derivative of an expression, using a similar
154172
structure and optimization to eval_tree_array. `variable` specifies whether
@@ -163,14 +181,19 @@ to every constant in the expression.
163181
must be `true`. This is needed to create the derivative operations.
164182
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `cX` - with `variable=true`),
165183
or with respect to every constant in the expression (`variable=false`).
184+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
166185
167186
# Returns
168187
169188
- `(evaluation, gradient, complete)::Tuple{AbstractVector{T}, AbstractMatrix{T}, Bool}`: the normal evaluation,
170189
the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
171190
"""
172191
function eval_grad_tree_array(
173-
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false
192+
tree::Node{T},
193+
cX::AbstractMatrix{T},
194+
operators::OperatorEnum;
195+
variable::Bool=false,
196+
turbo::Bool=false,
174197
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real}
175198
assert_autodiff_enabled(operators)
176199
n = size(cX, 2)
@@ -181,7 +204,14 @@ function eval_grad_tree_array(
181204
end
182205
index_tree = index_constants(tree, 0)
183206
return eval_grad_tree_array(
184-
tree, n, n_gradients, index_tree, cX, operators, (variable ? Val(true) : Val(false))
207+
tree,
208+
n,
209+
n_gradients,
210+
index_tree,
211+
cX,
212+
operators,
213+
(variable ? Val(true) : Val(false)),
214+
(turbo ? Val(true) : Val(false)),
185215
)
186216
end
187217

@@ -193,20 +223,29 @@ function eval_grad_tree_array(
193223
cX::AbstractMatrix{T},
194224
operators::OperatorEnum,
195225
::Val{variable},
196-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,variable}
226+
::Val{turbo},
227+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,variable,turbo}
197228
evaluation, gradient, complete = _eval_grad_tree_array(
198-
tree, n, n_gradients, index_tree, cX, operators, Val(variable)
229+
tree, n, n_gradients, index_tree, cX, operators, Val(variable), Val(turbo)
199230
)
200231
@return_on_false2 complete evaluation gradient
201232
return evaluation, gradient, !(is_bad_array(evaluation) || is_bad_array(gradient))
202233
end
203234

204235
function eval_grad_tree_array(
205-
tree::Node{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum; variable::Bool=false
236+
tree::Node{T1},
237+
cX::AbstractMatrix{T2},
238+
operators::OperatorEnum;
239+
variable::Bool=false,
240+
turbo::Bool=false,
206241
) where {T1<:Real,T2<:Real}
207242
T = promote_type(T1, T2)
208243
return eval_grad_tree_array(
209-
convert(Node{T}, tree), convert(AbstractMatrix{T}, cX), operators; variable=variable
244+
convert(Node{T}, tree),
245+
convert(AbstractMatrix{T}, cX),
246+
operators;
247+
variable=variable,
248+
turbo=turbo,
210249
)
211250
end
212251

@@ -218,7 +257,8 @@ function _eval_grad_tree_array(
218257
cX::AbstractMatrix{T},
219258
operators::OperatorEnum,
220259
::Val{variable},
221-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,variable}
260+
::Val{turbo},
261+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,variable,turbo}
222262
if tree.degree == 0
223263
grad_deg0_eval(tree, n, n_gradients, index_tree, cX, Val(variable))
224264
elseif tree.degree == 1
@@ -232,6 +272,7 @@ function _eval_grad_tree_array(
232272
operators.diff_unaops[tree.op],
233273
operators,
234274
Val(variable),
275+
Val(turbo),
235276
)
236277
else
237278
grad_deg2_eval(
@@ -244,6 +285,7 @@ function _eval_grad_tree_array(
244285
operators.diff_binops[tree.op],
245286
operators,
246287
Val(variable),
288+
Val(turbo),
247289
)
248290
end
249291
end
@@ -278,13 +320,14 @@ function grad_deg1_eval(
278320
diff_op::dF,
279321
operators::OperatorEnum,
280322
::Val{variable},
281-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,F,dF,variable}
323+
::Val{turbo},
324+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,F,dF,variable,turbo}
282325
(cumulator, dcumulator, complete) = eval_grad_tree_array(
283-
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable)
326+
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable), Val(turbo)
284327
)
285328
@return_on_false2 complete cumulator dcumulator
286329

287-
@inbounds @simd for j in indices((cumulator, dcumulator), (1, 2))
330+
@maybe_turbo turbo for j in indices((cumulator, dcumulator), (1, 2))
288331
x = op(cumulator[j])::T
289332
dx = diff_op(cumulator[j])::T
290333

@@ -306,17 +349,18 @@ function grad_deg2_eval(
306349
diff_op::dF,
307350
operators::OperatorEnum,
308351
::Val{variable},
309-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,F,dF,variable}
352+
::Val{turbo},
353+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,F,dF,variable,turbo}
310354
(cumulator1, dcumulator1, complete) = eval_grad_tree_array(
311-
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable)
355+
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable), Val(turbo)
312356
)
313357
@return_on_false2 complete cumulator1 dcumulator1
314358
(cumulator2, dcumulator2, complete2) = eval_grad_tree_array(
315-
tree.r, n, n_gradients, index_tree.r, cX, operators, Val(variable)
359+
tree.r, n, n_gradients, index_tree.r, cX, operators, Val(variable), Val(turbo)
316360
)
317361
@return_on_false2 complete2 cumulator1 dcumulator1
318362

319-
@inbounds @simd for j in indices(
363+
@maybe_turbo turbo for j in indices(
320364
(cumulator1, cumulator2, dcumulator1, dcumulator2), (1, 1, 2, 2)
321365
)
322366
c1 = cumulator1[j]

test/test_derivatives.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using Zygote
66
using LinearAlgebra
77

88
seed = 0
9-
pow_abs2(x, y) = abs(x)^y
9+
# SIMD doesn't like abs(x) ^ y for some reason.
10+
pow_abs2(x, y) = exp(y * log(abs(x)))
1011
custom_cos(x) = cos(x)^2
1112

1213
equation1(x1, x2, x3) = x1 + x2 + x3 + 3.2
@@ -35,8 +36,12 @@ function array_test(ar1, ar2; rtol=0.1)
3536
return isapprox(ar1, ar2; rtol=rtol)
3637
end
3738

38-
for type in [Float16, Float32, Float64]
39-
println("Testing derivatives with respect to variables, with type=$(type).")
39+
for type in [Float16, Float32, Float64], turbo in [true, false]
40+
type == Float16 && turbo && continue
41+
42+
println(
43+
"Testing derivatives with respect to variables, with type=$(type) and turbo=$(turbo).",
44+
)
4045
rng = MersenneTwister(seed)
4146
nfeatures = 3
4247
N = 100
@@ -72,10 +77,16 @@ for type in [Float16, Float32, Float64]
7277
)
7378
# Convert tuple of vectors to matrix:
7479
true_grad = reduce(hcat, true_grad)'
75-
predicted_grad = eval_grad_tree_array(tree, X, operators; variable=true)[2]
80+
predicted_grad = eval_grad_tree_array(
81+
tree, X, operators; variable=true, turbo=turbo
82+
)[2]
7683
predicted_grad2 =
7784
reduce(
78-
hcat, [eval_diff_tree_array(tree, X, operators, i)[2] for i in 1:nfeatures]
85+
hcat,
86+
[
87+
eval_diff_tree_array(tree, X, operators, i; turbo=turbo)[2] for
88+
i in 1:nfeatures
89+
],
7990
)'
8091
predicted_grad3 = tree'(X)
8192

@@ -98,7 +109,7 @@ for type in [Float16, Float32, Float64]
98109
local tree
99110
tree = equation4(nx1, nx2, nx3)
100111
tree = convert(Node{type}, tree)
101-
predicted_grad = eval_grad_tree_array(tree, X, operators; variable=false)[2]
112+
predicted_grad = eval_grad_tree_array(tree, X, operators; variable=false, turbo=turbo)[2]
102113
@test array_test(predicted_grad[1, :], X[1, :])
103114

104115
# More complex expression:
@@ -123,7 +134,7 @@ for type in [Float16, Float32, Float64]
123134
[X[i, :] for i in 1:nfeatures]...,
124135
)[1:2]
125136
true_grad = reduce(hcat, true_grad)'
126-
predicted_grad = eval_grad_tree_array(tree, X, operators; variable=false)[2]
137+
predicted_grad = eval_grad_tree_array(tree, X, operators; variable=false, turbo=turbo)[2]
127138

128139
@test array_test(predicted_grad, true_grad)
129140
println("Done.")

0 commit comments

Comments
 (0)