Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
p7zip_jll = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

[weakdeps]
Expand Down Expand Up @@ -71,6 +72,7 @@ ReactantRandom123Ext = "Random123"
ReactantSparseArraysExt = "SparseArrays"
ReactantSpecialFunctionsExt = "SpecialFunctions"
ReactantStatisticsExt = "Statistics"
ReactantStructArraysExt = "StructArrays"
ReactantYaoBlocksExt = "YaoBlocks"
ReactantZygoteExt = "Zygote"

Expand Down Expand Up @@ -115,6 +117,7 @@ Sockets = "1.10"
SparseArrays = "1.10"
SpecialFunctions = "2.4"
Statistics = "1.10"
StructArrays = "0.7.2"
YaoBlocks = "0.13, 0.14"
Zygote = "0.7"
julia = "1.10"
Expand Down
135 changes: 135 additions & 0 deletions ext/ReactantStructArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
module ReactantStructArraysExt

using Reactant: Reactant
using StructArrays: StructArrays

import StructArrays:
StructArrayStyle,
StructArray,
index_type,
components,
createinstance,
get_ith,
maybe_convert_elt,
foreachfield
import Reactant:
TraceMode,
TracedToTypes,
traced_type_inner,
append_path,
make_tracer,
traced_type,
ReactantPrimitive,
broadcast_to_size,
TracedRNumber,
TracedRArray,
unwrapped_eltype
import Reactant.TracedRArrayOverrides: AbstractReactantArrayStyle, _copy
import Base.Broadcast: Broadcasted

StructArrays.always_struct_broadcast(::AbstractReactantArrayStyle) = true

function Base.copy(
bc::Broadcasted{StructArrays.StructArrayStyle{S,N}}
) where {S<:AbstractReactantArrayStyle,N}
return _copy(bc)
end

function Reactant.broadcast_to_size(arg::StructArray{T}, rsize) where {T}
new = [broadcast_to_size(c, rsize) for c in components(arg)]
return StructArray{T}(NamedTuple(Base.propertynames(arg) .=> new))
end

function Base.copyto!(
dest::StructArray, bc::Base.Broadcast.Broadcasted{<:AbstractReactantArrayStyle}
)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest

bc = Broadcast.preprocess(dest, bc)

args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)

res = Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...)

return copyto!(dest, res)
end

Base.@propagate_inbounds function StructArrays._getindex(
x::StructArray{T}, I::Vararg{TracedRNumber{<:Integer}}
) where {T}
cols = components(x)
@boundscheck checkbounds(x, I...)
return createinstance(T, get_ith(cols, I...)...)
end

Base.@propagate_inbounds function Base.setindex!(
s::StructArray{T,<:Any,<:Any,Int}, vals, I::TracedRNumber{TI}
) where {T,TI<:Integer}
valsT = maybe_convert_elt(T, vals)
foreachfield((col, val) -> (@inbounds col[I] = val), s, valsT)
return s
end

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(prev::Type{<:StructArray{NT}}),
seen,
@nospecialize(mode::TraceMode),
@nospecialize(track_numbers::Type),
@nospecialize(sharding),
@nospecialize(runtime)
) where {NT<:NamedTuple}
T, N, C, I = prev.parameters
C_traced = traced_type_inner(C, seen, mode, track_numbers, sharding, runtime)

names = T.parameters[1]
valuetypes = T.parameters[2].parameters
traced_value_types = map(valuetypes) do VT
# The elements in the NamedTuple are backed by vectors,
# these vectors are converted to RArrays so we need to track numbers:
track_numbers = VT <: ReactantPrimitive ? ReactantPrimitive : track_numbers
traced_type_inner(VT, seen, mode, track_numbers, sharding, runtime)
end
T_traced = NamedTuple{names,Tuple{traced_value_types...}}

return StructArray{T_traced,N,C_traced,index_type(fieldtypes(C_traced))}
end

