Skip to content

Commit 0683201

Browse files
committed
wrap comm test
1 parent fb35874 commit 0683201

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

test/optimize_comm.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ function dus2(x, y)
2222
return nothing
2323
end
2424

25+
function wrap(x)
26+
return Reactant.Ops.@opcall wrap(x, 7, 7; dimension=1)
27+
end
28+
2529
if length(addressable_devices) 8
2630
@testset "Rotate" begin
2731
N = min((length(Reactant.devices()) ÷ 2) * 2, 8)
@@ -108,4 +112,16 @@ if length(addressable_devices) ≥ 8
108112
@test all(x .== convert(Array, rx))
109113
@test all(y .== convert(Array, ry))
110114
end
115+
116+
@testset "Wrap" begin
117+
mesh = Sharding.Mesh(Reactant.devices(), (:x,))
118+
sharding = Sharding.NamedSharding(mesh, (:x,))
119+
120+
x = Reactant.to_rarray(rand(8192); sharding)
121+
hlo = repr(@code_xla wrap(x))
122+
123+
@test !contains(hlo, "all-to-all")
124+
@test !contains(hlo, "all-gather")
125+
@test contains(hlo, "collective-permute")
126+
end
111127
end

0 commit comments

Comments
 (0)