@@ -1367,19 +1367,25 @@ end
13671367function unwrapped_broadcast (f:: F , x:: Base.Iterators.Zip , original_dims) where {F}
13681368 min_length = Base. inferencebarrier (minimum)(length, x. is)
13691369 itrs = [length (itr) > min_length ? itr[1 : min_length] : itr for itr in x. is]
1370- any (Base. Fix2 (isa, AnyTracedRArray), itrs) || return unrolled_map (f, x)
1371- return broadcast (BroadcastIterator (f), itrs... ), original_dims, identity
1370+ result = if any (Base. Fix2 (isa, AnyTracedRArray), itrs)
1371+ broadcast (BroadcastIterator (f), itrs... )
1372+ else
1373+ unrolled_map (f, x)
1374+ end
1375+ return result, original_dims, identity
13721376end
13731377
13741378function unwrapped_broadcast (f:: F , x:: Base.Iterators.Enumerate , original_dims) where {F}
1375- x. itr isa AnyTracedRArray || return unrolled_map (f, x)
1376- return (
1379+ result = if x. itr isa AnyTracedRArray
13771380 broadcast (
1378- BroadcastIterator (f), Reactant. promote_to (TracedRArray, 1 : length (x. itr)), x. itr
1379- ),
1380- original_dims,
1381- identity,
1382- )
1381+ BroadcastIterator (f),
1382+ Reactant. promote_to (TracedRArray, 1 : length (x. itr)),
1383+ x. itr,
1384+ )
1385+ else
1386+ unrolled_map (f, x)
1387+ end
1388+ return result, original_dims, identity
13831389end
13841390
13851391function unwrapped_broadcast (f:: F , x:: Slices , original_dims) where {F}
0 commit comments