Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Reshape/LocalReshape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ Merge(N::IntOrEllipsis) = Merge{N}()
"""
Split(N, sizes)
Split{N}(sizes...)
Split(sizes...)
Split(sizes)

Split the first `N` dimensions into `M` dimensions, with sizes given by a tuple
Split the first `N` (1, if not provided) dimensions into `M` dimensions, with sizes given by a tuple
of integers and at most one colon (`:`).
This can be interpreted as a local reshape operation on the `N` dimensions,
and doesn't have many of the compile time guarantees of the other operations.
Expand Down Expand Up @@ -82,6 +83,8 @@ Split(N::IntOrEllipsis, sizes::T) where {M,T<:NTuple{M,IntOrColon}} =

Split{N}(sizes::IntOrColon...) where N = Split(N, sizes)

Split(sizes::IntOrColon...) = Split(1, sizes)

Split(sizes::Tuple{Vararg{IntOrColon}}) = Split(1, sizes)

"""
Expand Down
54 changes: 52 additions & 2 deletions src/Reshape/Reshape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,57 @@ struct Reshape{OpsT,N,M} <: GlobalAxisOp{N,M}
end

@generated function Reshape(ops::OpsT, ::Val{N}) where {OpsT<:Tuple,N}
op_types = OpsT.parameters
raw_op_types = OpsT.parameters
_, preprocess_result = _preprocess_op_types(raw_op_types)

op_types = Any[]
op_exprs = Any[]

if preprocess_result === nothing
for (k, opT) in enumerate(raw_op_types)
if is_ellipsis(opT)
push!(op_types, Keep{..})
push!(op_exprs, :(Keep(..)))
else
push!(op_types, opT)
push!(op_exprs, :(ops[$k]))
end
end
elseif preprocess_result.first === :lone_colon
idx = preprocess_result.second
for (k, opT) in enumerate(raw_op_types)
if k == idx
push!(op_types, Merge{..})
push!(op_exprs, :(Merge(..)))
elseif is_ellipsis(opT)
push!(op_types, Keep{..})
push!(op_exprs, :(Keep(..)))
else
push!(op_types, opT)
push!(op_exprs, :(ops[$k]))
end
end
else
(start, stop) = preprocess_result.second
split_type = _build_split_type(raw_op_types, start, stop)
split_expr = _build_split_expr(start, stop)

for (k, opT) in enumerate(raw_op_types)
if k == start
push!(op_types, split_type)
push!(op_exprs, split_expr)
elseif k > start && k <= stop
continue
elseif is_ellipsis(opT)
push!(op_types, Keep{..})
push!(op_exprs, :(Keep(..)))
else
push!(op_types, opT)
push!(op_exprs, :(ops[$k]))
end
end
end

_ops_has_ellipsis(op_types) || nothing

ellipsis_seen = false
Expand Down Expand Up @@ -70,7 +120,7 @@ end
T0 = opT.parameters[3]
M0 = opT.parameters[2]
mout == M0 || throw(ArgumentError("Split output rank cannot be ellipsis-resolved"))
push!(resolved_op_exprs, :(Split($nin, ops[$k].sizes)))
push!(resolved_op_exprs, :(Split($nin, $(op_exprs[k]).sizes)))
push!(resolved_op_types, Split{nin, M0, T0})

elseif opT <: Resqueeze
Expand Down
61 changes: 56 additions & 5 deletions src/Reshape/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,14 @@ end
return Expr(:call, :*, (Expr(:call, :size, xsym, in_dim + j) for j in 1:n_in)...)
end

function _ops_has_ellipsis(op_types::Tuple)
function _ops_has_ellipsis(op_types)
for op in op_types
(is_ellipsis(ndims_in(op)) || is_ellipsis(ndims_out(op))) && return true
end
return false
end

_ops_has_ellipsis(op_types::Core.SimpleVector) = _ops_has_ellipsis(Tuple(op_types))

