Skip to content

Commit 31726d1

Browse files
committed
generalize StructVector to StructArray of NamedTuple
1 parent 801fbb6 commit 31726d1

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

ext/ReactantStructArraysExt.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import Reactant
44
import StructArrays
55

66
import StructArrays: StructArrayStyle, StructArray, StructVector, index_type
7-
import Reactant: TraceMode, TracedToTypes, traced_type_inner, append_path, make_tracer, traced_type
7+
import Reactant: TraceMode, TracedToTypes, traced_type_inner, append_path, make_tracer, traced_type, ReactantPrimitive
88
import Reactant.TracedRArrayOverrides: AbstractReactantArrayStyle, _copy
99
import Base.Broadcast: Broadcasted
1010

@@ -15,7 +15,7 @@ function Base.copy(bc::Broadcasted{StructArrays.StructArrayStyle{S, N}}) where {
1515
end
1616

1717
Base.@nospecializeinfer function Reactant.traced_type_inner(
18-
@nospecialize(prev::Type{<:StructVector{NT}}),
18+
@nospecialize(prev::Type{<:StructArray{NT}}),
1919
seen,
2020
@nospecialize(mode::TraceMode),
2121
@nospecialize(track_numbers::Type),
@@ -31,22 +31,30 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
3131
sharding,
3232
runtime,
3333
)
34-
T_traced = traced_type_inner(
35-
T,
36-
seen,
37-
mode,
34+
35+
names = T.parameters[1]
36+
valuetypes = T.parameters[2].parameters
37+
traced_value_types = map(valuetypes) do VT
3838
# The elements in the NamedTuple are backed by vectors,
3939
# these vectors are converted to RArrays so we need to track numbers:
40-
Number #= track_numbers =#,
41-
sharding,
42-
runtime,
43-
)
40+
track_numbers = VT <: ReactantPrimitive ? ReactantPrimitive : track_numbers
41+
traced_type_inner(
42+
VT,
43+
seen,
44+
mode,
45+
track_numbers,
46+
sharding,
47+
runtime,
48+
)
49+
end
50+
T_traced = NamedTuple{names, Tuple{traced_value_types...}}
51+
4452
return StructVector{T_traced, C_traced, index_type(fieldtypes(C_traced))}
4553
end
4654

4755
function Reactant.make_tracer(
48-
seen, @nospecialize(prev::StructVector{NT}), @nospecialize(path), mode; track_numbers=false, sharding=Reactant.Sharding.Sharding.NoSharding(), runtime=nothing, kwargs...
49-
) where {NT <: NamedTuple}
56+
seen, @nospecialize(prev::StructArray{NT, N}), @nospecialize(path), mode; track_numbers=false, sharding=Reactant.Sharding.Sharding.NoSharding(), runtime=nothing, kwargs...
57+
) where {NT <: NamedTuple, N}
5058
track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{})
5159
components = getfield(prev, :components)
5260
if mode == TracedToTypes
@@ -64,7 +72,7 @@ function Reactant.make_tracer(
6472
sharding,
6573
runtime,
6674
)
67-
return StructVector{first(T_traced.parameters)}(traced_components)
75+
return StructArray{first(T_traced.parameters)}(traced_components)
6876
end
6977

7078
@inline function Reactant.traced_getfield(@nospecialize(obj::StructArray), field)

0 commit comments

Comments
 (0)