Skip to content

Commit c8e75f0

Browse files
authored
fix: improve wrong-mode pushforward/pullback (#932)
* fix: improve wrong-mode pushforward/pullback * Fix * Update .github/workflows/Test.yml
1 parent 0a47d24 commit c8e75f0

File tree

4 files changed

+101
-91
lines changed

4 files changed

+101
-91
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ jobs:
166166
actions: write
167167
contents: read
168168
strategy:
169-
fail-fast: true
169+
fail-fast: true # TODO: toggle
170170
matrix:
171171
version:
172172
- '1.10'

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -325,61 +325,69 @@ function _value_and_pullback_via_pushforward(
325325
pushforward_prep::PushforwardPrep,
326326
backend::AbstractADType,
327327
x::Real,
328-
dy,
328+
ty::NTuple{B},
329329
contexts::Vararg{Context, C},
330-
) where {F, C}
330+
) where {F, B, C}
331331
y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
332-
dx = dot(a, dy)
333-
return y, dx
332+
tx = map(ty) do dy
333+
dot(a, dy)
334+
end
335+
return y, arroftup_to_tupofarr(tx)
334336
end
335337

336338
function _value_and_pullback_via_pushforward(
337339
f::F,
338340
pushforward_prep::PushforwardPrep,
339341
backend::AbstractADType,
340342
x::Complex,
341-
dy,
343+
ty::NTuple{B},
342344
contexts::Vararg{Context, C},
343-
) where {F, C}
345+
) where {F, B, C}
344346
y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
345347
b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...))
346-
dx = real(dot(a, dy)) + im * real(dot(b, dy))
347-
return y, dx
348+
tx = map(ty) do dy
349+
real(dot(a, dy)) + im * real(dot(b, dy))
350+
end
351+
return y, arroftup_to_tupofarr(tx)
348352
end
349353

350354
function _value_and_pullback_via_pushforward(
351355
f::F,
352356
pushforward_prep::PushforwardPrep,
353357
backend::AbstractADType,
354358
x::AbstractArray{<:Real},
355-
dy,
359+
ty::NTuple{B},
356360
contexts::Vararg{Context, C},
357-
) where {F, C}
361+
) where {F, B, C}
358362
y = f(x, map(unwrap, contexts)...)
359-
dx = map(CartesianIndices(x)) do j
363+
tx = map(CartesianIndices(x)) do j
360364
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
361-
dot(a, dy)
365+
map(ty) do dy
366+
dot(a, dy)
367+
end
362368
end
363-
return y, dx
369+
return y, arroftup_to_tupofarr(tx)
364370
end
365371

366372
function _value_and_pullback_via_pushforward(
367373
f::F,
368374
pushforward_prep::PushforwardPrep,
369375
backend::AbstractADType,
370376
x::AbstractArray{<:Complex},
371-
dy,
377+
ty::NTuple{B},
372378
contexts::Vararg{Context, C},
373-
) where {F, C}
379+
) where {F, B, C}
374380
y = f(x, map(unwrap, contexts)...)
375-
dx = map(CartesianIndices(x)) do j
381+
tx = map(CartesianIndices(x)) do j
376382
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
377383
b = only(
378384
pushforward(f, pushforward_prep, backend, x, (im * basis(x, j),), contexts...),
379385
)
380-
real(dot(a, dy)) + im * real(dot(b, dy))
386+
map(ty) do dy
387+
real(dot(a, dy)) + im * real(dot(b, dy))
388+
end
381389
end
382-
return y, dx
390+
return y, arroftup_to_tupofarr(tx)
383391
end
384392

385393
function value_and_pullback(
@@ -392,13 +400,7 @@ function value_and_pullback(
392400
) where {F, B, C}
393401
check_prep(f, prep, backend, x, ty, contexts...)
394402
(; pushforward_prep) = prep
395-
ys_and_tx = ntuple(
396-
b -> _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...),
397-
Val(B),
398-
)
399-
y = first(first(ys_and_tx))
400-
tx = map(last, ys_and_tx)
401-
return y, tx
403+
return _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty, contexts...)
402404
end
403405

404406
function value_and_pullback!(
@@ -449,12 +451,14 @@ function _value_and_pullback_via_pushforward(
449451
pushforward_prep::PushforwardPrep,
450452
backend::AbstractADType,
451453
x::Real,
452-
dy,
454+
ty::NTuple{B},
453455
contexts::Vararg{Context, C},
454-
) where {F, C}
456+
) where {F, B, C}
455457
_, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
456-
dx = dot(a, dy)
457-
return dx
458+
tx = map(ty) do dy
459+
dot(a, dy)
460+
end
461+
return y, arroftup_to_tupofarr(tx)
458462
end
459463