function _ops_total_in(op_types::Tuple)
function _ops_total_in(op_types)
total = 0
for op in op_types
n = ndims_in(op)
Expand All @@ -123,4 +121,57 @@ function _ops_total_in(op_types::Tuple)
return total
end

_ops_total_in(op_types::Core.SimpleVector) = _ops_total_in(Tuple(op_types))
function _is_intcolon(T::Type)
T <: Int && return true
T <: Colon && return true
return false
end

function _preprocess_op_types(op_types)
n = length(op_types)
n == 0 && return op_types, nothing

runs = Tuple{Int,Int}[]
i = 1
while i <= n
if _is_intcolon(op_types[i])
start = i
while i <= n && _is_intcolon(op_types[i])
i += 1
end
push!(runs, (start, i - 1))
else
i += 1
end
end

has_ellipsis = any(T -> is_ellipsis(T), op_types)

isempty(runs) && return op_types, nothing

if length(runs) == 1 && runs[1][1] == runs[1][2] && op_types[runs[1][1]] <: Colon
has_ellipsis && throw(ArgumentError("At most one Colon or Ellipsis is allowed"))
return op_types, :lone_colon => runs[1][1]
end

length(runs) == 1 || throw(ArgumentError("Int/Colon must form a single contiguous sequence"))
has_ellipsis && throw(ArgumentError("Cannot mix Ellipsis (..) with Int/Colon sequence"))

start, stop = runs[1]
colon_count = count(i -> op_types[i] <: Colon, start:stop)
colon_count <= 1 || throw(ArgumentError("Split can have at most one Colon in sizes"))

return op_types, :split_sequence => runs[1]
end

function _build_split_type(op_types, start::Int, stop::Int)
M = stop - start + 1
size_types = Any[op_types[i] for i in start:stop]
T = Tuple{size_types...}
return Split{.., M, T}
end

function _build_split_expr(start::Int, stop::Int)
sizes_expr = Expr(:tuple, [:(ops[$i]) for i in start:stop]...)
return :(Split(.., $sizes_expr))
end
5 changes: 4 additions & 1 deletion src/enhanced-base/dropdims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

Drop the specified dimensions from the array `x`.

!!! note
This method may not be type stable if `dims` cannot be constant-propagated.

