Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3564,12 +3564,16 @@ end
$(size(input, dimension)) (got $(lhs))"
@assert 0 ≤ rhs ≤ size(input, dimension) "rhs must be between 0 and \
$(size(input, dimension)) (got $(rhs))"

sz = collect(Int64, size(input))
sz[dimension] = sz[dimension] + lhs + rhs

return TracedRArray{T,N}(
(),
MLIR.IR.result(
enzymexla.wrap(input.mlir_data; lhs, rhs, dimension=dimension - 1, location), 1
),
size(input),
sz,
)
end

Expand Down
9 changes: 9 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,15 @@ end
@test fr!(vr) ≈ f!(v)
end

fn_test_wrap(x) = Reactant.Ops.wrap(x, 2, 1; dimension=3)

@testset "Ops.wrap" begin
x = Reactant.to_rarray(rand(2, 3, 4, 5))
out = @jit fn_test_wrap(x)

@test size(out) == (2, 3, 7, 5)
end

@testset "Ops.fill" begin
@testset "Fill with TracedScalar" begin
fn(x) = Ops.fill(x, [2, 3])
Expand Down
16 changes: 16 additions & 0 deletions test/optimize_comm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ function dus2(x, y)
return nothing
end

function wrap(x)
return Reactant.Ops.@opcall wrap(x, 7, 7; dimension=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may want to change this to 2

end

if length(addressable_devices) ≥ 8
@testset "Rotate" begin
N = min((length(Reactant.devices()) ÷ 2) * 2, 8)
Expand Down Expand Up @@ -108,4 +112,16 @@ if length(addressable_devices) ≥ 8
@test all(x .== convert(Array, rx))
@test all(y .== convert(Array, ry))
end

@testset "Wrap" begin
mesh = Sharding.Mesh(Reactant.devices(), (:x,))
sharding = Sharding.NamedSharding(mesh, (:x,))

x = Reactant.to_rarray(rand(8192); sharding)
hlo = repr(@code_xla wrap(x))

@test !contains(hlo, "all-to-all")
@test !contains(hlo, "all-gather")
@test contains(hlo, "collective-permute")
end
end
Loading