Skip to content

Commit 0f25e09

Browse files
committed
fix: drop=false
1 parent 3609b5b commit 0f25e09

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

src/TracedRArray.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,26 +1398,27 @@ function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}
13981398
updated_dims = ()
13991399
if original_dims isa Colon
14001400
updated_dims = mapslices_dims
1401+
re = x -> dropdims(x; dims=mapslices_dims)
14011402
else
14021403
for d in original_dims
14031404
idx = findfirst(isequal(d), x.slicemap)
14041405
@assert idx !== nothing "Expected dimension $d in $(x.slicemap)"
14051406
updated_dims = (updated_dims..., idx)
14061407
end
1408+
re = x -> eachslice(x; dims=mapslices_dims, drop=true)
14071409
end
14081410

1409-
return (
1410-
mapslices(f, px; dims=mapslices_dims),
1411-
updated_dims,
1412-
x -> eachslice(x; dims=mapslices_dims, drop=true),
1413-
)
1411+
return mapslices(f, px; dims=mapslices_dims), updated_dims, re
14141412
else
14151413
mapslices_dims = Tuple(filter(i -> !(x.slicemap[i] isa Colon), 1:ndims(px)))
1416-
return (
1417-
mapslices(f, px; dims=mapslices_dims),
1418-
original_dims,
1419-
x -> eachslice(x; dims=mapslices_dims, drop=false),
1420-
)
1414+
if original_dims isa Colon
1415+
updated_dims = mapslices_dims
1416+
re = x -> dropdims(x; dims=mapslices_dims)
1417+
else
1418+
updated_dims = Tuple(d for d in original_dims if d in mapslices_dims)
1419+
re = x -> eachslice(x; dims=mapslices_dims, drop=false)
1420+
end
1421+
return mapslices(f, px; dims=mapslices_dims), updated_dims, re
14211422
end
14221423
end
14231424

test/basic.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,3 +1725,43 @@ end
17251725
fn = @compile sum(x_ra1)
17261726
@test_throws Reactant.Compiler.MisMatchedThunkTypeError fn(x_ra2)
17271727
end
1728+
1729+
@testset "Slices" begin
1730+
@testset "drop=true" begin
1731+
x = eachslice(
1732+
Reactant.TestUtils.construct_test_array(Float32, 2, 3, 4, 5); dims=(3, 1)
1733+
)
1734+
x_ra = Reactant.to_rarray(x)
1735+
1736+
@test @jit(sum(x_ra)) sum(x)
1737+
1738+
@testset for dims in (1, 2, (1, 2), (2, 1))
1739+
res_ra = @jit sum(x_ra; dims)
1740+
res = sum(x; dims)
1741+
@test size(res_ra) == size(res)
1742+
for (gt, comp) in zip(res_ra, res)
1743+
@test gt comp
1744+
end
1745+
end
1746+
end
1747+
1748+
@testset "drop=false" begin
1749+
x = eachslice(
1750+
Reactant.TestUtils.construct_test_array(Float32, 2, 3, 4, 5);
1751+
dims=(3, 1),
1752+
drop=false,
1753+
)
1754+
x_ra = Reactant.to_rarray(x)
1755+
1756+
@test @jit(sum(x_ra)) sum(x)
1757+
1758+
@testset for dims in (1, 2, 3, 4, (1, 2), (1, 2, 4), (3, 4, 1), (2, 1))
1759+
res_ra = @jit sum(x_ra; dims)
1760+
res = sum(x; dims)
1761+
@test size(res_ra) == size(res)
1762+
for (gt, comp) in zip(res_ra, res)
1763+
@test gt comp
1764+
end
1765+
end
1766+
end
1767+
end

0 commit comments

Comments
 (0)