function Reactant.make_tracer(
seen,
@nospecialize(prev::StructArray{NT,N}),
@nospecialize(path),
mode;
track_numbers=false,
sharding=Reactant.Sharding.Sharding.NoSharding(),
runtime=nothing,
kwargs...,
) where {NT<:NamedTuple,N}
track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{})
components = getfield(prev, :components)
if mode == TracedToTypes
push!(path, typeof(prev))
for c in components
make_tracer(seen, c, path, mode; track_numbers, sharding, runtime, kwargs...)
end
return nothing
end
traced_components = make_tracer(
seen,
components,
append_path(path, 1),
mode;
track_numbers,
sharding,
runtime,
kwargs...,
)
T_traced = traced_type(typeof(prev), Val(mode), track_numbers, sharding, runtime)
return StructArray{first(T_traced.parameters)}(traced_components)
end

@inline function Reactant.traced_getfield(@nospecialize(obj::StructArray), field)
return Base.getfield(obj, field)
end

end
6 changes: 5 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ first_scalar(x) = @allowscalar first(x)

# we need to override the outer copy method to make sure we never fall back to scalar
# iteration (see, e.g., CUDA.jl#145)
function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
function _copy(bc)
fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive
TracedUtils.TypeCast{bc.f}()
else
Expand All @@ -477,6 +477,10 @@ function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
return copyto!(sim, bc)
end

@noinline function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
return _copy(bc)
end

function Base.materialize!(
::Style, dest, bc::Broadcasted
) where {Style<:AbstractReactantArrayStyle}
Expand Down
9 changes: 7 additions & 2 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) wh
res = res_ref[]
idx = idx_ref[] + 1

scalar_args = [@allowscalar(arg[idx]) for arg in args]
scalar_args = [@allowscalar((@inbounds arg[idx])) for arg in args]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a specific reason we need inbounds?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without, I run into:

ERROR: ArgumentError: unable to check bounds for indices of type Reactant.TracedRNumber{Int64}
Stacktrace:
  [1] checkindex(::Type{Bool}, inds::Base.OneTo{Int64}, i::Reactant.TracedRNumber{Int64})
    @ Base ./abstractarray.jl:751
  [2] checkbounds
    @ ./abstractarray.jl:689 [inlined]
  [3] checkbounds
    @ ./abstractarray.jl:699 [inlined]
  [4] _getindex
    @ ~/Reactant2/ext/ReactantStructArraysExt.jl:62 [inlined]
  [5] getindex
    @ ~/comrade-reactant/StructArrays.jl/src/structarray.jl:345 [inlined]
  [6] #26
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:206 [inlined]
  [7] (::Reactant.TracedUtils.var"#26#28"{…})(arg::StructVector{…})
    @ Reactant.TracedUtils ./none:0
  [8] iterate
    @ ./generator.jl:48 [inlined]
  [9] collect
    @ ./array.jl:791 [inlined]
 [10] (::Nothing)(none::typeof(collect), none::Base.Generator{Vector{…}, Reactant.TracedUtils.var"#26#28"{…}})
    @ Reactant ./<missing>:0
 [11] call_with_reactant(::typeof(collect), ::Base.Generator{Vector{…}, Reactant.TracedUtils.var"#26#28"{…}})
    @ Reactant ~/Reactant2/src/utils.jl:523
 [12] __elem_apply_loop_body
    @ ~/Reactant2/src/TracedUtils.jl:1101 [inlined]
 [13] (::Nothing)(none::typeof(Reactant.TracedUtils.__elem_apply_loop_body), none::Base.RefValue{…}, none::Base.RefValue{…}, none::Base.RefValue{…}, none::Base.RefValue{…}, none::Base.RefValue{…})
    @ Reactant ./<missing>:0
 [14] getproperty
    @ ./Base.jl:49 [inlined]
 [15] getindex
    @ ./refvalue.jl:59 [inlined]
 [16] __elem_apply_loop_body
    @ ~/Reactant2/src/TracedUtils.jl:1095 [inlined]