460464
function _value_and_pullback_via_pushforward(
@@ -463,15 +467,17 @@ function _value_and_pullback_via_pushforward(
463467
pushforward_prep::PushforwardPrep,
464468
backend::AbstractADType,
465469
x::Complex,
466-
dy,
470+
ty::NTuple{B},
467471
contexts::Vararg{Context, C},
468-
) where {F, C}
472+
) where {F, B, C}
469473
a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
470474
_, b = onlysecond(
471475
value_and_pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)
472476
)
473-
dx = real(dot(a, dy)) + im * real(dot(b, dy))
474-
return dx
477+
tx = map(ty) do dy
478+
real(dot(a, dy)) + im * real(dot(b, dy))
479+
end
480+
return y, arroftup_to_tupofarr(tx)
475481
end
476482

477483
function _value_and_pullback_via_pushforward(
@@ -480,14 +486,16 @@ function _value_and_pullback_via_pushforward(
480486
pushforward_prep::PushforwardPrep,
481487
backend::AbstractADType,
482488
x::AbstractArray{<:Real},
483-
dy,
489+
ty::NTuple{B},
484490
contexts::Vararg{Context, C},
485-
) where {F, C}
486-
dx = map(CartesianIndices(x)) do j # preserve shape
491+
) where {F, B, C}
492+
tx = map(CartesianIndices(x)) do j # preserve shape
487493
_, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
488-
dot(a, dy)
494+
map(ty) do dy
495+
dot(a, dy)
496+
end
489497
end
490-
return dx
498+
return y, arroftup_to_tupofarr(tx)
491499
end
492500

493501
function _value_and_pullback_via_pushforward(
@@ -496,19 +504,21 @@ function _value_and_pullback_via_pushforward(
496504
pushforward_prep::PushforwardPrep,
497505
backend::AbstractADType,
498506
x::AbstractArray{<:Complex},
499-
dy,
507+
ty::NTuple{B},
500508
contexts::Vararg{Context, C},
501-
) where {F, C}
502-
dx = map(CartesianIndices(x)) do j # preserve shape
509+
) where {F, B, C}
510+
tx = map(CartesianIndices(x)) do j # preserve shape
503511
a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
504512
_, b = onlysecond(
505513
value_and_pushforward(
506514
f!, y, pushforward_prep, backend, x, (im * basis(x, j),), contexts...
507515
),
508516
)
509-
real(dot(a, dy)) + im * real(dot(b, dy))
517+
map(ty) do dy
518+
real(dot(a, dy)) + im * real(dot(b, dy))
519+
end
510520
end
511-
return dx
521+
return y, arroftup_to_tupofarr(tx)
512522
end
513523

514524
function value_and_pullback(
@@ -522,13 +532,9 @@ function value_and_pullback(
522532
) where {F, B, C}
523533
check_prep(f!, y, prep, backend, x, ty, contexts...)
524534
(; pushforward_prep) = prep
525-
tx = ntuple(
526-
b -> _value_and_pullback_via_pushforward(
527-
f!, y, pushforward_prep, backend, x, ty[b], contexts...
528-
),
529-
Val(B),
535+
return _value_and_pullback_via_pushforward(
536+
f!, y, pushforward_prep, backend, x, ty, contexts...
530537
)
531-
return y, tx
532538
end
533539

534540
function value_and_pullback!(

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,14 @@ function _value_and_pushforward_via_pullback(
324324
pullback_prep::PullbackPrep,
325325
backend::AbstractADType,
326326
x,
327-
dx,
327+
tx::NTuple{B},
328328
contexts::Vararg{Context, C},
329-
) where {F, C}
329+
) where {F, B, C}
330330
y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...))
331-
dy = dot(a, dx)
332-
return y, dy
331+
ty = map(tx) do dx
332+
dot(a, dx)
333+
end
334+
return y, arroftup_to_tupofarr(ty)
333335
end
334336

335337
function _value_and_pushforward_via_pullback(
@@ -338,13 +340,15 @@ function _value_and_pushforward_via_pullback(
338340
pullback_prep::PullbackPrep,
339341
backend::AbstractADType,
340342
x,
341-
dx,
343+
tx::NTuple{B},
342344
contexts::Vararg{Context, C},
343-
) where {F, C}
345+
) where {F, B, C}
344346
y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...))
345347
b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y_ex),), contexts...))
346-
dy = real(dot(a, dx)) + im * real(dot(b, dx))
347-
return y, dy
348+
ty = map(tx) do dx
349+
real(dot(a, dx)) + im * real(dot(b, dx))
350+
end
351+
return y, arroftup_to_tupofarr(ty)
348352
end
349353

