Skip to content

Commit 9e9b29e

Browse files
committed
fix: maybe fix stackoverflow??
1 parent 0f25e09 commit 9e9b29e

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

src/TracedRArray.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,22 @@ end
13601360

13611361
(fn::BroadcastIterator)(args...) = fn.f((args...,))
13621362

1363+
abstract type AbstractUnwrappedBroadcastRestoreFunction end
1364+
1365+
struct __Identity <: AbstractUnwrappedBroadcastRestoreFunction end
1366+
(::__Identity)(x) = x
1367+
1368+
struct __EachSlice{D} <: AbstractUnwrappedBroadcastRestoreFunction
1369+
dims::D
1370+
drop::Bool
1371+
end
1372+
(s::__EachSlice{D})(x) where {D} = eachslice(x; dims=s.dims, drop=s.drop)
1373+
1374+
struct __DropDims{D} <: AbstractUnwrappedBroadcastRestoreFunction
1375+
dims::D
1376+
end
1377+
(s::__DropDims{D})(x) where {D} = dropdims(x; dims=s.dims)
1378+
13631379
function unwrapped_broadcast(f::F, x::Base.Iterators.Zip, original_dims) where {F}
13641380
min_length = Base.inferencebarrier(minimum)(length, x.is)
13651381
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
@@ -1368,7 +1384,7 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Zip, original_dims) where {
13681384
else
13691385
unrolled_map(f, x)
13701386
end
1371-
return result, original_dims, identity
1387+
return result, original_dims, __Identity()
13721388
end
13731389

13741390
function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate, original_dims) where {F}
@@ -1381,7 +1397,7 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate, original_dims) w
13811397
else
13821398
unrolled_map(f, x)
13831399
end
1384-
return result, original_dims, identity
1400+
return result, original_dims, __Identity()
13851401
end
13861402

13871403
function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}
@@ -1398,25 +1414,25 @@ function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}
13981414
updated_dims = ()
13991415
if original_dims isa Colon
14001416
updated_dims = mapslices_dims
1401-
re = x -> dropdims(x; dims=mapslices_dims)
1417+
re = __DropDims(mapslices_dims)
14021418
else
14031419
for d in original_dims
14041420
idx = findfirst(isequal(d), x.slicemap)
14051421
@assert idx !== nothing "Expected dimension $d in $(x.slicemap)"
14061422
updated_dims = (updated_dims..., idx)
14071423
end
1408-
re = x -> eachslice(x; dims=mapslices_dims, drop=true)
1424+
re = __EachSlice(mapslices_dims, true)
14091425
end
14101426

14111427
return mapslices(f, px; dims=mapslices_dims), updated_dims, re
14121428
else
14131429
mapslices_dims = Tuple(filter(i -> !(x.slicemap[i] isa Colon), 1:ndims(px)))
14141430
if original_dims isa Colon
14151431
updated_dims = mapslices_dims
1416-
re = x -> dropdims(x; dims=mapslices_dims)
1432+
re = __DropDims(mapslices_dims)
14171433
else
14181434
updated_dims = Tuple(d for d in original_dims if d in mapslices_dims)
1419-
re = x -> eachslice(x; dims=mapslices_dims, drop=false)
1435+
re = __EachSlice(mapslices_dims, false)
14201436
end
14211437
return mapslices(f, px; dims=mapslices_dims), updated_dims, re
14221438
end
@@ -1425,7 +1441,7 @@ end
14251441
function unwrapped_broadcast(f::F, xs, original_dims) where {F}
14261442
mapped_xs = unrolled_map(f, xs)
14271443
applicable(size, xs) && (mapped_xs = reshape(mapped_xs, size(xs)))
1428-
return mapped_xs, original_dims, identity
1444+
return mapped_xs, original_dims, __Identity()
14291445
end
14301446

14311447
# TODO: once traced_call supports internal mutations, we can use traced_call here

0 commit comments

Comments
 (0)