Skip to content

Commit 11796c5

Browse files
authored
perf: improve wrong-mode pushforward/pullback for scalars (#931)
* perf: improve wrong-mode pushforward/pullback for scalars * Fix * Fix
1 parent 3f4c2de commit 11796c5

File tree

4 files changed

+67
-60
lines changed

4 files changed

+67
-60
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ jobs:
9191
actions: write
9292
contents: read
9393
strategy:
94-
fail-fast: false # TODO: toggle
94+
fail-fast: true # TODO: toggle
9595
matrix:
9696
version:
9797
- '1.10'

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -320,64 +320,66 @@ end
320320

321321
## One argument
322322

323-
function _pullback_via_pushforward(
323+
function _value_and_pullback_via_pushforward(
324324
f::F,
325325
pushforward_prep::PushforwardPrep,
326326
backend::AbstractADType,
327327
x::Real,
328328
dy,
329329
contexts::Vararg{Context, C},
330330
) where {F, C}
331-
a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
331+
y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
332332
dx = dot(a, dy)
333-
return dx
333+
return y, dx
334334
end
335335

336-
function _pullback_via_pushforward(
336+
function _value_and_pullback_via_pushforward(
337337
f::F,
338338
pushforward_prep::PushforwardPrep,
339339
backend::AbstractADType,
340340
x::Complex,
341341
dy,
342342
contexts::Vararg{Context, C},
343343
) where {F, C}
344-
a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
344+
y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
345345
b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...))
346346
dx = real(dot(a, dy)) + im * real(dot(b, dy))
347-
return dx
347+
return y, dx
348348
end
349349

350-
function _pullback_via_pushforward(
350+
function _value_and_pullback_via_pushforward(
351351
f::F,
352352
pushforward_prep::PushforwardPrep,
353353
backend::AbstractADType,
354354
x::AbstractArray{<:Real},
355355
dy,
356356
contexts::Vararg{Context, C},
357357
) where {F, C}
358+
y = f(x, map(unwrap, contexts)...)
358359
dx = map(CartesianIndices(x)) do j
359360
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
360361
dot(a, dy)
361362
end
362-
return dx
363+
return y, dx
363364
end
364365

365-
function _pullback_via_pushforward(
366+
function _value_and_pullback_via_pushforward(
366367
f::F,
367368
pushforward_prep::PushforwardPrep,
368369
backend::AbstractADType,
369370
x::AbstractArray{<:Complex},
370371
dy,
371372
contexts::Vararg{Context, C},
372373
) where {F, C}
374+
y = f(x, map(unwrap, contexts)...)
373375
dx = map(CartesianIndices(x)) do j
374376
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
375377
b = only(
376378
pushforward(f, pushforward_prep, backend, x, (im * basis(x, j),), contexts...),
377379
)
378380
real(dot(a, dy)) + im * real(dot(b, dy))
379381
end
380-
return dx
382+
return y, dx
381383
end
382384

383385
function value_and_pullback(
@@ -390,11 +392,12 @@ function value_and_pullback(
390392
) where {F, B, C}
391393
check_prep(f, prep, backend, x, ty, contexts...)
392394
(; pushforward_prep) = prep
393-
y = f(x, map(unwrap, contexts)...)
394-
tx = ntuple(
395-
b -> _pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...),
395+
ys_and_tx = ntuple(
396+
b -> _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...),
396397
Val(B),
397398
)
399+
y = first(first(ys_and_tx))
400+
tx = map(last, ys_and_tx)
398401
return y, tx
399402
end
400403

@@ -440,7 +443,7 @@ end
440443

441444
## Two arguments
442445

443-
function _pullback_via_pushforward(
446+
function _value_and_pullback_via_pushforward(
444447
f!::F,
445448
y,
446449
pushforward_prep::PushforwardPrep,
@@ -449,12 +452,12 @@ function _pullback_via_pushforward(
449452
dy,
450453
contexts::Vararg{Context, C},
451454
) where {F, C}
452-
a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
455+
_, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
453456
dx = dot(a, dy)
454457
return dx
455458
end
456459

457-
function _pullback_via_pushforward(
460+
function _value_and_pullback_via_pushforward(
458461
f!::F,
459462
y,
460463
pushforward_prep::PushforwardPrep,
@@ -464,14 +467,14 @@ function _pullback_via_pushforward(
464467
contexts::Vararg{Context, C},
465468
) where {F, C}
466469
a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
467-
b = only(
468-
pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)
470+
_, b = onlysecond(
471+
value_and_pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)
469472
)
470473
dx = real(dot(a, dy)) + im * real(dot(b, dy))
471474
return dx
472475
end
473476

474-
function _pullback_via_pushforward(
477+
function _value_and_pullback_via_pushforward(
475478
f!::F,
476479
y,
477480
pushforward_prep::PushforwardPrep,
@@ -481,13 +484,13 @@ function _pullback_via_pushforward(
481484
contexts::Vararg{Context, C},
482485
) where {F, C}
483486
dx = map(CartesianIndices(x)) do j # preserve shape
484-
a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
487+
_, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
485488
dot(a, dy)
486489
end
487490
return dx
488491
end
489492

490-
function _pullback_via_pushforward(
493+
function _value_and_pullback_via_pushforward(
491494
f!::F,
492495
y,
493496
pushforward_prep::PushforwardPrep,
@@ -498,8 +501,8 @@ function _pullback_via_pushforward(
498501
) where {F, C}
499502
dx = map(CartesianIndices(x)) do j # preserve shape
500503
a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
501-
b = only(
502-
pushforward(
504+
_, b = onlysecond(
505+
value_and_pushforward(
503506
f!, y, pushforward_prep, backend, x, (im * basis(x, j),), contexts...
504507
),
505508
)
@@ -520,12 +523,11 @@ function value_and_pullback(
520523
check_prep(f!, y, prep, backend, x, ty, contexts...)
521524
(; pushforward_prep) = prep
522525
tx = ntuple(
523-
b -> _pullback_via_pushforward(
526+
b -> _value_and_pullback_via_pushforward(
524527
f!, y, pushforward_prep, backend, x, ty[b], contexts...
525528
),
526529
Val(B),
527530
)
528-
f!(y, x, map(unwrap, contexts)...)
529531
return y, tx
530532
end
531533

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,10 @@ end
252252

253253
## Preparation
254254

255-
struct PullbackPushforwardPrep{SIG, E} <: PushforwardPrep{SIG}
255+
struct PullbackPushforwardPrep{SIG, E, Y} <: PushforwardPrep{SIG}
256256
_sig::Val{SIG}
257257
pullback_prep::E
258+
y_example::Y
258259
end
259260

260261
function prepare_pushforward_nokwarg(
@@ -296,7 +297,7 @@ function _prepare_pushforward_aux(
296297
basis(y)
297298
end
298299
pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...)
299-
return PullbackPushforwardPrep(_sig, pullback_prep)
300+
return PullbackPushforwardPrep(_sig, pullback_prep, y)
300301
end
301302

302303
function _prepare_pushforward_aux(
@@ -312,71 +313,73 @@ function _prepare_pushforward_aux(
312313
_sig = signature(f!, y, backend, x, tx, contexts...; strict)
313314
dy = basis(y)
314315
pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...)
315-
return PullbackPushforwardPrep(_sig, pullback_prep)
316+
return PullbackPushforwardPrep(_sig, pullback_prep, y)
316317
end
317318

318319
## One argument
319320

320-
function _pushforward_via_pullback(
321-
y::Number,
321+
function _value_and_pushforward_via_pullback(
322+
y_ex::Number,
322323
f::F,
323324
pullback_prep::PullbackPrep,
324325
backend::AbstractADType,
325326
x,
326327
dx,
327328
contexts::Vararg{Context, C},
328329
) where {F, C}
329-
a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...))
330+
y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...))
330331
dy = dot(a, dx)
331-
return dy
332+
return y, dy
332333
end
333334

334-
function _pushforward_via_pullback(
335-
y::Complex,
335+
function _value_and_pushforward_via_pullback(
336+
y_ex::Complex,
336337
f::F,
337338
pullback_prep::PullbackPrep,
338339
backend::AbstractADType,
339340
x,
340341
dx,
341342
contexts::Vararg{Context, C},
342343
) where {F, C}
343-
a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...))
344-
b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y),), contexts...))
344+
y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...))
345+
b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y_ex),), contexts...))
345346
dy = real(dot(a, dx)) + im * real(dot(b, dx))
346-
return dy
347+
return y, dy
347348
end
348349

349-
function _pushforward_via_pullback(
350-
y::AbstractArray{<:Real},
350+
function _value_and_pushforward_via_pullback(
351+
y_ex::AbstractArray{<:Real},
351352
f::F,
352353
pullback_prep::PullbackPrep,
353354
backend::AbstractADType,
354355
x,
355356
dx,
356357
contexts::Vararg{Context, C},
357358
) where {F, C}
358-
dy = map(CartesianIndices(y)) do i
359-
a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...))
359+
y = f(x, map(unwrap, contexts)...)
360+
dy = map(CartesianIndices(y_ex)) do i
361+
a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...))
360362
dot(a, dx)
361363
end
362-
return dy
364+
return y, dy
363365
end
364366

