diff --git a/Project.toml b/Project.toml index 0bf5002667..3ebfcdcc44 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Scratch = "6c6a2e73-6563-6170-7368-637461726353" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" p7zip_jll = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" [weakdeps] @@ -71,6 +72,7 @@ ReactantRandom123Ext = "Random123" ReactantSparseArraysExt = "SparseArrays" ReactantSpecialFunctionsExt = "SpecialFunctions" ReactantStatisticsExt = "Statistics" +ReactantStructArraysExt = "StructArrays" ReactantYaoBlocksExt = "YaoBlocks" ReactantZygoteExt = "Zygote" @@ -115,6 +117,7 @@ Sockets = "1.10" SparseArrays = "1.10" SpecialFunctions = "2.4" Statistics = "1.10" +StructArrays = "0.7.2" YaoBlocks = "0.13, 0.14" Zygote = "0.7" julia = "1.10" diff --git a/ext/ReactantStructArraysExt.jl b/ext/ReactantStructArraysExt.jl new file mode 100644 index 0000000000..993556299f --- /dev/null +++ b/ext/ReactantStructArraysExt.jl @@ -0,0 +1,135 @@ +module ReactantStructArraysExt + +using Reactant: Reactant +using StructArrays: StructArrays + +import StructArrays: + StructArrayStyle, + StructArray, + index_type, + components, + createinstance, + get_ith, + maybe_convert_elt, + foreachfield +import Reactant: + TraceMode, + TracedToTypes, + traced_type_inner, + append_path, + make_tracer, + traced_type, + ReactantPrimitive, + broadcast_to_size, + TracedRNumber, + TracedRArray, + unwrapped_eltype +import Reactant.TracedRArrayOverrides: AbstractReactantArrayStyle, _copy +import Base.Broadcast: Broadcasted + +StructArrays.always_struct_broadcast(::AbstractReactantArrayStyle) = true + +function Base.copy( + bc::Broadcasted{StructArrays.StructArrayStyle{S,N}} +) where {S<:AbstractReactantArrayStyle,N} + return _copy(bc) +end + +function Reactant.broadcast_to_size(arg::StructArray{T}, rsize) where {T} + new = [broadcast_to_size(c, rsize) for c in components(arg)] + return StructArray{T}(NamedTuple(Base.propertynames(arg) .=> new)) +end + +function Base.copyto!( + dest::StructArray, bc::Base.Broadcast.Broadcasted{<:AbstractReactantArrayStyle} +) + axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) + isempty(dest) && return dest + + bc = Broadcast.preprocess(dest, bc) + + args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) + + res = Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...) + + return copyto!(dest, res) +end + +Base.@propagate_inbounds function StructArrays._getindex( + x::StructArray{T}, I::Vararg{TracedRNumber{<:Integer}} +) where {T} + cols = components(x) + @boundscheck checkbounds(x, I...) + return createinstance(T, get_ith(cols, I...)...) +end + +Base.@propagate_inbounds function Base.setindex!( + s::StructArray{T,<:Any,<:Any,Int}, vals, I::TracedRNumber{TI} +) where {T,TI<:Integer} + valsT = maybe_convert_elt(T, vals) + foreachfield((col, val) -> (@inbounds col[I] = val), s, valsT) + return s +end + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(prev::Type{<:StructArray{NT}}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) where {NT<:NamedTuple} + T, N, C, I = prev.parameters + C_traced = traced_type_inner(C, seen, mode, track_numbers, sharding, runtime) + + names = T.parameters[1] + valuetypes = T.parameters[2].parameters + traced_value_types = map(valuetypes) do VT + # The elements in the NamedTuple are backed by vectors, + # these vectors are converted to RArrays so we need to track numbers: + track_numbers = VT <: ReactantPrimitive ? ReactantPrimitive : track_numbers + traced_type_inner(VT, seen, mode, track_numbers, sharding, runtime) + end + T_traced = NamedTuple{names,Tuple{traced_value_types...}} + + return StructArray{T_traced,N,C_traced,index_type(fieldtypes(C_traced))} +end + +function Reactant.make_tracer( + seen, + @nospecialize(prev::StructArray{NT,N}), + @nospecialize(path), + mode; + track_numbers=false, + sharding=Reactant.Sharding.Sharding.NoSharding(), + runtime=nothing, + kwargs..., +) where {NT<:NamedTuple,N} + track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{}) + components = getfield(prev, :components) + if mode == TracedToTypes + push!(path, typeof(prev)) + for c in components + make_tracer(seen, c, path, mode; track_numbers, sharding, runtime, kwargs...) + end + return nothing + end + traced_components = make_tracer( + seen, + components, + append_path(path, 1), + mode; + track_numbers, + sharding, + runtime, + kwargs..., + ) + T_traced = traced_type(typeof(prev), Val(mode), track_numbers, sharding, runtime) + return StructArray{first(T_traced.parameters)}(traced_components) +end + +@inline function Reactant.traced_getfield(@nospecialize(obj::StructArray), field) + return Base.getfield(obj, field) +end + +end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 85f692db41..72bba9ca50 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -459,7 +459,7 @@ first_scalar(x) = @allowscalar first(x) # we need to override the outer copy method to make sure we never fall back to scalar # iteration (see, e.g., CUDA.jl#145) -function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) +function _copy(bc) fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive TracedUtils.TypeCast{bc.f}() else @@ -477,6 +477,10 @@ function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) return copyto!(sim, bc) end +@noinline function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) + return _copy(bc) +end + function Base.materialize!( ::Style, dest, bc::Broadcasted ) where {Style<:AbstractReactantArrayStyle} diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index a61dfa289e..e411a6d69d 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -1099,7 +1099,7 @@ function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) wh res = res_ref[] idx = idx_ref[] + 1 - scalar_args = [@allowscalar(arg[idx]) for arg in args] + scalar_args = [@allowscalar((@inbounds arg[idx])) for arg in args] @allowscalar res[idx] = fn(scalar_args...) idx_ref[] = idx @@ -1115,7 +1115,12 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs} # This wont be a mutating function so we can safely execute it once res_tmp = @allowscalar(f([@allowscalar(arg[1]) for arg in flat_args]...)) - result = similar(first(flat_args), Reactant.unwrapped_eltype(res_tmp), L) + + # TODO: perhaps instead of this logic, we should have + # `similar(::TracedRArray, TracedRNumber{T}) where T = similar(::TracedRArray, T)` + # and just not unwrap here? + T_res = typeof(res_tmp) <: TracedRNumber ? unwrapped_eltype(res_tmp) : typeof(res_tmp) + result = similar(first(flat_args), T_res, L) ind_var = Ref(0) f_ref = Ref(f) diff --git a/test/Project.toml b/test/Project.toml index 29584d8c7c..1ab04c3e29 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -35,6 +35,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/integration/structarrays.jl b/test/integration/structarrays.jl new file mode 100644 index 0000000000..83c0ee3af1 --- /dev/null +++ b/test/integration/structarrays.jl @@ -0,0 +1,67 @@ +using StructArrays, Reactant, Test + +@testset "StructArray to_rarray and make_tracer" begin + x = StructArray(; + a=rand(10, 2), b=fill("some strings", (10, 2)), c=rand(Float32, 10, 2) + ) + x_ra = Reactant.to_rarray(x) + + # Note that the element type (the NamedTuple) contains ConcreteRNumbers even though track_numbers is not enabled. + # This is because when the backing arrays are converted to TracedRArrays, their elements will contain TracedRNumbers. + # In order for the element type to match the backing arrays, we need to use ConcreteRNumbers here as well: + @test typeof(x_ra) == StructArray{ + @NamedTuple{ + a::ConcretePJRTNumber{Float64,1}, b::String, c::ConcretePJRTNumber{Float32,1} + }, + 2, + @NamedTuple{ + a::ConcretePJRTArray{Float64,2,1}, + b::Matrix{String}, + c::ConcretePJRTArray{Float32,2,1}, + }, + CartesianIndex{2}, + } + + @test typeof( + make_tracer(Reactant.OrderedIdDict(), x_ra, (), Reactant.ConcreteToTraced) + ) == StructArray{ + @NamedTuple{ + a::Reactant.TracedRNumber{Float64}, + b::String, + c::Reactant.TracedRNumber{Float32}, + }, + 2, + @NamedTuple{ + a::Reactant.TracedRArray{Float64,2}, + b::Matrix{String}, + c::Reactant.TracedRArray{Float32,2}, + }, + Int64, + } +end + +@noinline function elwise(e::NamedTuple) + return (; c=e.b, d=sin(e.a)) +end + +function broadcast_elwise(x) + return elwise.(x) +end + +@testset "structarray broadcasting" begin + x = StructVector(; a=rand(10), b=rand(Float32, 10)) + + x_ra = Reactant.to_rarray(x) + + result = @jit broadcast_elwise(x_ra) + + @test typeof(result) == StructVector{ + @NamedTuple{c::ConcretePJRTNumber{Float32,1}, d::ConcretePJRTNumber{Float64,1}}, + @NamedTuple{c::ConcretePJRTArray{Float32,1,1}, d::ConcretePJRTArray{Float64,1,1}}, + CartesianIndex{1}, + } + for (component_ra, component) in + zip(components(result), components(broadcast_elwise(x))) + @test component_ra ≈ component + end +end diff --git a/test/runtests.jl b/test/runtests.jl index bbd5e0855f..6d1c04e9df 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,6 +62,7 @@ end @safetestset "OneHotArrays" include("integration/onehotarrays.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") @safetestset "SpecialFunctions" include("integration/special_functions.jl") + @safetestset "StructArrays" include("integration/structarrays.jl") @safetestset "Random" include("integration/random.jl") @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl")