@@ -1364,6 +1364,22 @@ end
13641364
13651365(fn:: BroadcastIterator )(args... ) = fn. f ((args... ,))
13661366
1367+ abstract type AbstractUnwrappedBroadcastRestoreFunction end
1368+
1369+ struct __Identity <: AbstractUnwrappedBroadcastRestoreFunction end
1370+ (:: __Identity )(x) = x
1371+
1372+ struct __EachSlice{D} <: AbstractUnwrappedBroadcastRestoreFunction
1373+ dims:: D
1374+ drop:: Bool
1375+ end
1376+ (s:: __EachSlice{D} )(x) where {D} = eachslice (x; dims= s. dims, drop= s. drop)
1377+
1378+ struct __DropDims{D} <: AbstractUnwrappedBroadcastRestoreFunction
1379+ dims:: D
1380+ end
1381+ (s:: __DropDims{D} )(x) where {D} = dropdims (x; dims= s. dims)
1382+
13671383function unwrapped_broadcast (f:: F , x:: Base.Iterators.Zip , original_dims) where {F}
13681384 min_length = Base. inferencebarrier (minimum)(length, x. is)
13691385 itrs = [length (itr) > min_length ? itr[1 : min_length] : itr for itr in x. is]
@@ -1372,7 +1388,7 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Zip, original_dims) where {
13721388 else
13731389 unrolled_map (f, x)
13741390 end
1375- return result, original_dims, identity
1391+ return result, original_dims, __Identity ()
13761392end
13771393
13781394function unwrapped_broadcast (f:: F , x:: Base.Iterators.Enumerate , original_dims) where {F}
@@ -1385,7 +1401,7 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate, original_dims) w
13851401 else
13861402 unrolled_map (f, x)
13871403 end
1388- return result, original_dims, identity
1404+ return result, original_dims, __Identity ()
13891405end
13901406
13911407function unwrapped_broadcast (f:: F , x:: Slices , original_dims) where {F}
@@ -1402,25 +1418,25 @@ function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}
14021418 updated_dims = ()
14031419 if original_dims isa Colon
14041420 updated_dims = mapslices_dims
1405- re = x -> dropdims (x; dims = mapslices_dims)
1421+ re = __DropDims ( mapslices_dims)
14061422 else
14071423 for d in original_dims
14081424 idx = findfirst (isequal (d), x. slicemap)
14091425 @assert idx != = nothing " Expected dimension $d in $(x. slicemap) "
14101426 updated_dims = (updated_dims... , idx)
14111427 end
1412- re = x -> eachslice (x; dims = mapslices_dims, drop = true )
1428+ re = __EachSlice ( mapslices_dims, true )
14131429 end
14141430
14151431 return mapslices (f, px; dims= mapslices_dims), updated_dims, re
14161432 else
14171433 mapslices_dims = Tuple (filter (i -> ! (x. slicemap[i] isa Colon), 1 : ndims (px)))
14181434 if original_dims isa Colon
14191435 updated_dims = mapslices_dims
1420- re = x -> dropdims (x; dims = mapslices_dims)
1436+ re = __DropDims ( mapslices_dims)
14211437 else
14221438 updated_dims = Tuple (d for d in original_dims if d in mapslices_dims)
1423- re = x -> eachslice (x; dims = mapslices_dims, drop = false )
1439+ re = __EachSlice ( mapslices_dims, false )
14241440 end
14251441 return mapslices (f, px; dims= mapslices_dims), updated_dims, re
14261442 end
@@ -1429,7 +1445,7 @@ end
14291445function unwrapped_broadcast (f:: F , xs, original_dims) where {F}
14301446 mapped_xs = unrolled_map (f, xs)
14311447 applicable (size, xs) && (mapped_xs = reshape (mapped_xs, size (xs)))
1432- return mapped_xs, original_dims, identity
1448+ return mapped_xs, original_dims, __Identity ()
14331449end
14341450
14351451# TODO : once traced_call supports internal mutations, we can use traced_call here
0 commit comments