Skip to content

Commit fc2497b

Browse files
committed
feat: avoid double descent in rrule
1 parent d7d8802 commit fc2497b

File tree

2 files changed

+73
-72
lines changed

2 files changed

+73
-72
lines changed

src/ChainRules.jl

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -42,43 +42,37 @@ function CRC.rrule(
4242
primal .= NaN
4343
end
4444

45-
# TODO: Preferable to use the primal in the pullback somehow
46-
function pullback((dY, _))
47-
dtree = let X = X, dY = dY, tree = tree, operators = operators
48-
@thunk(
49-
let
50-
_, gradient, complete2 = eval_grad_tree_array(
51-
tree, X, operators; variable=Val(false)
52-
)
53-
if !complete2
54-
gradient .= NaN
55-
end
45+
return (primal, complete), EvalPullback(tree, X, operators)
46+
end
47+
48+
# Wrap in struct rather than closure to ensure variables are boxed
49+
struct EvalPullback{N,A,O} <: Function
50+
tree::N
51+
X::A
52+
operators::O
53+
end
5654

57-
NodeTangent(
58-
tree,
59-
sum(j -> gradient[:, j] * dY[j], eachindex(dY, axes(gradient, 2))),
60-
)
61-
end
62-
)
63-
end
64-
dX = let X = X, dY = dY, tree = tree, operators = operators
65-
@thunk(
66-
let
67-
_, gradient2, complete3 = eval_grad_tree_array(
68-
tree, X, operators; variable=Val(true)
69-
)
70-
if !complete3
71-
gradient2 .= NaN
72-
end
55+
# TODO: Preferable to use the primal in the pullback somehow
56+
function (e::EvalPullback)((dY, _))
57+
_, dX_constants_dY, complete = eval_grad_tree_array(
58+
e.tree, e.X, e.operators; variable=Val(:both)
59+
)
7360

74-
gradient2 .* reshape(dY, 1, length(dY))
75-
end
76-
)
77-
end
78-
return (NoTangent(), dtree, dX, NoTangent())
61+
if !complete
62+
dX_constants_dY .= NaN
7963
end
8064

81-
return (primal, complete), pullback
65+
nfeatures = size(e.X, 1)
66+
dX_dY = @view dX_constants_dY[1:nfeatures, :]
67+
dconstants_dY = @view dX_constants_dY[(nfeatures + 1):end, :]
68+
69+
dtree = NodeTangent(
70+
e.tree, sum(j -> dconstants_dY[:, j] * dY[j], eachindex(dY, axes(dconstants_dY, 2)))
71+
)
72+
73+
dX = dX_dY .* reshape(dY, 1, length(dY))
74+
75+
return (NoTangent(), dtree, dX, NoTangent())
8276
end
8377

8478
end

src/EvaluateDerivative.jl

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,29 @@ function eval_grad_tree_array(
206206
variable::Union{Bool,Val}=Val(false),
207207
turbo::Union{Bool,Val}=Val(false),
208208
) where {T<:Number}
209-
n_gradients = if isa(variable, Val{true}) || (isa(variable, Bool) && variable)
209+
variable_mode = isa(variable, Val{true}) || (isa(variable, Bool) && variable)
210+
constant_mode = isa(variable, Val{false}) || (isa(variable, Bool) && !variable)
211+
both_mode = isa(variable, Val{:both})
212+
213+
n_gradients = if variable_mode
210214
size(cX, 1)::Int
211-
else
215+
elseif constant_mode
212216
count_constants(tree)::Int
217+
elseif both_mode
218+
size(cX, 1) + count_constants(tree)
213219
end
214-
result = if isa(variable, Val{true}) || (variable isa Bool && variable)
220+
221+
result = if variable_mode
215222
eval_grad_tree_array(tree, n_gradients, nothing, cX, operators, Val(true))
216-
else
223+
elseif constant_mode
217224
index_tree = index_constants(tree)
218225
eval_grad_tree_array(tree, n_gradients, index_tree, cX, operators, Val(false))
226+
elseif both_mode
227+
# features come first because we can use size(cX, 1) to skip them
228+
index_tree = index_constants(tree)
229+
eval_grad_tree_array(tree, n_gradients, index_tree, cX, operators, Val(:both))
219230
end
231+
220232
return (result.x, result.dx, result.ok)
221233
end
222234

@@ -226,11 +238,9 @@ function eval_grad_tree_array(
226238
index_tree::Union{NodeIndex,Nothing},
227239
cX::AbstractMatrix{T},
228240
operators::OperatorEnum,
229-
::Val{variable},
230-
)::ResultOk2 where {T<:Number,variable}
231-
result = _eval_grad_tree_array(
232-
tree, n_gradients, index_tree, cX, operators, Val(variable)
233-
)
241+
::Val{mode},
242+
)::ResultOk2 where {T<:Number,mode}
243+
result = _eval_grad_tree_array(tree, n_gradients, index_tree, cX, operators, Val(mode))
234244
!result.ok && return result
235245
return ResultOk2(
236246
result.x, result.dx, !(is_bad_array(result.x) || is_bad_array(result.dx))
@@ -260,30 +270,18 @@ end
260270
index_tree::Union{NodeIndex,Nothing},
261271
cX::AbstractMatrix{T},
262272
operators::OperatorEnum,
263-
::Val{variable},
264-
)::ResultOk2 where {T<:Number,variable}
273+
::Val{mode},
274+
)::ResultOk2 where {T<:Number,mode}
265275
nuna = get_nuna(operators)
266276
nbin = get_nbin(operators)
267277
deg1_branch_skeleton = quote
268278
grad_deg1_eval(
269-
tree,
270-
n_gradients,
271-
index_tree,
272-
cX,
273-
operators.unaops[i],
274-
operators,
275-
Val(variable),
279+
tree, n_gradients, index_tree, cX, operators.unaops[i], operators, Val(mode)
276280
)
277281
end
278282
deg2_branch_skeleton = quote
279283
grad_deg2_eval(
280-
tree,
281-
n_gradients,
282-
index_tree,
283-
cX,
284-
operators.binops[i],
285-
operators,
286-
Val(variable),
284+
tree, n_gradients, index_tree, cX, operators.binops[i], operators, Val(mode)
287285
)
288286
end
289287
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
@@ -310,7 +308,7 @@ end
310308
end
311309
quote
312310
if tree.degree == 0
313-
grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(variable))
311+
grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(mode))
314312
elseif tree.degree == 1
315313
$deg1_branch
316314
else
@@ -324,8 +322,8 @@ function grad_deg0_eval(
324322
n_gradients,
325323
index_tree::Union{NodeIndex,Nothing},
326324
cX::AbstractMatrix{T},
327-
::Val{variable},
328-
)::ResultOk2 where {T<:Number,variable}
325+
::Val{mode},
326+
)::ResultOk2 where {T<:Number,mode}
329327
const_part = deg0_eval(tree, cX).x
330328

