11module EvaluateEquationDerivativeModule
22
3- import LoopVectorization: indices
3+ import LoopVectorization: indices, @turbo
44import .. EquationModule: Node
55import .. OperatorEnumModule: OperatorEnum
6- import .. UtilsModule: @return_on_false2 , is_bad_array
6+ import .. UtilsModule: @return_on_false2 , @maybe_turbo , is_bad_array
77import .. EquationUtilsModule: count_constants, index_constants, NodeIndex
88import .. EvaluateEquationModule: deg0_eval
99
@@ -16,7 +16,7 @@ function assert_autodiff_enabled(operators::OperatorEnum)
1616end
1717
1818"""
19- eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int)
19+ eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int; turbo::Bool=false )
2020
2121Compute the forward derivative of an expression, using a similar
2222structure and optimization to eval_tree_array. `direction` is the index of a particular
@@ -30,33 +30,48 @@ respect to `x1`.
3030- `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
3131 must be `true`. This is needed to create the derivative operations.
3232- `direction::Int`: The index of the variable to take the derivative with respect to.
33+ - `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
3334
3435# Returns
3536
3637- `(evaluation, derivative, complete)::Tuple{AbstractVector{T}, AbstractVector{T}, Bool}`: the normal evaluation,
3738 the derivative, and whether the evaluation completed as normal (or encountered a nan or inf).
3839"""
3940function eval_diff_tree_array (
40- tree:: Node{T} , cX:: AbstractMatrix{T} , operators:: OperatorEnum , direction:: Int
41+ tree:: Node{T} ,
42+ cX:: AbstractMatrix{T} ,
43+ operators:: OperatorEnum ,
44+ direction:: Int ;
45+ turbo:: Bool = false ,
4146):: Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<: Real }
4247 assert_autodiff_enabled (operators)
4348 # TODO : Implement quick check for whether the variable is actually used
4449 # in this tree. Otherwise, return zero.
45- return _eval_diff_tree_array (tree, cX, operators, direction)
50+ return _eval_diff_tree_array (
51+ tree, cX, operators, direction, (turbo ? Val (true ) : Val (false ))
52+ )
4653end
4754function eval_diff_tree_array (
48- tree:: Node{T1} , cX:: AbstractMatrix{T2} , operators:: OperatorEnum , direction:: Int ;
55+ tree:: Node{T1} ,
56+ cX:: AbstractMatrix{T2} ,
57+ operators:: OperatorEnum ,
58+ direction:: Int ;
59+ turbo:: Bool = false ,
4960) where {T1<: Real ,T2<: Real }
5061 T = promote_type (T1, T2)
5162 @warn " Warning: eval_diff_tree_array received mixed types: tree=$(T1) and data=$(T2) ."
5263 tree = convert (Node{T}, tree)
5364 cX = convert (AbstractMatrix{T}, cX)
54- return eval_diff_tree_array (tree, cX, operators, direction)
65+ return eval_diff_tree_array (tree, cX, operators, direction; turbo = turbo )
5566end
5667
5768function _eval_diff_tree_array (
58- tree:: Node{T} , cX:: AbstractMatrix{T} , operators:: OperatorEnum , direction:: Int
59- ):: Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<: Real }
69+ tree:: Node{T} ,
70+ cX:: AbstractMatrix{T} ,
71+ operators:: OperatorEnum ,
72+ direction:: Int ,
73+ :: Val{turbo} ,
74+ ):: Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<: Real ,turbo}
6075 evaluation, derivative, complete = if tree. degree == 0
6176 diff_deg0_eval (tree, cX, direction)
6277 elseif tree. degree == 1
@@ -67,6 +82,7 @@ function _eval_diff_tree_array(
6782 operators. diff_unaops[tree. op],
6883 operators,
6984 direction,
85+ Val (turbo),
7086 )
7187 else
7288 diff_deg2_eval (
@@ -76,6 +92,7 @@ function _eval_diff_tree_array(
7692 operators. diff_binops[tree. op],
7793 operators,
7894 direction,
95+ Val (turbo),
7996 )
8097 end
8198 @return_on_false2 complete evaluation derivative
@@ -99,15 +116,16 @@ function diff_deg1_eval(
99116 diff_op:: dF ,
100117 operators:: OperatorEnum ,
101118 direction:: Int ,
102- ):: Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<: Real ,F,dF}
119+ :: Val{turbo} ,
120+ ):: Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<: Real ,F,dF,turbo}
103121 n = size (cX, 2 )
104122 (cumulator, dcumulator, complete) = _eval_diff_tree_array (
105- tree. l, cX, operators, direction
123+ tree. l, cX, operators, direction, Val (turbo)
106124 )
107125 @return_on_false2 complete cumulator dcumulator
108126
109127 # TODO - add type assertions to get better speed:
110- @inbounds @simd for j in indices ((cumulator, dcumulator))
128+ @maybe_turbo turbo for j in indices ((cumulator, dcumulator))
111129 x = op (cumulator[j]):: T
112130 dx = diff_op (cumulator[j]):: T * dcumulator[j]
113131
@@ -124,21 +142,21 @@ function diff_deg2_eval(
124142 diff_op:: dF ,
125143 operators:: OperatorEnum ,
126144 direction:: Int ,
127- ) :: Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T <: Real ,F,dF}
128- n = size (cX, 2 )
145+ :: Val{turbo} ,
146+ ) :: Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T <: Real ,F,dF,turbo}
129147 (cumulator, dcumulator, complete) = _eval_diff_tree_array (
130- tree. l, cX, operators, direction
148+ tree. l, cX, operators, direction, Val (turbo)
131149 )
132150 @return_on_false2 complete cumulator dcumulator
133151 (array2, dcumulator2, complete2) = _eval_diff_tree_array (
134- tree. r, cX, operators, direction
152+ tree. r, cX, operators, direction, Val (turbo)
135153 )
136154 @return_on_false2 complete2 array2 dcumulator2
137155
138- @inbounds @simd for j in indices ((cumulator, dcumulator, array2, dcumulator2))
139- x = op (cumulator[j], array2[j])
156+ @maybe_turbo turbo for j in indices ((cumulator, dcumulator, array2, dcumulator2))
157+ x = op (cumulator[j], array2[j]):: T
140158
141- first, second = diff_op (cumulator[j], array2[j])
159+ first, second = diff_op (cumulator[j], array2[j]):: Tuple{T,T}
142160 dx = first * dcumulator[j] + second * dcumulator2[j]
143161
144162 cumulator[j] = x
@@ -148,7 +166,7 @@ function diff_deg2_eval(
148166end
149167
150168"""
151- eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false)
169+ eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false, turbo::Bool=false )
152170
153171Compute the forward-mode derivative of an expression, using a similar
154172structure and optimization to eval_tree_array. `variable` specifies whether
@@ -163,14 +181,19 @@ to every constant in the expression.
163181 must be `true`. This is needed to create the derivative operations.
164182- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `cX` - with `variable=true`),
165183 or with respect to every constant in the expression (`variable=false`).
184+ - `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
166185
167186# Returns
168187
169188- `(evaluation, gradient, complete)::Tuple{AbstractVector{T}, AbstractMatrix{T}, Bool}`: the normal evaluation,
170189 the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
171190"""
172191function eval_grad_tree_array (
173- tree:: Node{T} , cX:: AbstractMatrix{T} , operators:: OperatorEnum ; variable:: Bool = false
192+ tree:: Node{T} ,
193+ cX:: AbstractMatrix{T} ,
194+ operators:: OperatorEnum ;
195+ variable:: Bool = false ,
196+ turbo:: Bool = false ,
174197):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Real }
175198 assert_autodiff_enabled (operators)
176199 n = size (cX, 2 )
@@ -181,7 +204,14 @@ function eval_grad_tree_array(
181204 end
182205 index_tree = index_constants (tree, 0 )
183206 return eval_grad_tree_array (
184- tree, n, n_gradients, index_tree, cX, operators, (variable ? Val (true ) : Val (false ))
207+ tree,
208+ n,
209+ n_gradients,
210+ index_tree,
211+ cX,
212+ operators,
213+ (variable ? Val (true ) : Val (false )),
214+ (turbo ? Val (true ) : Val (false )),
185215 )
186216end
187217
@@ -193,20 +223,29 @@ function eval_grad_tree_array(
193223 cX:: AbstractMatrix{T} ,
194224 operators:: OperatorEnum ,
195225 :: Val{variable} ,
196- ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Real ,variable}
226+ :: Val{turbo} ,
227+ ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Real ,variable,turbo}
197228 evaluation, gradient, complete = _eval_grad_tree_array (
198- tree, n, n_gradients, index_tree, cX, operators, Val (variable)
229+ tree, n, n_gradients, index_tree, cX, operators, Val (variable), Val (turbo)
199230 )
200231 @return_on_false2 complete evaluation gradient
201232 return evaluation, gradient, ! (is_bad_array (evaluation) || is_bad_array (gradient))
202233end
203234
204235function eval_grad_tree_array (
205- tree:: Node{T1} , cX:: AbstractMatrix{T2} , operators:: OperatorEnum ; variable:: Bool = false
236+ tree:: Node{T1} ,
237+ cX:: AbstractMatrix{T2} ,
238+ operators:: OperatorEnum ;
239+ variable:: Bool = false ,
240+ turbo:: Bool = false ,
206241) where {T1<: Real ,T2<: Real }
207242 T = promote_type (T1, T2)
208243 return eval_grad_tree_array (
209- convert (Node{T}, tree), convert (AbstractMatrix{T}, cX), operators; variable= variable
244+ convert (Node{T}, tree),
245+ convert (AbstractMatrix{T}, cX),
246+ operators;
247+ variable= variable,
248+ turbo= turbo,
210249 )
211250end
212251
@@ -218,7 +257,8 @@ function _eval_grad_tree_array(
218257 cX:: AbstractMatrix{T} ,
219258 operators:: OperatorEnum ,
220259 :: Val{variable} ,
221- ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Real ,variable}
260+ :: Val{turbo} ,
261+ ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Real ,variable,turbo}
222262 if tree. degree == 0
223263 grad_deg0_eval (tree, n, n_gradients, index_tree, cX, Val (variable))
224264 elseif tree. degree == 1
@@ -232,6 +272,7 @@ function _eval_grad_tree_array(
232272 operators. diff_unaops[tree. op],
233273 operators,
234274 Val (variable),
275+ Val (turbo),
235276 )
236277 else
237278 grad_deg2_eval (
@@ -244,6 +285,7 @@ function _eval_grad_tree_array(
244285 operators. diff_binops[tree. op],
245286 operators,
246287 Val (variable),
288+ Val (turbo),
247289 )
248290 end
249291end
@@ -278,13 +320,14 @@ function grad_deg1_eval(
278320 diff_op:: dF ,
279321 operators:: OperatorEnum ,
280322 :: Val{variable} ,
281- ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Real ,F,dF,variable}
323+ :: Val{turbo} ,
324+ ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Real ,F,dF,variable,turbo}
282325 (cumulator, dcumulator, complete) = eval_grad_tree_array (
283- tree. l, n, n_gradients, index_tree. l, cX, operators, Val (variable)
326+ tree. l, n, n_gradients, index_tree. l, cX, operators, Val (variable), Val (turbo)
284327 )
285328 @return_on_false2 complete cumulator dcumulator
286329
287- @inbounds @simd for j in indices ((cumulator, dcumulator), (1 , 2 ))
330+ @maybe_turbo turbo for j in indices ((cumulator, dcumulator), (1 , 2 ))
288331 x = op (cumulator[j]):: T
289332 dx = diff_op (cumulator[j]):: T
290333
@@ -306,17 +349,18 @@ function grad_deg2_eval(
306349 diff_op:: dF ,
307350 operators:: OperatorEnum ,
308351 :: Val{variable} ,
309- ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Real ,F,dF,variable}
352+ :: Val{turbo} ,
353+ ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Real ,F,dF,variable,turbo}
310354 (cumulator1, dcumulator1, complete) = eval_grad_tree_array (
311- tree. l, n, n_gradients, index_tree. l, cX, operators, Val (variable)
355+ tree. l, n, n_gradients, index_tree. l, cX, operators, Val (variable), Val (turbo)
312356 )
313357 @return_on_false2 complete cumulator1 dcumulator1
314358 (cumulator2, dcumulator2, complete2) = eval_grad_tree_array (
315- tree. r, n, n_gradients, index_tree. r, cX, operators, Val (variable)
359+ tree. r, n, n_gradients, index_tree. r, cX, operators, Val (variable), Val (turbo)
316360 )
317361 @return_on_false2 complete2 cumulator1 dcumulator1
318362
319- @inbounds @simd for j in indices (
363+ @maybe_turbo turbo for j in indices (
320364 (cumulator1, cumulator2, dcumulator1, dcumulator2), (1 , 1 , 2 , 2 )
321365 )
322366 c1 = cumulator1[j]
0 commit comments