From 0daa1eb7a7e39512edfa2f9bd302d11b444cf556 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sun, 14 Dec 2025 16:16:48 -0600 Subject: [PATCH 1/8] fix Ops.wrap size computation --- src/Ops.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Ops.jl b/src/Ops.jl index ed162501ba..75e4c659cb 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3564,12 +3564,16 @@ end $(size(input, dimension)) (got $(lhs))" @assert 0 ≤ rhs ≤ size(input, dimension) "rhs must be between 0 and \ $(size(input, dimension)) (got $(rhs))" + + sz = collect(Int64, size(input)) + sz[dimension] = sz[dimension] + lhs + rhs + return TracedRArray{T,N}( (), MLIR.IR.result( enzymexla.wrap(input.mlir_data; lhs, rhs, dimension=dimension - 1, location), 1 ), - size(input), + sz, ) end From fb358747db5e8b2bf2d9c8368bfd97a4986e7575 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sun, 14 Dec 2025 16:26:23 -0600 Subject: [PATCH 2/8] test --- test/ops.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/ops.jl b/test/ops.jl index eba91db228..c5edf06493 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -1138,6 +1138,15 @@ end @test fr!(vr) ≈ f!(v) end +fn_test_wrap(x) = Reactant.Ops.wrap(x, 2, 1; dimension=3) + +@testset "Ops.wrap" begin + x = Reactant.to_rarray(rand(2, 3, 4, 5)) + out = @jit fn_test_wrap(x) + + @test size(out) == (2, 3, 7, 5) +end + @testset "Ops.fill" begin @testset "Fill with TracedScalar" begin fn(x) = Ops.fill(x, [2, 3]) From 0683201494f7ea30a58ba1f8d68b55fae2d33163 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sun, 14 Dec 2025 16:16:58 -0600 Subject: [PATCH 3/8] wrap comm test --- test/optimize_comm.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/optimize_comm.jl b/test/optimize_comm.jl index 404ccaa3a0..3c4fcf1d4b 100644 --- a/test/optimize_comm.jl +++ b/test/optimize_comm.jl @@ -22,6 +22,10 @@ function dus2(x, y) return nothing end +function wrap(x) + return Reactant.Ops.@opcall wrap(x, 7, 7; dimension=1) +end + if length(addressable_devices) ≥ 8 @testset "Rotate" begin N = min((length(Reactant.devices()) ÷ 2) * 2, 8) @@ -108,4 +112,16 @@ if length(addressable_devices) ≥ 8 @test all(x .== convert(Array, rx)) @test all(y .== convert(Array, ry)) end + + @testset "Wrap" begin + mesh = Sharding.Mesh(Reactant.devices(), (:x,)) + sharding = Sharding.NamedSharding(mesh, (:x,)) + + x = Reactant.to_rarray(rand(8192); sharding) + hlo = repr(@code_xla wrap(x)) + + @test !contains(hlo, "all-to-all") + @test !contains(hlo, "all-gather") + @test contains(hlo, "collective-permute") + end end From f28f9f785f9cc1e7b7bf8a7944708544c578e597 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sun, 14 Dec 2025 16:29:28 -0600 Subject: [PATCH 4/8] disable most of CI for testing --- .buildkite/pipeline.yml | 90 ++++++++++++++++++++-------------------- .github/workflows/CI.yml | 48 ++++++++++----------- 2 files changed, 69 insertions(+), 69 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 064ef33235..9032452c2b 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,52 +1,52 @@ -steps: - - group: ":test_tube: Tests" - steps: - - label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}" - matrix: - setup: - version: - - "1.10" - group: - - core - - neural_networks - - integration - runtime: - - "PJRT" - - "IFRT" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.version}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - - lib/ReactantCore/src - commands: | - touch LocalPreferences.toml +# steps: +# - group: ":test_tube: Tests" +# steps: +# - label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}" +# matrix: +# setup: +# version: +# - "1.10" +# group: +# - core +# - neural_networks +# - integration +# runtime: +# - "PJRT" +# - "IFRT" +# plugins: +# - JuliaCI/julia#v1: +# version: "{{matrix.version}}" +# - JuliaCI/julia-coverage#v1: +# codecov: true +# dirs: +# - src +# - ext +# - lib/ReactantCore/src +# commands: | +# touch LocalPreferences.toml - echo "[Reactant]" >> LocalPreferences.toml - echo "xla_runtime = \"{{matrix.runtime}}\"" >> LocalPreferences.toml +# echo "[Reactant]" >> LocalPreferences.toml +# echo "xla_runtime = \"{{matrix.runtime}}\"" >> LocalPreferences.toml - cat LocalPreferences.toml +# cat LocalPreferences.toml - julia --project=. -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.develop([PackageSpec(path="lib/ReactantCore")])' +# julia --project=. -e 'println("--- :julia: Instantiating project") +# using Pkg +# Pkg.develop([PackageSpec(path="lib/ReactantCore")])' - julia --project=. -e 'println("--- :julia: Run Tests") - using Pkg - Pkg.test(; coverage="user")' - agents: - queue: "juliagpu" - cuda: "*" - env: - REACTANT_TEST_GROUP: "{{matrix.group}}" - JULIA_DEBUG: "Reactant,Reactant_jll" - CUDA_VISIBLE_DEVICES: 0 - REACTANT_BACKEND_GROUP: "GPU" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 120 +# julia --project=. -e 'println("--- :julia: Run Tests") +# using Pkg +# Pkg.test(; coverage="user")' +# agents: +# queue: "juliagpu" +# cuda: "*" +# env: +# REACTANT_TEST_GROUP: "{{matrix.group}}" +# JULIA_DEBUG: "Reactant,Reactant_jll" +# CUDA_VISIBLE_DEVICES: 0 +# REACTANT_BACKEND_GROUP: "GPU" +# if: build.message !~ /\[skip tests\]/ +# timeout_in_minutes: 120 # - label: ":julia: :linux: AMDGPU Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}" # matrix: diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b60f801cbc..71a52b1013 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -34,26 +34,26 @@ jobs: fail-fast: false matrix: version: - - "1.10" + # - "1.10" - "1.11" # - 'nightly' os: - ubuntu-24.04 # `ubuntu-22.04-arm` is considered more stable than `ubuntu-24.04-arm`: # . - - ubuntu-22.04-arm + # - ubuntu-22.04-arm # Disable `macOS-13` until # is resolved. # - macOS-13 - - macOS-latest - - windows-latest - - linux-x86-ct6e-180-4tpu + # - macOS-latest + # - windows-latest + # - linux-x86-ct6e-180-4tpu test_group: - core - neural_networks - integration runtime: - - "pjrt" + # - "pjrt" - "ifrt" exclude: - os: linux-x86-ct6e-180-4tpu @@ -86,21 +86,21 @@ jobs: # assertions: true # test_group: ${{ matrix.test_group }} - downgrade: - strategy: - fail-fast: false - matrix: - test_group: - - core - - neural_networks - - integration - runtime: - - "pjrt" - - "ifrt" - uses: ./.github/workflows/CommonCI.yml - with: - julia_version: "1.10" - os: "ubuntu-24.04" - runtime: ${{ matrix.runtime }} - test_group: ${{ matrix.test_group }} - downgrade_testing: true + # downgrade: + # strategy: + # fail-fast: false + # matrix: + # test_group: + # - core + # - neural_networks + # - integration + # runtime: + # - "pjrt" + # - "ifrt" + # uses: ./.github/workflows/CommonCI.yml + # with: + # julia_version: "1.10" + # os: "ubuntu-24.04" + # runtime: ${{ matrix.runtime }} + # test_group: ${{ matrix.test_group }} + # downgrade_testing: true From a66810b76c45b342905fc3b4e8f3a717bfe84aee Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sun, 14 Dec 2025 16:33:23 -0600 Subject: [PATCH 5/8] change wrap sizes to 2 --- test/optimize_comm.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimize_comm.jl b/test/optimize_comm.jl index 3c4fcf1d4b..846a7d3cbd 100644 --- a/test/optimize_comm.jl +++ b/test/optimize_comm.jl @@ -23,7 +23,7 @@ function dus2(x, y) end function wrap(x) - return Reactant.Ops.@opcall wrap(x, 7, 7; dimension=1) + return Reactant.Ops.@opcall wrap(x, 2, 2; dimension=1) end if length(addressable_devices) ≥ 8 From 7ec4be5484ff52a805afa9e85b53d18e75120af9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Dec 2025 13:23:14 -0600 Subject: [PATCH 6/8] test: run only optimize_comm --- .github/workflows/CI.yml | 16 ++--- test/runtests.jl | 130 +++++++++++++++++++-------------------- 2 files changed, 73 insertions(+), 73 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 71a52b1013..f3aff51422 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -38,7 +38,7 @@ jobs: - "1.11" # - 'nightly' os: - - ubuntu-24.04 + - ubuntu-latest # `ubuntu-22.04-arm` is considered more stable than `ubuntu-24.04-arm`: # . # - ubuntu-22.04-arm @@ -50,16 +50,16 @@ jobs: # - linux-x86-ct6e-180-4tpu test_group: - core - - neural_networks - - integration + # - neural_networks + # - integration runtime: # - "pjrt" - "ifrt" - exclude: - - os: linux-x86-ct6e-180-4tpu - version: "1.10" - - os: linux-x86-ct6e-180-4tpu - runtime: "pjrt" + # exclude: + # - os: linux-x86-ct6e-180-4tpu + # version: "1.10" + # - os: linux-x86-ct6e-180-4tpu + # runtime: "pjrt" uses: ./.github/workflows/CommonCI.yml with: julia_version: ${{ matrix.version }} diff --git a/test/runtests.jl b/test/runtests.jl index bbd5e0855f..20b1735d21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,72 +19,72 @@ if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" end end -@testset "Reactant.jl Tests" begin - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" - if Sys.isapple() && haskey(Reactant.XLA.global_backend_state.clients, "metal") - @safetestset "Metal Plugin" include("plugins/metal.jl") - end +# @testset "Reactant.jl Tests" begin +# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" +# if Sys.isapple() && haskey(Reactant.XLA.global_backend_state.clients, "metal") +# @safetestset "Metal Plugin" include("plugins/metal.jl") +# end - @safetestset "Layout" include("layout.jl") - @safetestset "Tracing" include("tracing.jl") - @safetestset "Basic" include("basic.jl") - @safetestset "Constructor" include("constructor.jl") - @safetestset "Autodiff" include("autodiff.jl") - @safetestset "Complex" include("complex.jl") - @safetestset "Broadcast" include("bcast.jl") - @safetestset "Struct" include("struct.jl") - @safetestset "Closure" include("closure.jl") - @safetestset "Compile" include("compile.jl") - @safetestset "IR" include("ir.jl") - @safetestset "Buffer Donation" include("buffer_donation.jl") - @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") - @safetestset "Control Flow" include("control_flow.jl") - @safetestset "Sorting" include("sorting.jl") - @safetestset "Shortcuts to MLIR ops" include("ops.jl") - @safetestset "Indexing" include("indexing.jl") - @safetestset "Ranges" include("ranges.jl") - if !Sys.isapple() - @safetestset "Custom Number Types" include("custom_number_types.jl") - end - @safetestset "Sharding" include("sharding.jl") - @safetestset "Comm Optimization" include("optimize_comm.jl") - @safetestset "Cluster Detection" include("cluster_detector.jl") - @safetestset "Config" include("config.jl") - @safetestset "Batching" include("batching.jl") - @safetestset "QA" include("qa.jl") - end +# @safetestset "Layout" include("layout.jl") +# @safetestset "Tracing" include("tracing.jl") +# @safetestset "Basic" include("basic.jl") +# @safetestset "Constructor" include("constructor.jl") +# @safetestset "Autodiff" include("autodiff.jl") +# @safetestset "Complex" include("complex.jl") +# @safetestset "Broadcast" include("bcast.jl") +# @safetestset "Struct" include("struct.jl") +# @safetestset "Closure" include("closure.jl") +# @safetestset "Compile" include("compile.jl") +# @safetestset "IR" include("ir.jl") +# @safetestset "Buffer Donation" include("buffer_donation.jl") +# @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") +# @safetestset "Control Flow" include("control_flow.jl") +# @safetestset "Sorting" include("sorting.jl") +# @safetestset "Shortcuts to MLIR ops" include("ops.jl") +# @safetestset "Indexing" include("indexing.jl") +# @safetestset "Ranges" include("ranges.jl") +# if !Sys.isapple() +# @safetestset "Custom Number Types" include("custom_number_types.jl") +# end +# @safetestset "Sharding" include("sharding.jl") +@safetestset "Comm Optimization" include("optimize_comm.jl") +# @safetestset "Cluster Detection" include("cluster_detector.jl") +# @safetestset "Config" include("config.jl") +# @safetestset "Batching" include("batching.jl") +# @safetestset "QA" include("qa.jl") +# end - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" - @safetestset "CUDA" include("integration/cuda.jl") - @safetestset "KernelAbstractions" include("integration/kernelabstractions.jl") - @safetestset "Linear Algebra" include("integration/linear_algebra.jl") - @safetestset "OffsetArrays" include("integration/offsetarrays.jl") - @safetestset "OneHotArrays" include("integration/onehotarrays.jl") - @safetestset "AbstractFFTs" include("integration/fft.jl") - @safetestset "SpecialFunctions" include("integration/special_functions.jl") - @safetestset "Random" include("integration/random.jl") - @safetestset "Python" include("integration/python.jl") - @safetestset "Optimisers" include("integration/optimisers.jl") - @safetestset "FillArrays" include("integration/fillarrays.jl") - if ENZYMEJAX_INSTALLED[] && !Sys.isapple() - @safetestset "EnzymeJAX Export" include("integration/enzymejax.jl") - end - @safetestset "MPI" begin - using MPI - nranks = 2 - run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`) - end +# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" +# @safetestset "CUDA" include("integration/cuda.jl") +# @safetestset "KernelAbstractions" include("integration/kernelabstractions.jl") +# @safetestset "Linear Algebra" include("integration/linear_algebra.jl") +# @safetestset "OffsetArrays" include("integration/offsetarrays.jl") +# @safetestset "OneHotArrays" include("integration/onehotarrays.jl") +# @safetestset "AbstractFFTs" include("integration/fft.jl") +# @safetestset "SpecialFunctions" include("integration/special_functions.jl") +# @safetestset "Random" include("integration/random.jl") +# @safetestset "Python" include("integration/python.jl") +# @safetestset "Optimisers" include("integration/optimisers.jl") +# @safetestset "FillArrays" include("integration/fillarrays.jl") +# if ENZYMEJAX_INSTALLED[] && !Sys.isapple() +# @safetestset "EnzymeJAX Export" include("integration/enzymejax.jl") +# end +# @safetestset "MPI" begin +# using MPI +# nranks = 2 +# run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`) +# end - # Zygote is not supported on 1.12 https://github.com/FluxML/Zygote.jl/issues/1580 - if VERSION < v"1.12-" - @safetestset "Zygote" include("integration/zygote.jl") - end - end +# # Zygote is not supported on 1.12 https://github.com/FluxML/Zygote.jl/issues/1580 +# if VERSION < v"1.12-" +# @safetestset "Zygote" include("integration/zygote.jl") +# end +# end - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" - @safetestset "NNlib Primitives" include("nn/nnlib.jl") - @safetestset "Flux.jl Integration" include("nn/flux.jl") - @safetestset "LuxLib Primitives" include("nn/luxlib.jl") - @safetestset "Lux Integration" include("nn/lux.jl") - end -end +# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" +# @safetestset "NNlib Primitives" include("nn/nnlib.jl") +# @safetestset "Flux.jl Integration" include("nn/flux.jl") +# @safetestset "LuxLib Primitives" include("nn/luxlib.jl") +# @safetestset "Lux Integration" include("nn/lux.jl") +# end +# end From 64511c362bf7ef64b0339b32451df662c8ce45a3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Dec 2025 14:01:34 -0600 Subject: [PATCH 7/8] test: force exactly 8 devices --- test/optimize_comm.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimize_comm.jl b/test/optimize_comm.jl index 846a7d3cbd..3b73e9cb5e 100644 --- a/test/optimize_comm.jl +++ b/test/optimize_comm.jl @@ -114,7 +114,7 @@ if length(addressable_devices) ≥ 8 end @testset "Wrap" begin - mesh = Sharding.Mesh(Reactant.devices(), (:x,)) + mesh = Sharding.Mesh(Reactant.devices()[1:8], (:x,)) sharding = Sharding.NamedSharding(mesh, (:x,)) x = Reactant.to_rarray(rand(8192); sharding) From 8e6250527218a8f478637e6fcddf61b3744c4659 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Mon, 15 Dec 2025 16:27:07 -0600 Subject: [PATCH 8/8] change input size to fail with num_devices==12 --- test/optimize_comm.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/optimize_comm.jl b/test/optimize_comm.jl index 3b73e9cb5e..dae7a3afa8 100644 --- a/test/optimize_comm.jl +++ b/test/optimize_comm.jl @@ -114,11 +114,11 @@ if length(addressable_devices) ≥ 8 end @testset "Wrap" begin - mesh = Sharding.Mesh(Reactant.devices()[1:8], (:x,)) + mesh = Sharding.Mesh(Reactant.devices(), (:x,)) sharding = Sharding.NamedSharding(mesh, (:x,)) - x = Reactant.to_rarray(rand(8192); sharding) - hlo = repr(@code_xla wrap(x)) + x = Reactant.to_rarray(rand(192 * length(addressable_devices)); sharding) + @assert x isa ConcreteIFRTArray @test !contains(hlo, "all-to-all") @test !contains(hlo, "all-gather")