@@ -4,7 +4,7 @@ import Reactant
44import StructArrays
55
66import StructArrays: StructArrayStyle, StructArray, StructVector, index_type
7- import Reactant: TraceMode, TracedToTypes, traced_type_inner, append_path, make_tracer, traced_type
7+ import Reactant: TraceMode, TracedToTypes, traced_type_inner, append_path, make_tracer, traced_type, ReactantPrimitive
88import Reactant. TracedRArrayOverrides: AbstractReactantArrayStyle, _copy
99import Base. Broadcast: Broadcasted
1010
@@ -15,7 +15,7 @@ function Base.copy(bc::Broadcasted{StructArrays.StructArrayStyle{S, N}}) where {
1515end
1616
1717Base. @nospecializeinfer function Reactant. traced_type_inner (
18- @nospecialize (prev:: Type{<:StructVector {NT}} ),
18+ @nospecialize (prev:: Type{<:StructArray {NT}} ),
1919 seen,
2020 @nospecialize (mode:: TraceMode ),
2121 @nospecialize (track_numbers:: Type ),
@@ -31,22 +31,30 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
3131 sharding,
3232 runtime,
3333 )
34- T_traced = traced_type_inner (
35- T,
36- seen,
37- mode,
34+
35+ names = T . parameters[ 1 ]
36+ valuetypes = T . parameters[ 2 ] . parameters
37+ traced_value_types = map (valuetypes) do VT
3838 # The elements in the NamedTuple are backed by vectors,
3939 # these vectors are converted to RArrays so we need to track numbers:
40- Number #= track_numbers =# ,
41- sharding,
42- runtime,
43- )
40+ track_numbers = VT <: ReactantPrimitive ? ReactantPrimitive : track_numbers
41+ traced_type_inner (
42+ VT,
43+ seen,
44+ mode,
45+ track_numbers,
46+ sharding,
47+ runtime,
48+ )
49+ end
50+ T_traced = NamedTuple{names, Tuple{traced_value_types... }}
51+
4452 return StructVector{T_traced, C_traced, index_type (fieldtypes (C_traced))}
4553end
4654
4755function Reactant. make_tracer (
48- seen, @nospecialize (prev:: StructVector {NT} ), @nospecialize (path), mode; track_numbers= false , sharding= Reactant. Sharding. Sharding. NoSharding (), runtime= nothing , kwargs...
49- ) where {NT <: NamedTuple }
56+ seen, @nospecialize (prev:: StructArray {NT, N } ), @nospecialize (path), mode; track_numbers= false , sharding= Reactant. Sharding. Sharding. NoSharding (), runtime= nothing , kwargs...
57+ ) where {NT <: NamedTuple , N }
5058 track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{})
5159 components = getfield (prev, :components )
5260 if mode == TracedToTypes
@@ -64,7 +72,7 @@ function Reactant.make_tracer(
6472 sharding,
6573 runtime,
6674 )
67- return StructVector {first(T_traced.parameters)} (traced_components)
75+ return StructArray {first(T_traced.parameters)} (traced_components)
6876end
6977
7078@inline function Reactant. traced_getfield (@nospecialize (obj:: StructArray ), field)
0 commit comments