350354
function _value_and_pushforward_via_pullback(
@@ -353,15 +357,17 @@ function _value_and_pushforward_via_pullback(
353357
pullback_prep::PullbackPrep,
354358
backend::AbstractADType,
355359
x,
356-
dx,
360+
tx::NTuple{B},
357361
contexts::Vararg{Context, C},
358-
) where {F, C}
362+
) where {F, B, C}
359363
y = f(x, map(unwrap, contexts)...)
360-
dy = map(CartesianIndices(y_ex)) do i
364+
ty = map(CartesianIndices(y_ex)) do i
361365
a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...))
362-
dot(a, dx)
366+
map(tx) do dx
367+
dot(a, dx)
368+
end
363369
end
364-
return y, dy
370+
return y, arroftup_to_tupofarr(ty)
365371
end
366372

367373
function _value_and_pushforward_via_pullback(
@@ -370,16 +376,18 @@ function _value_and_pushforward_via_pullback(
370376
pullback_prep::PullbackPrep,
371377
backend::AbstractADType,
372378
x,
373-
dx,
379+
tx::NTuple{B},
374380
contexts::Vararg{Context, C},
375-
) where {F, C}
381+
) where {F, B, C}
376382
y = f(x, map(unwrap, contexts)...)
377-
dy = map(CartesianIndices(y_ex)) do i
383+
ty = map(CartesianIndices(y_ex)) do i
378384
a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...))
379385
b = only(pullback(f, pullback_prep, backend, x, (im * basis(y_ex, i),), contexts...))
380-
real(dot(a, dx)) + im * real(dot(b, dx))
386+
map(tx) do dx
387+
real(dot(a, dx)) + im * real(dot(b, dx))
388+
end
381389
end
382-
return y, dy
390+
return y, arroftup_to_tupofarr(ty)
383391
end
384392

385393
function value_and_pushforward(
@@ -392,13 +400,7 @@ function value_and_pushforward(
392400
) where {F, B, C}
393401
check_prep(f, prep, backend, x, tx, contexts...)
394402
(; pullback_prep, y_example) = prep
395-
ys_and_ty = ntuple(
396-
b -> _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx[b], contexts...),
397-
Val(B),
398-
)
399-
y = first(first(ys_and_ty))
400-
ty = map(last, ys_and_ty)
401-
return y, ty
403+
return _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx, contexts...)
402404
end
403405

404406
function value_and_pushforward!(
@@ -449,14 +451,16 @@ function _value_and_pushforward_via_pullback(
449451
pullback_prep::PullbackPrep,
450452
backend::AbstractADType,
451453
x,
452-
dx,
454+
tx::NTuple{B},
453455
contexts::Vararg{Context, C},
454-
) where {F, C}
455-
dy = map(CartesianIndices(y)) do i # preserve shape
456+
) where {F, B, C}
457+
ty = map(CartesianIndices(y)) do i # preserve shape
456458
_, a = onlysecond(value_and_pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
457-
dot(a, dx)
459+
map(tx) do dx
460+
dot(a, dx)
461+
end
458462
end
459-
return dy
463+
return y, arroftup_to_tupofarr(ty)
460464
end
461465

462466
function _value_and_pushforward_via_pullback(
@@ -465,17 +469,19 @@ function _value_and_pushforward_via_pullback(
465469
pullback_prep::PullbackPrep,
466470
backend::AbstractADType,
467471
x,
468-
dx,
472+
tx::NTuple{B},
469473
contexts::Vararg{Context, C},
470-
) where {F, C}
471-
dy = map(CartesianIndices(y)) do i # preserve shape
474+
) where {F, B, C}
475+
ty = map(CartesianIndices(y)) do i # preserve shape
472476
a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
473477
_, b = onlysecond(
474478
value_and_pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...)
475479
)
476-
real(dot(a, dx)) + im * real(dot(b, dx))
480+
map(tx) do dx
481+
real(dot(a, dx)) + im * real(dot(b, dx))
482+
end
477483
end
478-
return dy
484+
return y, arroftup_to_tupofarr(ty)
479485
end
480486

481487
function value_and_pushforward(
@@ -489,12 +495,7 @@ function value_and_pushforward(
489495
) where {F, B, C}
490496
check_prep(f!, y, prep, backend, x, tx, contexts...)
491497
(; pullback_prep) = prep
492-
ty = ntuple(
493-
b ->
494-
_value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...),
495-
Val(B),
496-
)
497-
return y, ty
498+
return _value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx, contexts...)
498499
end
499500

500501
function value_and_pushforward!(

DifferentiationInterface/src/utils/linalg.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,6 @@ The trivial dense fallback is designed to protect against a change of format in
3535
get_pattern(M::AbstractMatrix) = trues(size(M))
3636

3737
onlysecond((a, b)) = (a, only(b))
38+
39+
arroftup_to_tupofarr(x::NTuple) = x
40+
arroftup_to_tupofarr(x::AbstractArray{<:NTuple{B}}) where {B} = ntuple(b -> getindex.(x, b), Val(B))

0 commit comments

Comments
 (0)