252252
253253# # Preparation
254254
255- struct PullbackPushforwardPrep{SIG, E} <: PushforwardPrep{SIG}
255+ struct PullbackPushforwardPrep{SIG, E, Y } <: PushforwardPrep{SIG}
256256 _sig:: Val{SIG}
257257 pullback_prep:: E
258+ y_example:: Y
258259end
259260
260261function prepare_pushforward_nokwarg (
@@ -296,7 +297,7 @@ function _prepare_pushforward_aux(
296297 basis (y)
297298 end
298299 pullback_prep = prepare_pullback_nokwarg (strict, f, backend, x, (dy,), contexts... )
299- return PullbackPushforwardPrep (_sig, pullback_prep)
300+ return PullbackPushforwardPrep (_sig, pullback_prep, y )
300301end
301302
302303function _prepare_pushforward_aux (
@@ -312,71 +313,73 @@ function _prepare_pushforward_aux(
312313 _sig = signature (f!, y, backend, x, tx, contexts... ; strict)
313314 dy = basis (y)
314315 pullback_prep = prepare_pullback_nokwarg (strict, f!, y, backend, x, (dy,), contexts... )
315- return PullbackPushforwardPrep (_sig, pullback_prep)
316+ return PullbackPushforwardPrep (_sig, pullback_prep, y )
316317end
317318
318319# # One argument
319320
320- function _pushforward_via_pullback (
321- y :: Number ,
321+ function _value_and_pushforward_via_pullback (
322+ y_ex :: Number ,
322323 f:: F ,
323324 pullback_prep:: PullbackPrep ,
324325 backend:: AbstractADType ,
325326 x,
326327 dx,
327328 contexts:: Vararg{Context, C} ,
328329 ) where {F, C}
329- a = only ( pullback (f, pullback_prep, backend, x, (oneunit (y ),), contexts... ))
330+ y, a = onlysecond ( value_and_pullback (f, pullback_prep, backend, x, (oneunit (y_ex ),), contexts... ))
330331 dy = dot (a, dx)
331- return dy
332+ return y, dy
332333end
333334
334- function _pushforward_via_pullback (
335- y :: Complex ,
335+ function _value_and_pushforward_via_pullback (
336+ y_ex :: Complex ,
336337 f:: F ,
337338 pullback_prep:: PullbackPrep ,
338339 backend:: AbstractADType ,
339340 x,
340341 dx,
341342 contexts:: Vararg{Context, C} ,
342343 ) where {F, C}
343- a = only ( pullback (f, pullback_prep, backend, x, (oneunit (y ),), contexts... ))
344- b = only (pullback (f, pullback_prep, backend, x, (im * oneunit (y ),), contexts... ))
344+ y, a = onlysecond ( value_and_pullback (f, pullback_prep, backend, x, (oneunit (y_ex ),), contexts... ))
345+ b = only (pullback (f, pullback_prep, backend, x, (im * oneunit (y_ex ),), contexts... ))
345346 dy = real (dot (a, dx)) + im * real (dot (b, dx))
346- return dy
347+ return y, dy
347348end
348349
349- function _pushforward_via_pullback (
350- y :: AbstractArray{<:Real} ,
350+ function _value_and_pushforward_via_pullback (
351+ y_ex :: AbstractArray{<:Real} ,
351352 f:: F ,
352353 pullback_prep:: PullbackPrep ,
353354 backend:: AbstractADType ,
354355 x,
355356 dx,
356357 contexts:: Vararg{Context, C} ,
357358 ) where {F, C}
358- dy = map (CartesianIndices (y)) do i
359- a = only (pullback (f, pullback_prep, backend, x, (basis (y, i),), contexts... ))
359+ y = f (x, map (unwrap, contexts)... )
360+ dy = map (CartesianIndices (y_ex)) do i
361+ a = only (pullback (f, pullback_prep, backend, x, (basis (y_ex, i),), contexts... ))
360362 dot (a, dx)
361363 end
362- return dy
364+ return y, dy
363365end
364366
365- function _pushforward_via_pullback (
366- y :: AbstractArray{<:Complex} ,
367+ function _value_and_pushforward_via_pullback (
368+ y_ex :: AbstractArray{<:Complex} ,
367369 f:: F ,
368370 pullback_prep:: PullbackPrep ,
369371 backend:: AbstractADType ,
370372 x,
371373 dx,
372374 contexts:: Vararg{Context, C} ,
373375 ) where {F, C}
374- dy = map (CartesianIndices (y)) do i
375- a = only (pullback (f, pullback_prep, backend, x, (basis (y, i),), contexts... ))
376- b = only (pullback (f, pullback_prep, backend, x, (im * basis (y, i),), contexts... ))
376+ y = f (x, map (unwrap, contexts)... )
377+ dy = map (CartesianIndices (y_ex)) do i
378+ a = only (pullback (f, pullback_prep, backend, x, (basis (y_ex, i),), contexts... ))
379+ b = only (pullback (f, pullback_prep, backend, x, (im * basis (y_ex, i),), contexts... ))
377380 real (dot (a, dx)) + im * real (dot (b, dx))
378381 end
379- return dy
382+ return y, dy
380383end
381384
382385function value_and_pushforward (
@@ -388,12 +391,13 @@ function value_and_pushforward(
388391 contexts:: Vararg{Context, C} ,
389392 ) where {F, B, C}
390393 check_prep (f, prep, backend, x, tx, contexts... )
391- (; pullback_prep) = prep
392- y = f (x, map (unwrap, contexts)... )
393- ty = ntuple (
394- b -> _pushforward_via_pullback (y, f, pullback_prep, backend, x, tx[b], contexts... ),
394+ (; pullback_prep, y_example) = prep
395+ ys_and_ty = ntuple (
396+ b -> _value_and_pushforward_via_pullback (y_example, f, pullback_prep, backend, x, tx[b], contexts... ),
395397 Val (B),
396398 )
399+ y = first (first (ys_and_ty))
400+ ty = map (last, ys_and_ty)
397401 return y, ty
398402end
399403
439443
440444# # Two arguments
441445
442- function _pushforward_via_pullback (
446+ function _value_and_pushforward_via_pullback (
443447 f!:: F ,
444448 y:: AbstractArray{<:Real} ,
445449 pullback_prep:: PullbackPrep ,
@@ -449,13 +453,13 @@ function _pushforward_via_pullback(
449453 contexts:: Vararg{Context, C} ,
450454 ) where {F, C}
451455 dy = map (CartesianIndices (y)) do i # preserve shape
452- a = only ( pullback (f!, y, pullback_prep, backend, x, (basis (y, i),), contexts... ))
456+ _, a = onlysecond ( value_and_pullback (f!, y, pullback_prep, backend, x, (basis (y, i),), contexts... ))
453457 dot (a, dx)
454458 end
455459 return dy
456460end
457461
458- function _pushforward_via_pullback (
462+ function _value_and_pushforward_via_pullback (
459463 f!:: F ,
460464 y:: AbstractArray{<:Complex} ,
461465 pullback_prep:: PullbackPrep ,
@@ -466,8 +470,8 @@ function _pushforward_via_pullback(
466470 ) where {F, C}
467471 dy = map (CartesianIndices (y)) do i # preserve shape
468472 a = only (pullback (f!, y, pullback_prep, backend, x, (basis (y, i),), contexts... ))
469- b = only (
470- pullback (f!, y, pullback_prep, backend, x, (im * basis (y, i),), contexts... )
473+ _, b = onlysecond (
474+ value_and_pullback (f!, y, pullback_prep, backend, x, (im * basis (y, i),), contexts... )
471475 )
472476 real (dot (a, dx)) + im * real (dot (b, dx))
473477 end
@@ -487,10 +491,9 @@ function value_and_pushforward(
487491 (; pullback_prep) = prep
488492 ty = ntuple (
489493 b ->
490- _pushforward_via_pullback (f!, y, pullback_prep, backend, x, tx[b], contexts... ),
494+ _value_and_pushforward_via_pullback (f!, y, pullback_prep, backend, x, tx[b], contexts... ),
491495 Val (B),
492496 )
493- f! (y, x, map (unwrap, contexts)... )
494497 return y, ty
495498end
496499
0 commit comments