But that probably just means I need to add
checkindex(::Type{Bool, inds, i::TracedRNumber{<:Integer})
Though now I'm confused as to why we're not hitting this in other places?

@allowscalar res[idx] = fn(scalar_args...)

idx_ref[] = idx
Expand All @@ -1115,7 +1115,12 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs}

# This wont be a mutating function so we can safely execute it once
res_tmp = @allowscalar(f([@allowscalar(arg[1]) for arg in flat_args]...))
result = similar(first(flat_args), Reactant.unwrapped_eltype(res_tmp), L)

# TODO: perhaps instead of this logic, we should have
# `similar(::TracedRArray, TracedRNumber{T}) where T = similar(::TracedRArray, T)`
# and just not unwrap here?
T_res = typeof(res_tmp) <: TracedRNumber ? unwrapped_eltype(res_tmp) : typeof(res_tmp)
result = similar(first(flat_args), T_res, L)

ind_var = Ref(0)
f_ref = Ref(f)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
67 changes: 67 additions & 0 deletions test/integration/structarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using StructArrays, Reactant, Test

@testset "StructArray to_rarray and make_tracer" begin
x = StructArray(;
a=rand(10, 2), b=fill("some strings", (10, 2)), c=rand(Float32, 10, 2)
)
x_ra = Reactant.to_rarray(x)

# Note that the element type (the NamedTuple) contains ConcreteRNumbers even though track_numbers is not enabled.
# This is because when the backing arrays are converted to TracedRArrays, their elements will contain TracedRNumbers.
# In order for the element type to match the backing arrays, we need to use ConcreteRNumbers here as well:
@test typeof(x_ra) == StructArray{
@NamedTuple{
a::ConcretePJRTNumber{Float64,1}, b::String, c::ConcretePJRTNumber{Float32,1}
},
2,
@NamedTuple{
a::ConcretePJRTArray{Float64,2,1},
b::Matrix{String},
c::ConcretePJRTArray{Float32,2,1},
},
CartesianIndex{2},
}

@test typeof(
make_tracer(Reactant.OrderedIdDict(), x_ra, (), Reactant.ConcreteToTraced)
) == StructArray{
@NamedTuple{
a::Reactant.TracedRNumber{Float64},
b::String,
c::Reactant.TracedRNumber{Float32},
},
2,
@NamedTuple{
a::Reactant.TracedRArray{Float64,2},
b::Matrix{String},
c::Reactant.TracedRArray{Float32,2},
},
Int64,
}
end

@noinline function elwise(e::NamedTuple)
return (; c=e.b, d=sin(e.a))
end

function broadcast_elwise(x)
return elwise.(x)
end

@testset "structarray broadcasting" begin
x = StructVector(; a=rand(10), b=rand(Float32, 10))

x_ra = Reactant.to_rarray(x)

result = @jit broadcast_elwise(x_ra)

@test typeof(result) == StructVector{
@NamedTuple{c::ConcretePJRTNumber{Float32,1}, d::ConcretePJRTNumber{Float64,1}},
@NamedTuple{c::ConcretePJRTArray{Float32,1,1}, d::ConcretePJRTArray{Float64,1,1}},
CartesianIndex{1},
}
for (component_ra, component) in
zip(components(result), components(broadcast_elwise(x)))
@test component_ra ≈ component
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ end
@safetestset "OneHotArrays" include("integration/onehotarrays.jl")
@safetestset "AbstractFFTs" include("integration/fft.jl")
@safetestset "SpecialFunctions" include("integration/special_functions.jl")
@safetestset "StructArrays" include("integration/structarrays.jl")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a test dependency on structarrays

@safetestset "Random" include("integration/random.jl")
@safetestset "Python" include("integration/python.jl")
@safetestset "Optimisers" include("integration/optimisers.jl")
Expand Down
Loading