Skip to content

Commit 258bf4a

Browse files
committed
fix: drop=false
1 parent 2070ad3 commit 258bf4a

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

src/TracedRArray.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,26 +1402,27 @@ function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}
14021402
updated_dims = ()
14031403
if original_dims isa Colon
14041404
updated_dims = mapslices_dims
1405+
re = x -> dropdims(x; dims=mapslices_dims)
14051406
else
14061407
for d in original_dims
14071408
idx = findfirst(isequal(d), x.slicemap)
14081409
@assert idx !== nothing "Expected dimension $d in $(x.slicemap)"
14091410
updated_dims = (updated_dims..., idx)
14101411
end
1412+
re = x -> eachslice(x; dims=mapslices_dims, drop=true)
14111413
end
14121414

1413-
return (
1414-
mapslices(f, px; dims=mapslices_dims),
1415-
updated_dims,
1416-
x -> eachslice(x; dims=mapslices_dims, drop=true),
1417-
)
1415+
return mapslices(f, px; dims=mapslices_dims), updated_dims, re
14181416
else
14191417
mapslices_dims = Tuple(filter(i -> !(x.slicemap[i] isa Colon), 1:ndims(px)))
1420-
return (
1421-
mapslices(f, px; dims=mapslices_dims),
1422-
original_dims,
1423-
x -> eachslice(x; dims=mapslices_dims, drop=false),
1424-
)
1418+
if original_dims isa Colon
1419+
updated_dims = mapslices_dims
1420+
re = x -> dropdims(x; dims=mapslices_dims)
1421+
else
1422+
updated_dims = Tuple(d for d in original_dims if d in mapslices_dims)
1423+
re = x -> eachslice(x; dims=mapslices_dims, drop=false)
1424+
end
1425+
return mapslices(f, px; dims=mapslices_dims), updated_dims, re
14251426
end
14261427
end
14271428

test/basic.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,3 +1742,43 @@ end
17421742
fn = @compile sum(x_ra1)
17431743
@test_throws Reactant.Compiler.MisMatchedThunkTypeError fn(x_ra2)
17441744
end
1745+
1746+
@testset "Slices" begin
1747+
@testset "drop=true" begin
1748+
x = eachslice(
1749+
Reactant.TestUtils.construct_test_array(Float32, 2, 3, 4, 5); dims=(3, 1)
1750+
)
1751+
x_ra = Reactant.to_rarray(x)
1752+
1753+
@test @jit(sum(x_ra)) sum(x)
1754+
1755+
@testset for dims in (1, 2, (1, 2), (2, 1))
1756+
res_ra = @jit sum(x_ra; dims)
1757+
res = sum(x; dims)
1758+
@test size(res_ra) == size(res)
1759+
for (gt, comp) in zip(res_ra, res)
1760+
@test gt comp
1761+
end
1762+
end
1763+
end
1764+
1765+
@testset "drop=false" begin
1766+
x = eachslice(
1767+
Reactant.TestUtils.construct_test_array(Float32, 2, 3, 4, 5);
1768+
dims=(3, 1),
1769+
drop=false,
1770+
)
1771+
x_ra = Reactant.to_rarray(x)
1772+
1773+
@test @jit(sum(x_ra)) sum(x)
1774+
1775+
@testset for dims in (1, 2, 3, 4, (1, 2), (1, 2, 4), (3, 4, 1), (2, 1))
1776+
res_ra = @jit sum(x_ra; dims)
1777+
res = sum(x; dims)
1778+
@test size(res_ra) == size(res)
1779+
for (gt, comp) in zip(res_ra, res)
1780+
@test gt comp
1781+
end
1782+
end
1783+
end
1784+
end

0 commit comments

Comments
 (0)