@@ -206,17 +206,29 @@ function eval_grad_tree_array(
206206 variable:: Union{Bool,Val} = Val (false ),
207207 turbo:: Union{Bool,Val} = Val (false ),
208208) where {T<: Number }
209- n_gradients = if isa (variable, Val{true }) || (isa (variable, Bool) && variable)
209+ variable_mode = isa (variable, Val{true }) || (isa (variable, Bool) && variable)
210+ constant_mode = isa (variable, Val{false }) || (isa (variable, Bool) && ! variable)
211+ both_mode = isa (variable, Val{:both })
212+
213+ n_gradients = if variable_mode
210214 size (cX, 1 ):: Int
211- else
215+ elseif constant_mode
212216 count_constants (tree):: Int
217+ elseif both_mode
218+ size (cX, 1 ) + count_constants (tree)
213219 end
214- result = if isa (variable, Val{true }) || (variable isa Bool && variable)
220+
221+ result = if variable_mode
215222 eval_grad_tree_array (tree, n_gradients, nothing , cX, operators, Val (true ))
216- else
223+ elseif constant_mode
217224 index_tree = index_constants (tree)
218225 eval_grad_tree_array (tree, n_gradients, index_tree, cX, operators, Val (false ))
226+ elseif both_mode
227+ # features come first because we can use size(cX, 1) to skip them
228+ index_tree = index_constants (tree)
229+ eval_grad_tree_array (tree, n_gradients, index_tree, cX, operators, Val (:both ))
219230 end
231+
220232 return (result. x, result. dx, result. ok)
221233end
222234
@@ -226,11 +238,9 @@ function eval_grad_tree_array(
226238 index_tree:: Union{NodeIndex,Nothing} ,
227239 cX:: AbstractMatrix{T} ,
228240 operators:: OperatorEnum ,
229- :: Val{variable} ,
230- ):: ResultOk2 where {T<: Number ,variable}
231- result = _eval_grad_tree_array (
232- tree, n_gradients, index_tree, cX, operators, Val (variable)
233- )
241+ :: Val{mode} ,
242+ ):: ResultOk2 where {T<: Number ,mode}
243+ result = _eval_grad_tree_array (tree, n_gradients, index_tree, cX, operators, Val (mode))
234244 ! result. ok && return result
235245 return ResultOk2 (
236246 result. x, result. dx, ! (is_bad_array (result. x) || is_bad_array (result. dx))
@@ -260,30 +270,18 @@ end
260270 index_tree:: Union{NodeIndex,Nothing} ,
261271 cX:: AbstractMatrix{T} ,
262272 operators:: OperatorEnum ,
263- :: Val{variable } ,
264- ):: ResultOk2 where {T<: Number ,variable }
273+ :: Val{mode } ,
274+ ):: ResultOk2 where {T<: Number ,mode }
265275 nuna = get_nuna (operators)
266276 nbin = get_nbin (operators)
267277 deg1_branch_skeleton = quote
268278 grad_deg1_eval (
269- tree,
270- n_gradients,
271- index_tree,
272- cX,
273- operators. unaops[i],
274- operators,
275- Val (variable),
279+ tree, n_gradients, index_tree, cX, operators. unaops[i], operators, Val (mode)
276280 )
277281 end
278282 deg2_branch_skeleton = quote
279283 grad_deg2_eval (
280- tree,
281- n_gradients,
282- index_tree,
283- cX,
284- operators. binops[i],
285- operators,
286- Val (variable),
284+ tree, n_gradients, index_tree, cX, operators. binops[i], operators, Val (mode)
287285 )
288286 end
289287 deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
310308 end
311309 quote
312310 if tree. degree == 0
313- grad_deg0_eval (tree, n_gradients, index_tree, cX, Val (variable ))
311+ grad_deg0_eval (tree, n_gradients, index_tree, cX, Val (mode ))
314312 elseif tree. degree == 1
315313 $ deg1_branch
316314 else
@@ -324,8 +322,8 @@ function grad_deg0_eval(
324322 n_gradients,
325323 index_tree:: Union{NodeIndex,Nothing} ,
326324 cX:: AbstractMatrix{T} ,
327- :: Val{variable } ,
328- ):: ResultOk2 where {T<: Number ,variable }
325+ :: Val{mode } ,
326+ ):: ResultOk2 where {T<: Number ,mode }
329327 const_part = deg0_eval (tree, cX). x
330328
331329 zero_mat = if isa (cX, Array)
@@ -334,15 +332,24 @@ function grad_deg0_eval(
334332 hcat ([fill_similar (zero (T), cX, axes (cX, 2 )) for _ in 1 : n_gradients]. .. )'
335333 end
336334
337- if variable == tree. constant
335+ if (mode isa Bool && mode == tree. constant)
336+ # No gradients at this leaf node
338337 return ResultOk2 (const_part, zero_mat, true )
339338 end
340339
341- index = if variable
342- tree. feature
343- else
340+ index = if (mode isa Bool && mode)
341+ tree. feature:: UInt16
342+ elseif (mode isa Bool && ! mode)
344343 (index_tree === nothing ? zero (UInt16) : index_tree. val:: UInt16 )
344+ elseif mode == :both
345+ index_tree:: NodeIndex
346+ if tree. constant
347+ index_tree. val:: UInt16 + UInt16 (size (cX, 1 ))
348+ else
349+ tree. feature:: UInt16
350+ end
345351 end
352+
346353 derivative_part = zero_mat
347354 derivative_part[index, :] .= one (T)
348355 return ResultOk2 (const_part, derivative_part, true )
@@ -355,15 +362,15 @@ function grad_deg1_eval(
355362 cX:: AbstractMatrix{T} ,
356363 op:: F ,
357364 operators:: OperatorEnum ,
358- :: Val{variable } ,
359- ):: ResultOk2 where {T<: Number ,F,variable }
365+ :: Val{mode } ,
366+ ):: ResultOk2 where {T<: Number ,F,mode }
360367 result = eval_grad_tree_array (
361368 tree. l,
362369 n_gradients,
363370 index_tree === nothing ? index_tree : index_tree. l,
364371 cX,
365372 operators,
366- Val (variable ),
373+ Val (mode ),
367374 )
368375 ! result. ok && return result
369376
@@ -389,15 +396,15 @@ function grad_deg2_eval(
389396 cX:: AbstractMatrix{T} ,
390397 op:: F ,
391398 operators:: OperatorEnum ,
392- :: Val{variable } ,
393- ):: ResultOk2 where {T<: Number ,F,variable }
399+ :: Val{mode } ,
400+ ):: ResultOk2 where {T<: Number ,F,mode }
394401 result_l = eval_grad_tree_array (
395402 tree. l,
396403 n_gradients,
397404 index_tree === nothing ? index_tree : index_tree. l,
398405 cX,
399406 operators,
400- Val (variable ),
407+ Val (mode ),
401408 )
402409 ! result_l. ok && return result_l
403410 result_r = eval_grad_tree_array (
@@ -406,7 +413,7 @@ function grad_deg2_eval(
406413 index_tree === nothing ? index_tree : index_tree. r,
407414 cX,
408415 operators,
409- Val (variable ),
416+ Val (mode ),
410417 )
411418 ! result_r. ok && return result_r
412419
0 commit comments