331329
zero_mat = if isa(cX, Array)
@@ -334,15 +332,24 @@ function grad_deg0_eval(
334332
hcat([fill_similar(zero(T), cX, axes(cX, 2)) for _ in 1:n_gradients]...)'
335333
end
336334

337-
if variable == tree.constant
335+
if (mode isa Bool && mode == tree.constant)
336+
# No gradients at this leaf node
338337
return ResultOk2(const_part, zero_mat, true)
339338
end
340339

341-
index = if variable
342-
tree.feature
343-
else
340+
index = if (mode isa Bool && mode)
341+
tree.feature::UInt16
342+
elseif (mode isa Bool && !mode)
344343
(index_tree === nothing ? zero(UInt16) : index_tree.val::UInt16)
344+
elseif mode == :both
345+
index_tree::NodeIndex
346+
if tree.constant
347+
index_tree.val::UInt16 + UInt16(size(cX, 1))
348+
else
349+
tree.feature::UInt16
350+
end
345351
end
352+
346353
derivative_part = zero_mat
347354
derivative_part[index, :] .= one(T)
348355
return ResultOk2(const_part, derivative_part, true)
@@ -355,15 +362,15 @@ function grad_deg1_eval(
355362
cX::AbstractMatrix{T},
356363
op::F,
357364
operators::OperatorEnum,
358-
::Val{variable},
359-
)::ResultOk2 where {T<:Number,F,variable}
365+
::Val{mode},
366+
)::ResultOk2 where {T<:Number,F,mode}
360367
result = eval_grad_tree_array(
361368
tree.l,
362369
n_gradients,
363370
index_tree === nothing ? index_tree : index_tree.l,
364371
cX,
365372
operators,
366-
Val(variable),
373+
Val(mode),
367374
)
368375
!result.ok && return result
369376

@@ -389,15 +396,15 @@ function grad_deg2_eval(
389396
cX::AbstractMatrix{T},
390397
op::F,
391398
operators::OperatorEnum,
392-
::Val{variable},
393-
)::ResultOk2 where {T<:Number,F,variable}
399+
::Val{mode},
400+
)::ResultOk2 where {T<:Number,F,mode}
394401
result_l = eval_grad_tree_array(
395402
tree.l,
396403
n_gradients,
397404
index_tree === nothing ? index_tree : index_tree.l,
398405
cX,
399406
operators,
400-
Val(variable),
407+
Val(mode),
401408
)
402409
!result_l.ok && return result_l
403410
result_r = eval_grad_tree_array(
@@ -406,7 +413,7 @@ function grad_deg2_eval(
406413
index_tree === nothing ? index_tree : index_tree.r,
407414
cX,
408415
operators,
409-
Val(variable),
416+
Val(mode),
410417
)
411418
!result_r.ok && return result_r
412419

0 commit comments

Comments
 (0)