@@ -325,61 +325,69 @@ function _value_and_pullback_via_pushforward(
325325 pushforward_prep:: PushforwardPrep ,
326326 backend:: AbstractADType ,
327327 x:: Real ,
328- dy ,
328+ ty :: NTuple{B} ,
329329 contexts:: Vararg{Context, C} ,
330- ) where {F, C}
330+ ) where {F, B, C}
331331 y, a = onlysecond (value_and_pushforward (f, pushforward_prep, backend, x, (oneunit (x),), contexts... ))
332- dx = dot (a, dy)
333- return y, dx
332+ tx = map (ty) do dy
333+ dot (a, dy)
334+ end
335+ return y, arroftup_to_tupofarr (tx)
334336end
335337
336338function _value_and_pullback_via_pushforward (
337339 f:: F ,
338340 pushforward_prep:: PushforwardPrep ,
339341 backend:: AbstractADType ,
340342 x:: Complex ,
341- dy ,
343+ ty :: NTuple{B} ,
342344 contexts:: Vararg{Context, C} ,
343- ) where {F, C}
345+ ) where {F, B, C}
344346 y, a = onlysecond (value_and_pushforward (f, pushforward_prep, backend, x, (oneunit (x),), contexts... ))
345347 b = only (pushforward (f, pushforward_prep, backend, x, (im * oneunit (x),), contexts... ))
346- dx = real (dot (a, dy)) + im * real (dot (b, dy))
347- return y, dx
348+ tx = map (ty) do dy
349+ real (dot (a, dy)) + im * real (dot (b, dy))
350+ end
351+ return y, arroftup_to_tupofarr (tx)
348352end
349353
350354function _value_and_pullback_via_pushforward (
351355 f:: F ,
352356 pushforward_prep:: PushforwardPrep ,
353357 backend:: AbstractADType ,
354358 x:: AbstractArray{<:Real} ,
355- dy ,
359+ ty :: NTuple{B} ,
356360 contexts:: Vararg{Context, C} ,
357- ) where {F, C}
361+ ) where {F, B, C}
358362 y = f (x, map (unwrap, contexts)... )
359- dx = map (CartesianIndices (x)) do j
363+ tx = map (CartesianIndices (x)) do j
360364 a = only (pushforward (f, pushforward_prep, backend, x, (basis (x, j),), contexts... ))
361- dot (a, dy)
365+ map (ty) do dy
366+ dot (a, dy)
367+ end
362368 end
363- return y, dx
369+ return y, arroftup_to_tupofarr (tx)
364370end
365371
366372function _value_and_pullback_via_pushforward (
367373 f:: F ,
368374 pushforward_prep:: PushforwardPrep ,
369375 backend:: AbstractADType ,
370376 x:: AbstractArray{<:Complex} ,
371- dy ,
377+ ty :: NTuple{B} ,
372378 contexts:: Vararg{Context, C} ,
373- ) where {F, C}
379+ ) where {F, B, C}
374380 y = f (x, map (unwrap, contexts)... )
375- dx = map (CartesianIndices (x)) do j
381+ tx = map (CartesianIndices (x)) do j
376382 a = only (pushforward (f, pushforward_prep, backend, x, (basis (x, j),), contexts... ))
377383 b = only (
378384 pushforward (f, pushforward_prep, backend, x, (im * basis (x, j),), contexts... ),
379385 )
380- real (dot (a, dy)) + im * real (dot (b, dy))
386+ map (ty) do dy
387+ real (dot (a, dy)) + im * real (dot (b, dy))
388+ end
381389 end
382- return y, dx
390+ return y, arroftup_to_tupofarr (tx)
383391end
384392
385393function value_and_pullback (
@@ -392,13 +400,7 @@ function value_and_pullback(
392400 ) where {F, B, C}
393401 check_prep (f, prep, backend, x, ty, contexts... )
394402 (; pushforward_prep) = prep
395- ys_and_tx = ntuple (
396- b -> _value_and_pullback_via_pushforward (f, pushforward_prep, backend, x, ty[b], contexts... ),
397- Val (B),
398- )
399- y = first (first (ys_and_tx))
400- tx = map (last, ys_and_tx)
401- return y, tx
403+ return _value_and_pullback_via_pushforward (f, pushforward_prep, backend, x, ty, contexts... )
402404end
403405
404406function value_and_pullback! (
@@ -449,12 +451,14 @@ function _value_and_pullback_via_pushforward(
449451 pushforward_prep:: PushforwardPrep ,
450452 backend:: AbstractADType ,
451453 x:: Real ,
452- dy ,
454+ ty :: NTuple{B} ,
453455 contexts:: Vararg{Context, C} ,
454- ) where {F, C}
456+ ) where {F, B, C}
455457 _, a = onlysecond (value_and_pushforward (f!, y, pushforward_prep, backend, x, (oneunit (x),), contexts... ))
456- dx = dot (a, dy)
457- return dx
458+ tx = map (ty) do dy
459+ dot (a, dy)
460+ end
461+ return y, arroftup_to_tupofarr (tx)
458462end
459463
460464function _value_and_pullback_via_pushforward (
@@ -463,15 +467,17 @@ function _value_and_pullback_via_pushforward(
463467 pushforward_prep:: PushforwardPrep ,
464468 backend:: AbstractADType ,
465469 x:: Complex ,
466- dy ,
470+ ty :: NTuple{B} ,
467471 contexts:: Vararg{Context, C} ,
468- ) where {F, C}
472+ ) where {F, B, C}
469473 a = only (pushforward (f!, y, pushforward_prep, backend, x, (oneunit (x),), contexts... ))
470474 _, b = onlysecond (
471475 value_and_pushforward (f!, y, pushforward_prep, backend, x, (im * oneunit (x),), contexts... )
472476 )
473- dx = real (dot (a, dy)) + im * real (dot (b, dy))
474- return dx
477+ tx = map (ty) do dy
478+ real (dot (a, dy)) + im * real (dot (b, dy))
479+ end
480+ return y, arroftup_to_tupofarr (tx)
475481end
476482
477483function _value_and_pullback_via_pushforward (
@@ -480,14 +486,16 @@ function _value_and_pullback_via_pushforward(
480486 pushforward_prep:: PushforwardPrep ,
481487 backend:: AbstractADType ,
482488 x:: AbstractArray{<:Real} ,
483- dy ,
489+ ty :: NTuple{B} ,
484490 contexts:: Vararg{Context, C} ,
485- ) where {F, C}
486- dx = map (CartesianIndices (x)) do j # preserve shape
491+ ) where {F, B, C}
492+ tx = map (CartesianIndices (x)) do j # preserve shape
487493 _, a = onlysecond (value_and_pushforward (f!, y, pushforward_prep, backend, x, (basis (x, j),), contexts... ))
488- dot (a, dy)
494+ map (ty) do dy
495+ dot (a, dy)
496+ end
489497 end
490- return dx
498+ return y, arroftup_to_tupofarr (tx)
491499end
492500
493501function _value_and_pullback_via_pushforward (
@@ -496,19 +504,21 @@ function _value_and_pullback_via_pushforward(
496504 pushforward_prep:: PushforwardPrep ,
497505 backend:: AbstractADType ,
498506 x:: AbstractArray{<:Complex} ,
499- dy ,
507+ ty :: NTuple{B} ,
500508 contexts:: Vararg{Context, C} ,
501- ) where {F, C}
502- dx = map (CartesianIndices (x)) do j # preserve shape
509+ ) where {F, B, C}
510+ tx = map (CartesianIndices (x)) do j # preserve shape
503511 a = only (pushforward (f!, y, pushforward_prep, backend, x, (basis (x, j),), contexts... ))
504512 _, b = onlysecond (
505513 value_and_pushforward (
506514 f!, y, pushforward_prep, backend, x, (im * basis (x, j),), contexts...
507515 ),
508516 )
509- real (dot (a, dy)) + im * real (dot (b, dy))
517+ map (ty) do dy
518+ real (dot (a, dy)) + im * real (dot (b, dy))
519+ end
510520 end
511- return dx
521+ return y, arroftup_to_tupofarr (tx)
512522end
513523
514524function value_and_pullback (
@@ -522,13 +532,9 @@ function value_and_pullback(
522532 ) where {F, B, C}
523533 check_prep (f!, y, prep, backend, x, ty, contexts... )
524534 (; pushforward_prep) = prep
525- tx = ntuple (
526- b -> _value_and_pullback_via_pushforward (
527- f!, y, pushforward_prep, backend, x, ty[b], contexts...
528- ),
529- Val (B),
535+ return _value_and_pullback_via_pushforward (
536+ f!, y, pushforward_prep, backend, x, ty, contexts...
530537 )
531- return y, tx
532538end
533539
534540function value_and_pullback! (
0 commit comments