From 070d141526642c3c2928406362792a0524700312 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 30 Dec 2025 19:06:44 +0100 Subject: [PATCH 1/2] allow integers in reshape --- src/Reshape/LocalReshape.jl | 5 ++- src/Reshape/Reshape.jl | 54 +++++++++++++++++++++++++++++-- src/Reshape/utils.jl | 57 +++++++++++++++++++++++++++++--- src/enhanced-base/dropdims.jl | 5 ++- src/enhanced-base/reshape.jl | 61 +++++++++++++++++++---------------- src/enhanced-base/vec.jl | 2 +- 6 files changed, 147 insertions(+), 37 deletions(-) diff --git a/src/Reshape/LocalReshape.jl b/src/Reshape/LocalReshape.jl index 68952ee..dc7d7cd 100644 --- a/src/Reshape/LocalReshape.jl +++ b/src/Reshape/LocalReshape.jl @@ -50,9 +50,10 @@ Merge(N::IntOrEllipsis) = Merge{N}() """ Split(N, sizes) Split{N}(sizes...) + Split(sizes...) Split(sizes) -Split the first `N` dimensions into `M` dimensions, with sizes given by a tuple +Split the first `N` (1, if not provided) dimensions into `M` dimensions, with sizes given by a tuple of integers and at most one colon (`:`). This can be interpreted as a local reshape operation on the `N` dimensions, and doesn't have many of the compile time guarantees of the other operations. @@ -82,6 +83,8 @@ Split(N::IntOrEllipsis, sizes::T) where {M,T<:NTuple{M,IntOrColon}} = Split{N}(sizes::IntOrColon...) where N = Split(N, sizes) +Split(sizes::IntOrColon...) = Split(1, sizes) + Split(sizes::Tuple{Vararg{IntOrColon}}) = Split(1, sizes) """ diff --git a/src/Reshape/Reshape.jl b/src/Reshape/Reshape.jl index 28e9fa2..2cf7b41 100644 --- a/src/Reshape/Reshape.jl +++ b/src/Reshape/Reshape.jl @@ -15,7 +15,57 @@ struct Reshape{OpsT,N,M} <: GlobalAxisOp{N,M} end @generated function Reshape(ops::OpsT, ::Val{N}) where {OpsT<:Tuple,N} - op_types = OpsT.parameters + raw_op_types = OpsT.parameters + _, preprocess_result = _preprocess_op_types(raw_op_types) + + op_types = Any[] + op_exprs = Any[] + + if preprocess_result === nothing + for (k, opT) in enumerate(raw_op_types) + if is_ellipsis(opT) + push!(op_types, Keep{..}) + push!(op_exprs, :(Keep(..))) + else + push!(op_types, opT) + push!(op_exprs, :(ops[$k])) + end + end + elseif preprocess_result.first === :lone_colon + idx = preprocess_result.second + for (k, opT) in enumerate(raw_op_types) + if k == idx + push!(op_types, Merge{..}) + push!(op_exprs, :(Merge(..))) + elseif is_ellipsis(opT) + push!(op_types, Keep{..}) + push!(op_exprs, :(Keep(..))) + else + push!(op_types, opT) + push!(op_exprs, :(ops[$k])) + end + end + else + (start, stop) = preprocess_result.second + split_type = _build_split_type(raw_op_types, start, stop) + split_expr = _build_split_expr(start, stop) + + for (k, opT) in enumerate(raw_op_types) + if k == start + push!(op_types, split_type) + push!(op_exprs, split_expr) + elseif k > start && k <= stop + continue + elseif is_ellipsis(opT) + push!(op_types, Keep{..}) + push!(op_exprs, :(Keep(..))) + else + push!(op_types, opT) + push!(op_exprs, :(ops[$k])) + end + end + end + _ops_has_ellipsis(op_types) || nothing ellipsis_seen = false @@ -70,7 +120,7 @@ end T0 = opT.parameters[3] M0 = opT.parameters[2] mout == M0 || throw(ArgumentError("Split output rank cannot be ellipsis-resolved")) - push!(resolved_op_exprs, :(Split($nin, ops[$k].sizes))) + push!(resolved_op_exprs, :(Split($nin, $(op_exprs[k]).sizes))) push!(resolved_op_types, Split{nin, M0, T0}) elseif opT <: Resqueeze diff --git a/src/Reshape/utils.jl b/src/Reshape/utils.jl index 297c0d0..99f484d 100644 --- a/src/Reshape/utils.jl +++ b/src/Reshape/utils.jl @@ -104,16 +104,14 @@ end return Expr(:call, :*, (Expr(:call, :size, xsym, in_dim + j) for j in 1:n_in)...) end -function _ops_has_ellipsis(op_types::Tuple) +function _ops_has_ellipsis(op_types) for op in op_types (is_ellipsis(ndims_in(op)) || is_ellipsis(ndims_out(op))) && return true end return false end -_ops_has_ellipsis(op_types::Core.SimpleVector) = _ops_has_ellipsis(Tuple(op_types)) - -function _ops_total_in(op_types::Tuple) +function _ops_total_in(op_types) total = 0 for op in op_types n = ndims_in(op) @@ -123,4 +121,53 @@ function _ops_total_in(op_types::Tuple) return total end -_ops_total_in(op_types::Core.SimpleVector) = _ops_total_in(Tuple(op_types)) +function _is_intcolon(T::Type) + T <: Int && return true + T <: Colon && return true + return false +end + +function _preprocess_op_types(op_types) + n = length(op_types) + n == 0 && return op_types, nothing + + runs = Tuple{Int,Int}[] + i = 1 + while i <= n + if _is_intcolon(op_types[i]) + start = i + while i <= n && _is_intcolon(op_types[i]) + i += 1 + end + push!(runs, (start, i - 1)) + else + i += 1 + end + end + + has_ellipsis = any(T -> is_ellipsis(T), op_types) + + isempty(runs) && return op_types, nothing + + if length(runs) == 1 && runs[1][1] == runs[1][2] && op_types[runs[1][1]] <: Colon + has_ellipsis && throw(ArgumentError("At most one Colon or Ellipsis is allowed")) + return op_types, :lone_colon => runs[1][1] + end + + length(runs) == 1 || throw(ArgumentError("Int/Colon must form a single contiguous sequence")) + has_ellipsis && throw(ArgumentError("Cannot mix Ellipsis (..) with Int/Colon sequence")) + + return op_types, :split_sequence => runs[1] +end + +function _build_split_type(op_types, start::Int, stop::Int) + M = stop - start + 1 + size_types = Any[op_types[i] for i in start:stop] + T = Tuple{size_types...} + return Split{.., M, T} +end + +function _build_split_expr(start::Int, stop::Int) + sizes_expr = Expr(:tuple, [:(ops[$i]) for i in start:stop]...) + return :(Split(.., $sizes_expr)) +end diff --git a/src/enhanced-base/dropdims.jl b/src/enhanced-base/dropdims.jl index 027f783..b44c22d 100644 --- a/src/enhanced-base/dropdims.jl +++ b/src/enhanced-base/dropdims.jl @@ -3,6 +3,9 @@ Drop the specified dimensions from the array `x`. +!!! note + This method may not be type stable if `dims` cannot be constant-propagated. + ```jldoctest julia> x = [1 3 5; 2 4 6;;;] 2×3×1 Array{Int64, 3}: @@ -27,5 +30,5 @@ julia> Rewrap.dropdims(y; dims=3) ) where N dims′ = dims isa Int ? (dims,) : dims ops = ntuple(i -> i in dims′ ? Squeeze() : Keep(), N) - return reshape(x, ops) + return Rewrap.reshape(x, ops) end diff --git a/src/enhanced-base/reshape.jl b/src/enhanced-base/reshape.jl index 638c02f..15d335d 100644 --- a/src/enhanced-base/reshape.jl +++ b/src/enhanced-base/reshape.jl @@ -1,11 +1,13 @@ -const AnyOp = Union{LocalReshape,ColonOrEllipsis} +const AnyOp = Union{LocalReshape,Int,Colon,Ellipsis} """ Rewrap.reshape(x, ops...) Rewrap.reshape(x, ops::Tuple) Reshape the array `x` using the given operations, which can include -`:` (Base.Colon) and `..` (EllipsisNotation.Ellipsis). +`:` (Base.Colon), `..` (EllipsisNotation.Ellipsis), and integers. + +Integers and colons can form a single contiguous sequence that becomes a `Split(.., sizes)`. See also [`Base.reshape`](@ref). @@ -35,38 +37,26 @@ function reshape end Rewrap.reshape(x::AbstractArray, args...) = Base.reshape(x, args...) -@constprop function Rewrap.reshape(x::AbstractArray, ops::Tuple{LocalReshape,Vararg{LocalReshape}}) +@constprop function Rewrap.reshape(x::AbstractArray, ops::Tuple{AnyOp,Vararg{AnyOp}}) r = Reshape(ops, Val(ndims(x))) r(x) end -@constprop function Rewrap.reshape(x::AbstractArray, ops::Tuple{AnyOp,Vararg{AnyOp}}) - count(op -> op isa ColonOrEllipsis, ops) > 1 && throw(ArgumentError("At most one Colon or Ellipsis is allowed")) - ops′ = map(ops) do op - if op isa Colon - Merge(..) - elseif op isa Ellipsis - Keep(..) - else - op - end - end - return Rewrap.reshape(x, ops′) -end - @constprop function Rewrap.reshape(x::AbstractArray, op1::AnyOp, ops::AnyOp...) return Rewrap.reshape(x, (op1, ops...)) end ## Base.reshape +const NonLocalOp = Union{Int,Colon,Ellipsis} + """ Base.reshape(x, ops...) Base.reshape(x, ops::Tuple) Reshape the array `x` using the given operations, which can include -`:` (Base.Colon) and `..` (EllipsisNotation.Ellipsis), but there -must be at least one `LocalReshape`. +`:` (Base.Colon), `..` (EllipsisNotation.Ellipsis), and integers, +but there must be at least one `LocalReshape`. See also [`Rewrap.reshape`](@ref). @@ -94,30 +84,47 @@ julia> reshape(view(rand(2, 3), :, 1:2), Merge(..)) |> summary # can not use a s """ Base.reshape -@constprop function Base.reshape(x::AbstractArray, ops::Tuple{LocalReshape,Vararg{LocalReshape}}) - return Rewrap.reshape(x, ops) -end - @constprop function Base.reshape( x::AbstractArray, ops::Union{ - Tuple{ColonOrEllipsis,LocalReshape,Vararg{LocalReshape}}, - Tuple{LocalReshape,Vararg{AnyOp}} + Tuple{LocalReshape,Vararg{AnyOp}}, + Tuple{NonLocalOp,LocalReshape,Vararg{AnyOp}}, + Tuple{NonLocalOp,NonLocalOp,LocalReshape,Vararg{AnyOp}}, + Tuple{NonLocalOp,NonLocalOp,NonLocalOp,LocalReshape,Vararg{AnyOp}}, + Tuple{NonLocalOp,NonLocalOp,NonLocalOp,NonLocalOp,LocalReshape,Vararg{AnyOp}} } ) return Rewrap.reshape(x, ops) end @constprop function Base.reshape( - x::AbstractArray, op1::LocalReshape, ops::Union{LocalReshape,ColonOrEllipsis}... + x::AbstractArray, op1::LocalReshape, ops::AnyOp... ) return Rewrap.reshape(x, (op1, ops...)) end @constprop function Base.reshape( - x::AbstractArray, op1::ColonOrEllipsis, op2::LocalReshape, ops::LocalReshape... + x::AbstractArray, op1::NonLocalOp, op2::LocalReshape, ops::AnyOp... ) return Rewrap.reshape(x, (op1, op2, ops...)) end +@constprop function Base.reshape( + x::AbstractArray, op1::NonLocalOp, op2::NonLocalOp, op3::LocalReshape, ops::AnyOp... +) + return Rewrap.reshape(x, (op1, op2, op3, ops...)) +end + +@constprop function Base.reshape( + x::AbstractArray, op1::NonLocalOp, op2::NonLocalOp, op3::NonLocalOp, op4::LocalReshape, ops::AnyOp... +) + return Rewrap.reshape(x, (op1, op2, op3, op4, ops...)) +end + +@constprop function Base.reshape( + x::AbstractArray, op1::NonLocalOp, op2::NonLocalOp, op3::NonLocalOp, op4::NonLocalOp, op5::LocalReshape, ops::AnyOp... +) + return Rewrap.reshape(x, (op1, op2, op3, op4, op5, ops...)) +end + diff --git a/src/enhanced-base/vec.jl b/src/enhanced-base/vec.jl index efaf044..0e9c01a 100644 --- a/src/enhanced-base/vec.jl +++ b/src/enhanced-base/vec.jl @@ -23,4 +23,4 @@ julia> Rewrap.vec(view(x, 1:2, :)) # not contiguous! 6 ``` """ -vec(x) = reshape(x, Merge(..)) +vec(x) = Rewrap.reshape(x, :) From 25b73a590fbea24444f9ddd6ab466bbeca0e2bc4 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 30 Dec 2025 19:23:22 +0100 Subject: [PATCH 2/2] fix --- src/Reshape/utils.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Reshape/utils.jl b/src/Reshape/utils.jl index 99f484d..aefb4d3 100644 --- a/src/Reshape/utils.jl +++ b/src/Reshape/utils.jl @@ -157,6 +157,10 @@ function _preprocess_op_types(op_types) length(runs) == 1 || throw(ArgumentError("Int/Colon must form a single contiguous sequence")) has_ellipsis && throw(ArgumentError("Cannot mix Ellipsis (..) with Int/Colon sequence")) + start, stop = runs[1] + colon_count = count(i -> op_types[i] <: Colon, start:stop) + colon_count <= 1 || throw(ArgumentError("Split can have at most one Colon in sizes")) + return op_types, :split_sequence => runs[1] end