1+ module ReactantStructArraysExt
2+
3+ import Reactant
4+ import StructArrays
5+
6+ import StructArrays: StructArrayStyle, StructArray, StructVector, index_type
7+ import Reactant: TraceMode, TracedToTypes, traced_type_inner, append_path, make_tracer, traced_type
8+ import Reactant. TracedRArrayOverrides: AbstractReactantArrayStyle, _copy
9+ import Base. Broadcast: Broadcasted
10+
11+ StructArrays. always_struct_broadcast (:: AbstractReactantArrayStyle ) = true
12+
13+ function Base. copy (bc:: Broadcasted{StructArrays.StructArrayStyle{S, N}} ) where {S<: AbstractReactantArrayStyle , N}
14+ return _copy (bc)
15+ end
16+
17+ Base. @nospecializeinfer function Reactant. traced_type_inner (
18+ @nospecialize (prev:: Type{<:StructVector{NT}} ),
19+ seen,
20+ @nospecialize (mode:: TraceMode ),
21+ @nospecialize (track_numbers:: Type ),
22+ @nospecialize (sharding),
23+ @nospecialize (runtime)
24+ ) where {NT <: NamedTuple }
25+ T, N, C, I = prev. parameters
26+ C_traced = traced_type_inner (
27+ C,
28+ seen,
29+ mode,
30+ track_numbers,
31+ sharding,
32+ runtime,
33+ )
34+ T_traced = traced_type_inner (
35+ T,
36+ seen,
37+ mode,
38+ # The elements in the NamedTuple are backed by vectors,
39+ # these vectors are converted to RArrays so we need to track numbers:
40+ Number #= track_numbers =# ,
41+ sharding,
42+ runtime,
43+ )
44+ return StructVector{T_traced, C_traced, index_type (fieldtypes (C_traced))}
45+ end
46+
47+ function Reactant. make_tracer (
48+ seen, @nospecialize (prev:: StructVector{NT} ), @nospecialize (path), mode; track_numbers= false , sharding= Reactant. Sharding. Sharding. NoSharding (), runtime= nothing , kwargs...
49+ ) where {NT <: NamedTuple }
50+ track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{})
51+ components = getfield (prev, :components )
52+ if mode == TracedToTypes
53+ push! (path, typeof (prev))
54+ for c in components
55+ make_tracer (seen, c, path, mode; track_numbers, sharding, runtime, kwargs... )
56+ end
57+ return nothing
58+ end
59+ traced_components = make_tracer (seen, components, append_path (path, :components ), mode; track_numbers, sharding, runtime, kwargs... )
60+ T_traced = traced_type (
61+ typeof (prev),
62+ Val (mode),
63+ track_numbers,
64+ sharding,
65+ runtime,
66+ )
67+ return StructVector {first(T_traced.parameters)} (traced_components)
68+ end
69+
70+ @inline function Reactant. traced_getfield (@nospecialize (obj:: StructArray ), field)
71+ return Base. getfield (obj, field)
72+ end
73+
74+ end
0 commit comments