Skip to content

Commit 74c16b6

Browse files
committed
fix: maybe fix stackoverflow??
1 parent 258bf4a commit 74c16b6

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

src/TracedRArray.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,22 @@ end
13641364

13651365
(fn::BroadcastIterator)(args...) = fn.f((args...,))
13661366

1367+
abstract type AbstractUnwrappedBroadcastRestoreFunction end
1368+
1369+
struct __Identity <: AbstractUnwrappedBroadcastRestoreFunction end
1370+
(::__Identity)(x) = x
1371+
1372+
struct __EachSlice{D} <: AbstractUnwrappedBroadcastRestoreFunction
1373+
dims::D
1374+
drop::Bool
1375+
end
1376+
(s::__EachSlice{D})(x) where {D} = eachslice(x; dims=s.dims, drop=s.drop)
1377+
1378+
struct __DropDims{D} <: AbstractUnwrappedBroadcastRestoreFunction
1379+
dims::D
1380+
end
1381+
(s::__DropDims{D})(x) where {D} = dropdims(x; dims=s.dims)
1382+
13671383
function unwrapped_broadcast(f::F, x::Base.Iterators.Zip, original_dims) where {F}
13681384
min_length = Base.inferencebarrier(minimum)(length, x.is)
13691385
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
@@ -1372,7 +1388,7 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Zip, original_dims) where {
13721388
else
13731389
unrolled_map(f, x)
13741390
end
1375-
return result, original_dims, identity
1391+
return result, original_dims, __Identity()
13761392
end
13771393

13781394
function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate, original_dims) where {F}
@@ -1385,7 +1401,7 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate, original_dims) w
13851401
else
13861402
unrolled_map(f, x)
13871403
end
1388-
return result, original_dims, identity
1404+
return result, original_dims, __Identity()
13891405
end
13901406

13911407
function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}
@@ -1402,25 +1418,25 @@ function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}
14021418
updated_dims = ()
14031419
if original_dims isa Colon
14041420
updated_dims = mapslices_dims
1405-
re = x -> dropdims(x; dims=mapslices_dims)
1421+
re = __DropDims(mapslices_dims)
14061422
else
14071423
for d in original_dims
14081424
idx = findfirst(isequal(d), x.slicemap)
14091425
@assert idx !== nothing "Expected dimension $d in $(x.slicemap)"
14101426
updated_dims = (updated_dims..., idx)
14111427
end
1412-
re = x -> eachslice(x; dims=mapslices_dims, drop=true)
1428+
re = __EachSlice(mapslices_dims, true)
14131429
end
14141430

14151431
return mapslices(f, px; dims=mapslices_dims), updated_dims, re
14161432
else
14171433
mapslices_dims = Tuple(filter(i -> !(x.slicemap[i] isa Colon), 1:ndims(px)))
14181434
if original_dims isa Colon
14191435
updated_dims = mapslices_dims
1420-
re = x -> dropdims(x; dims=mapslices_dims)
1436+
re = __DropDims(mapslices_dims)
14211437
else
14221438
updated_dims = Tuple(d for d in original_dims if d in mapslices_dims)
1423-
re = x -> eachslice(x; dims=mapslices_dims, drop=false)
1439+
re = __EachSlice(mapslices_dims, false)
14241440
end
14251441
return mapslices(f, px; dims=mapslices_dims), updated_dims, re
14261442
end
@@ -1429,7 +1445,7 @@ end
14291445
function unwrapped_broadcast(f::F, xs, original_dims) where {F}
14301446
mapped_xs = unrolled_map(f, xs)
14311447
applicable(size, xs) && (mapped_xs = reshape(mapped_xs, size(xs)))
1432-
return mapped_xs, original_dims, identity
1448+
return mapped_xs, original_dims, __Identity()
14331449
end
14341450

14351451
# TODO: once traced_call supports internal mutations, we can use traced_call here

0 commit comments

Comments
 (0)