Skip to content

Commit 801fbb6

Browse files
committed
initial StructArrays extension
1 parent d7fbda5 commit 801fbb6

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
2626
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
2727
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
2828
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
29+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2930
p7zip_jll = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
3031

3132
[weakdeps]
@@ -71,6 +72,7 @@ ReactantRandom123Ext = "Random123"
7172
ReactantSparseArraysExt = "SparseArrays"
7273
ReactantSpecialFunctionsExt = "SpecialFunctions"
7374
ReactantStatisticsExt = "Statistics"
75+
ReactantStructArraysExt = "StructArrays"
7476
ReactantYaoBlocksExt = "YaoBlocks"
7577
ReactantZygoteExt = "Zygote"
7678

@@ -115,6 +117,7 @@ Sockets = "1.10"
115117
SparseArrays = "1.10"
116118
SpecialFunctions = "2.4"
117119
Statistics = "1.10"
120+
StructArrays = "0.7.2"
118121
YaoBlocks = "0.13, 0.14"
119122
Zygote = "0.7"
120123
julia = "1.10"

ext/ReactantStructArraysExt.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
module ReactantStructArraysExt
2+
3+
import Reactant
4+
import StructArrays
5+
6+
import StructArrays: StructArrayStyle, StructArray, StructVector, index_type
7+
import Reactant: TraceMode, TracedToTypes, traced_type_inner, append_path, make_tracer, traced_type
8+
import Reactant.TracedRArrayOverrides: AbstractReactantArrayStyle, _copy
9+
import Base.Broadcast: Broadcasted
10+
11+
StructArrays.always_struct_broadcast(::AbstractReactantArrayStyle) = true
12+
13+
function Base.copy(bc::Broadcasted{StructArrays.StructArrayStyle{S, N}}) where {S<:AbstractReactantArrayStyle, N}
14+
return _copy(bc)
15+
end
16+
17+
Base.@nospecializeinfer function Reactant.traced_type_inner(
18+
@nospecialize(prev::Type{<:StructVector{NT}}),
19+
seen,
20+
@nospecialize(mode::TraceMode),
21+
@nospecialize(track_numbers::Type),
22+
@nospecialize(sharding),
23+
@nospecialize(runtime)
24+
) where {NT <: NamedTuple}
25+
T, N, C, I = prev.parameters
26+
C_traced = traced_type_inner(
27+
C,
28+
seen,
29+
mode,
30+
track_numbers,
31+
sharding,
32+
runtime,
33+
)
34+
T_traced = traced_type_inner(
35+
T,
36+
seen,
37+
mode,
38+
# The elements in the NamedTuple are backed by vectors,
39+
# these vectors are converted to RArrays so we need to track numbers:
40+
Number #= track_numbers =#,
41+
sharding,
42+
runtime,
43+
)
44+
return StructVector{T_traced, C_traced, index_type(fieldtypes(C_traced))}
45+
end
46+
47+
function Reactant.make_tracer(
48+
seen, @nospecialize(prev::StructVector{NT}), @nospecialize(path), mode; track_numbers=false, sharding=Reactant.Sharding.Sharding.NoSharding(), runtime=nothing, kwargs...
49+
) where {NT <: NamedTuple}
50+
track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{})
51+
components = getfield(prev, :components)
52+
if mode == TracedToTypes
53+
push!(path, typeof(prev))
54+
for c in components
55+
make_tracer(seen, c, path, mode; track_numbers, sharding, runtime, kwargs...)
56+
end
57+
return nothing
58+
end
59+
traced_components = make_tracer(seen, components, append_path(path, :components), mode; track_numbers, sharding, runtime, kwargs...)
60+
T_traced = traced_type(
61+
typeof(prev),
62+
Val(mode),
63+
track_numbers,
64+
sharding,
65+
runtime,
66+
)
67+
return StructVector{first(T_traced.parameters)}(traced_components)
68+
end
69+
70+
@inline function Reactant.traced_getfield(@nospecialize(obj::StructArray), field)
71+
return Base.getfield(obj, field)
72+
end
73+
74+
end

src/TracedRArray.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ first_scalar(x) = @allowscalar first(x)
459459

460460
# we need to override the outer copy method to make sure we never fall back to scalar
461461
# iteration (see, e.g., CUDA.jl#145)
462-
function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
462+
function _copy(bc)
463463
fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive
464464
TracedUtils.TypeCast{bc.f}()
465465
else
@@ -477,6 +477,10 @@ function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
477477
return copyto!(sim, bc)
478478
end
479479

480+
@noinline function Base.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
481+
return _copy(bc)
482+
end
483+
480484
function Base.materialize!(
481485
::Style, dest, bc::Broadcasted
482486
) where {Style<:AbstractReactantArrayStyle}

0 commit comments

Comments
 (0)