365-
function _pushforward_via_pullback(
366-
y::AbstractArray{<:Complex},
367+
function _value_and_pushforward_via_pullback(
368+
y_ex::AbstractArray{<:Complex},
367369
f::F,
368370
pullback_prep::PullbackPrep,
369371
backend::AbstractADType,
370372
x,
371373
dx,
372374
contexts::Vararg{Context, C},
373375
) where {F, C}
374-
dy = map(CartesianIndices(y)) do i
375-
a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...))
376-
b = only(pullback(f, pullback_prep, backend, x, (im * basis(y, i),), contexts...))
376+
y = f(x, map(unwrap, contexts)...)
377+
dy = map(CartesianIndices(y_ex)) do i
378+
a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...))
379+
b = only(pullback(f, pullback_prep, backend, x, (im * basis(y_ex, i),), contexts...))
377380
real(dot(a, dx)) + im * real(dot(b, dx))
378381
end
379-
return dy
382+
return y, dy
380383
end
381384

382385
function value_and_pushforward(
@@ -388,12 +391,13 @@ function value_and_pushforward(
388391
contexts::Vararg{Context, C},
389392
) where {F, B, C}
390393
check_prep(f, prep, backend, x, tx, contexts...)
391-
(; pullback_prep) = prep
392-
y = f(x, map(unwrap, contexts)...)
393-
ty = ntuple(
394-
b -> _pushforward_via_pullback(y, f, pullback_prep, backend, x, tx[b], contexts...),
394+
(; pullback_prep, y_example) = prep
395+
ys_and_ty = ntuple(
396+
b -> _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx[b], contexts...),
395397
Val(B),
396398
)
399+
y = first(first(ys_and_ty))
400+
ty = map(last, ys_and_ty)
397401
return y, ty
398402
end
399403

@@ -439,7 +443,7 @@ end
439443

440444
## Two arguments
441445

442-
function _pushforward_via_pullback(
446+
function _value_and_pushforward_via_pullback(
443447
f!::F,
444448
y::AbstractArray{<:Real},
445449
pullback_prep::PullbackPrep,
@@ -449,13 +453,13 @@ function _pushforward_via_pullback(
449453
contexts::Vararg{Context, C},
450454
) where {F, C}
451455
dy = map(CartesianIndices(y)) do i # preserve shape
452-
a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
456+
_, a = onlysecond(value_and_pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
453457
dot(a, dx)
454458
end
455459
return dy
456460
end
457461

458-
function _pushforward_via_pullback(
462+
function _value_and_pushforward_via_pullback(
459463
f!::F,
460464
y::AbstractArray{<:Complex},
461465
pullback_prep::PullbackPrep,
@@ -466,8 +470,8 @@ function _pushforward_via_pullback(
466470
) where {F, C}
467471
dy = map(CartesianIndices(y)) do i # preserve shape
468472
a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
469-
b = only(
470-
pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...)
473+
_, b = onlysecond(
474+
value_and_pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...)
471475
)
472476
real(dot(a, dx)) + im * real(dot(b, dx))
473477
end
@@ -487,10 +491,9 @@ function value_and_pushforward(
487491
(; pullback_prep) = prep
488492
ty = ntuple(
489493
b ->
490-
_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...),
494+
_value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...),
491495
Val(B),
492496
)
493-
f!(y, x, map(unwrap, contexts)...)
494497
return y, ty
495498
end
496499

DifferentiationInterface/src/utils/linalg.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,5 @@ Only specialized on `SparseMatrixCSC` because it is used with symbolic backends,
3333
The trivial dense fallback is designed to protect against a change of format in these packages.
3434
"""
3535
get_pattern(M::AbstractMatrix) = trues(size(M))
36+
37+
onlysecond((a, b)) = (a, only(b))

0 commit comments

Comments
 (0)