Skip to content

Commit 2070ad3

Browse files
committed
fix: old dispatches
1 parent 5ce1ae0 commit 2070ad3

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/TracedRArray.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,19 +1367,25 @@ end
13671367
function unwrapped_broadcast(f::F, x::Base.Iterators.Zip, original_dims) where {F}
13681368
min_length = Base.inferencebarrier(minimum)(length, x.is)
13691369
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
1370-
any(Base.Fix2(isa, AnyTracedRArray), itrs) || return unrolled_map(f, x)
1371-
return broadcast(BroadcastIterator(f), itrs...), original_dims, identity
1370+
result = if any(Base.Fix2(isa, AnyTracedRArray), itrs)
1371+
broadcast(BroadcastIterator(f), itrs...)
1372+
else
1373+
unrolled_map(f, x)
1374+
end
1375+
return result, original_dims, identity
13721376
end
13731377

13741378
function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate, original_dims) where {F}
1375-
x.itr isa AnyTracedRArray || return unrolled_map(f, x)
1376-
return (
1379+
result = if x.itr isa AnyTracedRArray
13771380
broadcast(
1378-
BroadcastIterator(f), Reactant.promote_to(TracedRArray, 1:length(x.itr)), x.itr
1379-
),
1380-
original_dims,
1381-
identity,
1382-
)
1381+
BroadcastIterator(f),
1382+
Reactant.promote_to(TracedRArray, 1:length(x.itr)),
1383+
x.itr,
1384+
)
1385+
else
1386+
unrolled_map(f, x)
1387+
end
1388+
return result, original_dims, identity
13831389
end
13841390

13851391
function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}

0 commit comments

Comments
 (0)