```jldoctest
julia> x = [1 3 5; 2 4 6;;;]
2×3×1 Array{Int64, 3}:
Expand All @@ -27,5 +30,5 @@ julia> Rewrap.dropdims(y; dims=3)
) where N
dims′ = dims isa Int ? (dims,) : dims
ops = ntuple(i -> i in dims′ ? Squeeze() : Keep(), N)
return reshape(x, ops)
return Rewrap.reshape(x, ops)
end
61 changes: 34 additions & 27 deletions src/enhanced-base/reshape.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
const AnyOp = Union{LocalReshape,ColonOrEllipsis}
const AnyOp = Union{LocalReshape,Int,Colon,Ellipsis}

"""
Rewrap.reshape(x, ops...)
Rewrap.reshape(x, ops::Tuple)

Reshape the array `x` using the given operations, which can include
`:` (Base.Colon) and `..` (EllipsisNotation.Ellipsis).
`:` (Base.Colon), `..` (EllipsisNotation.Ellipsis), and integers.

Integers and colons can form a single contiguous sequence that becomes a `Split(.., sizes)`.

See also [`Base.reshape`](@ref).

Expand Down Expand Up @@ -35,38 +37,26 @@ function reshape end

Rewrap.reshape(x::AbstractArray, args...) = Base.reshape(x, args...)

@constprop function Rewrap.reshape(x::AbstractArray, ops::Tuple{LocalReshape,Vararg{LocalReshape}})
@constprop function Rewrap.reshape(x::AbstractArray, ops::Tuple{AnyOp,Vararg{AnyOp}})
r = Reshape(ops, Val(ndims(x)))
r(x)
end

@constprop function Rewrap.reshape(x::AbstractArray, ops::Tuple{AnyOp,Vararg{AnyOp}})
count(op -> op isa ColonOrEllipsis, ops) > 1 && throw(ArgumentError("At most one Colon or Ellipsis is allowed"))
ops′ = map(ops) do op
if op isa Colon
Merge(..)
elseif op isa Ellipsis
Keep(..)
else
op
end
end
return Rewrap.reshape(x, ops′)
end

@constprop function Rewrap.reshape(x::AbstractArray, op1::AnyOp, ops::AnyOp...)
return Rewrap.reshape(x, (op1, ops...))
end

## Base.reshape

const NonLocalOp = Union{Int,Colon,Ellipsis}

"""
Base.reshape(x, ops...)
Base.reshape(x, ops::Tuple)

Reshape the array `x` using the given operations, which can include
`:` (Base.Colon) and `..` (EllipsisNotation.Ellipsis), but there
must be at least one `LocalReshape`.
`:` (Base.Colon), `..` (EllipsisNotation.Ellipsis), and integers,
but there must be at least one `LocalReshape`.

See also [`Rewrap.reshape`](@ref).

Expand Down Expand Up @@ -94,30 +84,47 @@ julia> reshape(view(rand(2, 3), :, 1:2), Merge(..)) |> summary # can not use a s
"""
Base.reshape

@constprop function Base.reshape(x::AbstractArray, ops::Tuple{LocalReshape,Vararg{LocalReshape}})
return Rewrap.reshape(x, ops)
end

@constprop function Base.reshape(
x::AbstractArray,
ops::Union{
Tuple{ColonOrEllipsis,LocalReshape,Vararg{LocalReshape}},
Tuple{LocalReshape,Vararg{AnyOp}}
Tuple{LocalReshape,Vararg{AnyOp}},
Tuple{NonLocalOp,LocalReshape,Vararg{AnyOp}},
Tuple{NonLocalOp,NonLocalOp,LocalReshape,Vararg{AnyOp}},
Tuple{NonLocalOp,NonLocalOp,NonLocalOp,LocalReshape,Vararg{AnyOp}},
Tuple{NonLocalOp,NonLocalOp,NonLocalOp,NonLocalOp,LocalReshape,Vararg{AnyOp}}
}
)
return Rewrap.reshape(x, ops)
end

@constprop function Base.reshape(
x::AbstractArray, op1::LocalReshape, ops::Union{LocalReshape,ColonOrEllipsis}...
x::AbstractArray, op1::LocalReshape, ops::AnyOp...
)
return Rewrap.reshape(x, (op1, ops...))
end

@constprop function Base.reshape(
x::AbstractArray, op1::ColonOrEllipsis, op2::LocalReshape, ops::LocalReshape...
x::AbstractArray, op1::NonLocalOp, op2::LocalReshape, ops::AnyOp...
)
return Rewrap.reshape(x, (op1, op2, ops...))
end

@constprop function Base.reshape(
x::AbstractArray, op1::NonLocalOp, op2::NonLocalOp, op3::LocalReshape, ops::AnyOp...
)
return Rewrap.reshape(x, (op1, op2, op3, ops...))
end

@constprop function Base.reshape(
x::AbstractArray, op1::NonLocalOp, op2::NonLocalOp, op3::NonLocalOp, op4::LocalReshape, ops::AnyOp...
)
return Rewrap.reshape(x, (op1, op2, op3, op4, ops...))
end

@constprop function Base.reshape(
x::AbstractArray, op1::NonLocalOp, op2::NonLocalOp, op3::NonLocalOp, op4::NonLocalOp, op5::LocalReshape, ops::AnyOp...
)
return Rewrap.reshape(x, (op1, op2, op3, op4, op5, ops...))
end


2 changes: 1 addition & 1 deletion src/enhanced-base/vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ julia> Rewrap.vec(view(x, 1:2, :)) # not contiguous!
6
```
"""
vec(x) = reshape(x, Merge(..))
vec(x) = Rewrap.reshape(x, :)