From 801fbb67db1974fc56c75c91612ae5b53eb8fa24 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 9 Dec 2025 22:02:54 -0600 Subject: [PATCH 1/8] initial StructArrays extension --- Project.toml | 3 ++ ext/ReactantStructArraysExt.jl | 74 ++++++++++++++++++++++++++++++++++ src/TracedRArray.jl | 6 ++- 3 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 ext/ReactantStructArraysExt.jl diff --git a/Project.toml b/Project.toml index 0bf5002667..3ebfcdcc44 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Scratch = "6c6a2e73-6563-6170-7368-637461726353" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" p7zip_jll = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" [weakdeps] @@ -71,6 +72,7 @@ ReactantRandom123Ext = "Random123" ReactantSparseArraysExt = "SparseArrays" ReactantSpecialFunctionsExt = "SpecialFunctions" ReactantStatisticsExt = "Statistics" +ReactantStructArraysExt = "StructArrays" ReactantYaoBlocksExt = "YaoBlocks" ReactantZygoteExt = "Zygote" @@ -115,6 +117,7 @@ Sockets = "1.10" SparseArrays = "1.10" SpecialFunctions = "2.4" Statistics = "1.10" +StructArrays = "0.7.2" YaoBlocks = "0.13, 0.14" Zygote = "0.7" julia = "1.10" diff --git a/ext/ReactantStructArraysExt.jl b/ext/ReactantStructArraysExt.jl new file mode 100644 index 0000000000..949872641b --- /dev/null +++ b/ext/ReactantStructArraysExt.jl @@ -0,0 +1,74 @@ +module ReactantStructArraysExt + +import Reactant +import StructArrays + +import StructArrays: StructArrayStyle, StructArray, StructVector, index_type +import Reactant: TraceMode, TracedToTypes, traced_type_inner, append_path, make_tracer, traced_type +import Reactant.TracedRArrayOverrides: AbstractReactantArrayStyle, _copy +import Base.Broadcast: Broadcasted + +StructArrays.always_struct_broadcast(::AbstractReactantArrayStyle) = true + +function Base.copy(bc::Broadcasted{StructArrays.StructArrayStyle{S, N}}) where {S<:AbstractReactantArrayStyle, N} + return _copy(bc) +end + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(prev::Type{<:StructVector{NT}}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) where {NT <: NamedTuple} + T, N, C, I = prev.parameters + C_traced = traced_type_inner( + C, + seen, + mode, + track_numbers, + sharding, + runtime, + ) + T_traced = traced_type_inner( + T, + seen, + mode, + # The elements in the NamedTuple are backed by vectors, + # these vectors are converted to RArrays so we need to track numbers: + Number #= track_numbers =#, + sharding, + runtime, + ) + return StructVector{T_traced, C_traced, index_type(fieldtypes(C_traced))} +end + +function Reactant.make_tracer( + seen, @nospecialize(prev::StructVector{NT}), @nospecialize(path), mode; track_numbers=false, sharding=Reactant.Sharding.Sharding.NoSharding(), runtime=nothing, kwargs... +) where {NT <: NamedTuple} + track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{}) + components = getfield(prev, :components) + if mode == TracedToTypes + push!(path, typeof(prev)) + for c in components + make_tracer(seen, c, path, mode; track_numbers, sharding, runtime, kwargs...) + end + return nothing + end + traced_components = make_tracer(seen, components, append_path(path, :components), mode; track_numbers, sharding, runtime, kwargs...) + T_traced = traced_type( + typeof(prev), + Val(mode), + track_numbers, + sharding, + runtime, + ) + return StructVector{first(T_traced.parameters)}(traced_components) +end + +@inline function Reactant.traced_getfield(@nospecialize(obj::StructArray), field) + return Base.getfield(obj, field) +end + +end \ No newline at end of file diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 85f692db41..72bba9ca50 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -459,7 +459,7 @@ first_scalar(x) = @allowscalar first(x) # we need to override the outer copy method to make sure we never fall back to scalar # iteration (see, e.g., CUDA.jl#145) -function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) +function _copy(bc) fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive TracedUtils.TypeCast{bc.f}() else @@ -477,6 +477,10 @@ function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) return copyto!(sim, bc) end +@noinline function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) + return _copy(bc) +end + function Base.materialize!( ::Style, dest, bc::Broadcasted ) where {Style<:AbstractReactantArrayStyle} From 31726d1981f951cd251ebfb9446da868b8041f69 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 9 Dec 2025 22:25:02 -0600 Subject: [PATCH 2/8] generalize StructVector to StructArray of NamedTuple --- ext/ReactantStructArraysExt.jl | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/ext/ReactantStructArraysExt.jl b/ext/ReactantStructArraysExt.jl index 949872641b..6691db018a 100644 --- a/ext/ReactantStructArraysExt.jl +++ b/ext/ReactantStructArraysExt.jl @@ -4,7 +4,7 @@ import Reactant import StructArrays import StructArrays: StructArrayStyle, StructArray, StructVector, index_type -import Reactant: TraceMode, TracedToTypes, traced_type_inner, append_path, make_tracer, traced_type +import Reactant: TraceMode, TracedToTypes, traced_type_inner, append_path, make_tracer, traced_type, ReactantPrimitive import Reactant.TracedRArrayOverrides: AbstractReactantArrayStyle, _copy import Base.Broadcast: Broadcasted @@ -15,7 +15,7 @@ function Base.copy(bc::Broadcasted{StructArrays.StructArrayStyle{S, N}}) where { end Base.@nospecializeinfer function Reactant.traced_type_inner( - @nospecialize(prev::Type{<:StructVector{NT}}), + @nospecialize(prev::Type{<:StructArray{NT}}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @@ -31,22 +31,30 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( sharding, runtime, ) - T_traced = traced_type_inner( - T, - seen, - mode, + + names = T.parameters[1] + valuetypes = T.parameters[2].parameters + traced_value_types = map(valuetypes) do VT # The elements in the NamedTuple are backed by vectors, # these vectors are converted to RArrays so we need to track numbers: - Number #= track_numbers =#, - sharding, - runtime, - ) + track_numbers = VT <: ReactantPrimitive ? ReactantPrimitive : track_numbers + traced_type_inner( + VT, + seen, + mode, + track_numbers, + sharding, + runtime, + ) + end + T_traced = NamedTuple{names, Tuple{traced_value_types...}} + return StructVector{T_traced, C_traced, index_type(fieldtypes(C_traced))} end function Reactant.make_tracer( - seen, @nospecialize(prev::StructVector{NT}), @nospecialize(path), mode; track_numbers=false, sharding=Reactant.Sharding.Sharding.NoSharding(), runtime=nothing, kwargs... -) where {NT <: NamedTuple} + seen, @nospecialize(prev::StructArray{NT, N}), @nospecialize(path), mode; track_numbers=false, sharding=Reactant.Sharding.Sharding.NoSharding(), runtime=nothing, kwargs... +) where {NT <: NamedTuple, N} track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{}) components = getfield(prev, :components) if mode == TracedToTypes @@ -64,7 +72,7 @@ function Reactant.make_tracer( sharding, runtime, ) - return StructVector{first(T_traced.parameters)}(traced_components) + return StructArray{first(T_traced.parameters)}(traced_components) end @inline function Reactant.traced_getfield(@nospecialize(obj::StructArray), field) From eff01caeef1ef563e42f9e7fcb1f3d88b72629c9 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 9 Dec 2025 22:37:58 -0600 Subject: [PATCH 3/8] test structarray to_rarray and make_tracer --- test/integration/structarrays.jl | 41 ++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 42 insertions(+) create mode 100644 test/integration/structarrays.jl diff --git a/test/integration/structarrays.jl b/test/integration/structarrays.jl new file mode 100644 index 0000000000..bd3634db78 --- /dev/null +++ b/test/integration/structarrays.jl @@ -0,0 +1,41 @@ +using StructArrays, Reactant, Test + +@testset "StructArray to_rarray and make_tracer" begin + x = StructArray(; + a=rand(10, 2), b=fill("some strings", (10, 2)), c=rand(Float32, 10, 2) + ) + x_ra = Reactant.to_rarray(x) + + # Note that the element type (the NamedTuple) contains ConcreteRNumbers even though track_numbers is not enabled. + # This is because when the backing arrays are converted to TracedRArrays, their elements will contain TracedRNumbers. + # In order for the element type to match the backing arrays, we need to use ConcreteRNumbers here as well: + @test typeof(x_ra) == StructArray{ + @NamedTuple{ + a::ConcretePJRTNumber{Float64,1}, b::String, c::ConcretePJRTNumber{Float32,1} + }, + 2, + @NamedTuple{ + a::ConcretePJRTArray{Float64,2,1}, + b::Matrix{String}, + c::ConcretePJRTArray{Float32,2,1}, + }, + CartesianIndex{2}, + } + + @test typeof( + make_tracer(Reactant.OrderedIdDict(), x_ra, (), Reactant.ConcreteToTraced) + ) == StructArray{ + @NamedTuple{ + a::Reactant.TracedRNumber{Float64}, + b::String, + c::Reactant.TracedRNumber{Float32}, + }, + 2, + @NamedTuple{ + a::Reactant.TracedRArray{Float64,2}, + b::Matrix{String}, + c::Reactant.TracedRArray{Float32,2}, + }, + Int64, + } +end diff --git a/test/runtests.jl b/test/runtests.jl index bbd5e0855f..6d1c04e9df 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,6 +62,7 @@ end @safetestset "OneHotArrays" include("integration/onehotarrays.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") @safetestset "SpecialFunctions" include("integration/special_functions.jl") + @safetestset "StructArrays" include("integration/structarrays.jl") @safetestset "Random" include("integration/random.jl") @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") From a9e8bed8649576fd8d6e71df7d3d379bf4a8ca52 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 12 Dec 2025 10:56:29 -0600 Subject: [PATCH 4/8] structarrays getindex, setindex, broadcasting, and small fixes + formatting --- ext/ReactantStructArraysExt.jl | 115 ++++++++++++++++++++++++--------- 1 file changed, 84 insertions(+), 31 deletions(-) diff --git a/ext/ReactantStructArraysExt.jl b/ext/ReactantStructArraysExt.jl index 6691db018a..993556299f 100644 --- a/ext/ReactantStructArraysExt.jl +++ b/ext/ReactantStructArraysExt.jl @@ -1,19 +1,76 @@ module ReactantStructArraysExt -import Reactant -import StructArrays +using Reactant: Reactant +using StructArrays: StructArrays -import StructArrays: StructArrayStyle, StructArray, StructVector, index_type -import Reactant: TraceMode, TracedToTypes, traced_type_inner, append_path, make_tracer, traced_type, ReactantPrimitive +import StructArrays: + StructArrayStyle, + StructArray, + index_type, + components, + createinstance, + get_ith, + maybe_convert_elt, + foreachfield +import Reactant: + TraceMode, + TracedToTypes, + traced_type_inner, + append_path, + make_tracer, + traced_type, + ReactantPrimitive, + broadcast_to_size, + TracedRNumber, + TracedRArray, + unwrapped_eltype import Reactant.TracedRArrayOverrides: AbstractReactantArrayStyle, _copy import Base.Broadcast: Broadcasted StructArrays.always_struct_broadcast(::AbstractReactantArrayStyle) = true -function Base.copy(bc::Broadcasted{StructArrays.StructArrayStyle{S, N}}) where {S<:AbstractReactantArrayStyle, N} +function Base.copy( + bc::Broadcasted{StructArrays.StructArrayStyle{S,N}} +) where {S<:AbstractReactantArrayStyle,N} return _copy(bc) end +function Reactant.broadcast_to_size(arg::StructArray{T}, rsize) where {T} + new = [broadcast_to_size(c, rsize) for c in components(arg)] + return StructArray{T}(NamedTuple(Base.propertynames(arg) .=> new)) +end + +function Base.copyto!( + dest::StructArray, bc::Base.Broadcast.Broadcasted{<:AbstractReactantArrayStyle} +) + axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) + isempty(dest) && return dest + + bc = Broadcast.preprocess(dest, bc) + + args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) + + res = Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...) + + return copyto!(dest, res) +end + +Base.@propagate_inbounds function StructArrays._getindex( + x::StructArray{T}, I::Vararg{TracedRNumber{<:Integer}} +) where {T} + cols = components(x) + @boundscheck checkbounds(x, I...) + return createinstance(T, get_ith(cols, I...)...) +end + +Base.@propagate_inbounds function Base.setindex!( + s::StructArray{T,<:Any,<:Any,Int}, vals, I::TracedRNumber{TI} +) where {T,TI<:Integer} + valsT = maybe_convert_elt(T, vals) + foreachfield((col, val) -> (@inbounds col[I] = val), s, valsT) + return s +end + Base.@nospecializeinfer function Reactant.traced_type_inner( @nospecialize(prev::Type{<:StructArray{NT}}), seen, @@ -21,16 +78,9 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( @nospecialize(track_numbers::Type), @nospecialize(sharding), @nospecialize(runtime) -) where {NT <: NamedTuple} +) where {NT<:NamedTuple} T, N, C, I = prev.parameters - C_traced = traced_type_inner( - C, - seen, - mode, - track_numbers, - sharding, - runtime, - ) + C_traced = traced_type_inner(C, seen, mode, track_numbers, sharding, runtime) names = T.parameters[1] valuetypes = T.parameters[2].parameters @@ -38,23 +88,23 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( # The elements in the NamedTuple are backed by vectors, # these vectors are converted to RArrays so we need to track numbers: track_numbers = VT <: ReactantPrimitive ? ReactantPrimitive : track_numbers - traced_type_inner( - VT, - seen, - mode, - track_numbers, - sharding, - runtime, - ) + traced_type_inner(VT, seen, mode, track_numbers, sharding, runtime) end - T_traced = NamedTuple{names, Tuple{traced_value_types...}} + T_traced = NamedTuple{names,Tuple{traced_value_types...}} - return StructVector{T_traced, C_traced, index_type(fieldtypes(C_traced))} + return StructArray{T_traced,N,C_traced,index_type(fieldtypes(C_traced))} end function Reactant.make_tracer( - seen, @nospecialize(prev::StructArray{NT, N}), @nospecialize(path), mode; track_numbers=false, sharding=Reactant.Sharding.Sharding.NoSharding(), runtime=nothing, kwargs... -) where {NT <: NamedTuple, N} + seen, + @nospecialize(prev::StructArray{NT,N}), + @nospecialize(path), + mode; + track_numbers=false, + sharding=Reactant.Sharding.Sharding.NoSharding(), + runtime=nothing, + kwargs..., +) where {NT<:NamedTuple,N} track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{}) components = getfield(prev, :components) if mode == TracedToTypes @@ -64,14 +114,17 @@ function Reactant.make_tracer( end return nothing end - traced_components = make_tracer(seen, components, append_path(path, :components), mode; track_numbers, sharding, runtime, kwargs...) - T_traced = traced_type( - typeof(prev), - Val(mode), + traced_components = make_tracer( + seen, + components, + append_path(path, 1), + mode; track_numbers, sharding, runtime, + kwargs..., ) + T_traced = traced_type(typeof(prev), Val(mode), track_numbers, sharding, runtime) return StructArray{first(T_traced.parameters)}(traced_components) end @@ -79,4 +132,4 @@ end return Base.getfield(obj, field) end -end \ No newline at end of file +end From 1102ff4a15ddf303269cacee18e653de85d92b44 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 12 Dec 2025 10:57:50 -0600 Subject: [PATCH 5/8] mark getindex as inbounds in __elem_apply_loop_body --- src/TracedUtils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index a61dfa289e..ee6ff9f7c5 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -1099,7 +1099,7 @@ function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) wh res = res_ref[] idx = idx_ref[] + 1 - scalar_args = [@allowscalar(arg[idx]) for arg in args] + scalar_args = [@allowscalar((@inbounds arg[idx])) for arg in args] @allowscalar res[idx] = fn(scalar_args...) idx_ref[] = idx From a43dedc6296378e59469749df17243a8bc164be4 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 12 Dec 2025 10:58:26 -0600 Subject: [PATCH 6/8] don't blindly unwrap_eltype in elem_apply_via_while_loop --- src/TracedUtils.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index ee6ff9f7c5..e411a6d69d 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -1115,7 +1115,12 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs} # This wont be a mutating function so we can safely execute it once res_tmp = @allowscalar(f([@allowscalar(arg[1]) for arg in flat_args]...)) - result = similar(first(flat_args), Reactant.unwrapped_eltype(res_tmp), L) + + # TODO: perhaps instead of this logic, we should have + # `similar(::TracedRArray, TracedRNumber{T}) where T = similar(::TracedRArray, T)` + # and just not unwrap here? + T_res = typeof(res_tmp) <: TracedRNumber ? unwrapped_eltype(res_tmp) : typeof(res_tmp) + result = similar(first(flat_args), T_res, L) ind_var = Ref(0) f_ref = Ref(f) From e57d690c3ac2287503520eef91eb8cd027e488ce Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 12 Dec 2025 11:19:28 -0600 Subject: [PATCH 7/8] test structarray broadcasting --- test/integration/structarrays.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/integration/structarrays.jl b/test/integration/structarrays.jl index bd3634db78..83c0ee3af1 100644 --- a/test/integration/structarrays.jl +++ b/test/integration/structarrays.jl @@ -39,3 +39,29 @@ using StructArrays, Reactant, Test Int64, } end + +@noinline function elwise(e::NamedTuple) + return (; c=e.b, d=sin(e.a)) +end + +function broadcast_elwise(x) + return elwise.(x) +end + +@testset "structarray broadcasting" begin + x = StructVector(; a=rand(10), b=rand(Float32, 10)) + + x_ra = Reactant.to_rarray(x) + + result = @jit broadcast_elwise(x_ra) + + @test typeof(result) == StructVector{ + @NamedTuple{c::ConcretePJRTNumber{Float32,1}, d::ConcretePJRTNumber{Float64,1}}, + @NamedTuple{c::ConcretePJRTArray{Float32,1,1}, d::ConcretePJRTArray{Float64,1,1}}, + CartesianIndex{1}, + } + for (component_ra, component) in + zip(components(result), components(broadcast_elwise(x))) + @test component_ra ≈ component + end +end From c23680cccbc87ff6505b8716ffe3efa37b34c3d0 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 12 Dec 2025 14:06:05 -0600 Subject: [PATCH 8/8] StructArrays test dependency --- test/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Project.toml b/test/Project.toml index 29584d8c7c..1ab04c3e29 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -35,6 +35,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"