@@ -1360,6 +1360,22 @@ end
13601360
13611361(fn:: BroadcastIterator )(args... ) = fn. f ((args... ,))
13621362
1363+ abstract type AbstractUnwrappedBroadcastRestoreFunction end
1364+
1365+ struct __Identity <: AbstractUnwrappedBroadcastRestoreFunction end
1366+ (:: __Identity )(x) = x
1367+
1368+ struct __EachSlice{D} <: AbstractUnwrappedBroadcastRestoreFunction
1369+ dims:: D
1370+ drop:: Bool
1371+ end
1372+ (s:: __EachSlice{D} )(x) where {D} = eachslice (x; dims= s. dims, drop= s. drop)
1373+
1374+ struct __DropDims{D} <: AbstractUnwrappedBroadcastRestoreFunction
1375+ dims:: D
1376+ end
1377+ (s:: __DropDims{D} )(x) where {D} = dropdims (x; dims= s. dims)
1378+
13631379function unwrapped_broadcast (f:: F , x:: Base.Iterators.Zip , original_dims) where {F}
13641380 min_length = Base. inferencebarrier (minimum)(length, x. is)
13651381 itrs = [length (itr) > min_length ? itr[1 : min_length] : itr for itr in x. is]
@@ -1368,7 +1384,7 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Zip, original_dims) where {
13681384 else
13691385 unrolled_map (f, x)
13701386 end
1371- return result, original_dims, identity
1387+ return result, original_dims, __Identity ()
13721388end
13731389
13741390function unwrapped_broadcast (f:: F , x:: Base.Iterators.Enumerate , original_dims) where {F}
@@ -1381,7 +1397,7 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate, original_dims) w
13811397 else
13821398 unrolled_map (f, x)
13831399 end
1384- return result, original_dims, identity
1400+ return result, original_dims, __Identity ()
13851401end
13861402
13871403function unwrapped_broadcast (f:: F , x:: Slices , original_dims) where {F}
@@ -1398,25 +1414,25 @@ function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}
13981414 updated_dims = ()
13991415 if original_dims isa Colon
14001416 updated_dims = mapslices_dims
1401- re = x -> dropdims (x; dims = mapslices_dims)
1417+ re = __DropDims ( mapslices_dims)
14021418 else
14031419 for d in original_dims
14041420 idx = findfirst (isequal (d), x. slicemap)
14051421 @assert idx != = nothing " Expected dimension $d in $(x. slicemap) "
14061422 updated_dims = (updated_dims... , idx)
14071423 end
1408- re = x -> eachslice (x; dims = mapslices_dims, drop = true )
1424+ re = __EachSlice ( mapslices_dims, true )
14091425 end
14101426
14111427 return mapslices (f, px; dims= mapslices_dims), updated_dims, re
14121428 else
14131429 mapslices_dims = Tuple (filter (i -> ! (x. slicemap[i] isa Colon), 1 : ndims (px)))
14141430 if original_dims isa Colon
14151431 updated_dims = mapslices_dims
1416- re = x -> dropdims (x; dims = mapslices_dims)
1432+ re = __DropDims ( mapslices_dims)
14171433 else
14181434 updated_dims = Tuple (d for d in original_dims if d in mapslices_dims)
1419- re = x -> eachslice (x; dims = mapslices_dims, drop = false )
1435+ re = __EachSlice ( mapslices_dims, false )
14201436 end
14211437 return mapslices (f, px; dims= mapslices_dims), updated_dims, re
14221438 end
@@ -1425,7 +1441,7 @@ end
14251441function unwrapped_broadcast (f:: F , xs, original_dims) where {F}
14261442 mapped_xs = unrolled_map (f, xs)
14271443 applicable (size, xs) && (mapped_xs = reshape (mapped_xs, size (xs)))
1428- return mapped_xs, original_dims, identity
1444+ return mapped_xs, original_dims, __Identity ()
14291445end
14301446
14311447# TODO : once traced_call supports internal mutations, we can use traced_call here
0 commit comments