Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from setuptools import setup

setup(
name='steerable_pytorch',
name='steerable',
version='0.1',
author='Tom Runia',
author_email='tomrunia@gmail.com',
url='https://github.com/tomrunia/PyTorchSteerablePyramid',
description='Complex Steerable Pyramids in PyTorch',
long_description='Fast CPU/CUDA implementation of the Complex Steerable Pyramid in PyTorch.',
license='MIT',
packages=['steerable_pytorch'],
packages=['steerable'],
scripts=[]
)
)
90 changes: 47 additions & 43 deletions steerable/SCFpyr_PyTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@

import numpy as np
import torch
from scipy.misc import factorial

#support for mulitple versions of scipy
try:
from scipy.misc import factorial
except ImportError:
from scipy.special import factorial

import steerable.math_utils as math_utils
pointOp = math_utils.pointOp
Expand All @@ -45,6 +50,8 @@ class SCFpyr_PyTorch(object):

Also looks very similar to the original Python code presented here:
https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py

VD: Modified March 2022 for updated pytorch versions 1.7+ where torch.fft is replaced with torch.fft.fft : See porting guide here: https://github.com/pytorch/pytorch/issues/49637

'''

Expand Down Expand Up @@ -86,6 +93,7 @@ def build(self, im_batch):
# Check whether image size is sufficient for number of levels
if self.height > int(np.floor(np.log2(min(width, height))) - 2):
raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))
print(f'height of pyramid is, {self.height}')

# Prepare a grid
log_rad, angle = math_utils.prepare_grid(height, width)
Expand All @@ -100,35 +108,36 @@ def build(self, im_batch):
hi0mask = pointOp(log_rad, Yrcos, Xrcos)

# Note that we expand dims to support broadcasting later
lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)
lo0mask = torch.from_numpy(lo0mask).float()[None,:,:].to(self.device)
hi0mask = torch.from_numpy(hi0mask).float()[None,:,:].to(self.device)

# Fourier transform (2D) and shifting
batch_dft = torch.rfft(im_batch, signal_ndim=2, onesided=False)
batch_dft = math_utils.batch_fftshift2d(batch_dft)

batch_dft = torch.fft.fft2(im_batch) #updated pytorch
batch_dft = torch.fft.fftshift(batch_dft,dim=(1,2)).real
# Low-pass
lo0dft = batch_dft * lo0mask

# Start recursively building the pyramids
coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1)

# High-pass
hi0dft = batch_dft * hi0mask
hi0 = math_utils.batch_ifftshift2d(hi0dft)
hi0 = torch.ifft(hi0, signal_ndim=2)
hi0_real = torch.unbind(hi0, -1)[0]
hi0 = torch.fft.fftshift(hi0dft,dim=(1,2))
hi0 = torch.fft.ifft2(hi0) #updated pytorch
hi0_real = hi0.real
coeff.insert(0, hi0_real)

return coeff

def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):

if height <= 1:

# Low-pass
lo0 = math_utils.batch_ifftshift2d(lodft)
lo0 = torch.ifft(lo0, signal_ndim=2)
lo0_real = torch.unbind(lo0, -1)[0]
lo0 = torch.fft.fftshift(lodft,dim=(1,2))
lo0 = torch.fft.ifft2(lo0) #new pytorch version with complex support
lo0_real = lo0.real #new pytorch with complex support
coeff = [lo0_real]

else:
Expand All @@ -140,7 +149,7 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
####################################################################

himask = pointOp(log_rad, Yrcos, Xrcos)
himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device)
himask = torch.from_numpy(himask[None,:,:]).float().to(self.device)

order = self.nbands - 1
const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
Expand All @@ -151,23 +160,17 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
for b in range(self.nbands):

anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands)
anglemask = anglemask[None,:,:,None] # for broadcasting
anglemask = anglemask[None,:,:] # for broadcasting
anglemask = torch.from_numpy(anglemask).float().to(self.device)

# Bandpass filtering
banddft = lodft * anglemask * himask

# Now multiply with complex number
# (x+yi)(u+vi) = (xu-yv) + (xv+yu)i
banddft = torch.unbind(banddft, -1)
banddft_real = self.complex_fact_construct.real*banddft[0] - self.complex_fact_construct.imag*banddft[1]
banddft_imag = self.complex_fact_construct.real*banddft[1] + self.complex_fact_construct.imag*banddft[0]
banddft = torch.stack((banddft_real, banddft_imag), -1)

band = math_utils.batch_ifftshift2d(banddft)
band = torch.ifft(band, signal_ndim=2)
band = torch.fft.fftshift(banddft,dim=(1,2))
band = torch.fft.ifft2(band) #new pytorch version
band = torch.stack((band.real,band.imag),dim=-1)
orientations.append(band)

####################################################################
######################## Subsample lowpass #########################
####################################################################
Expand All @@ -184,12 +187,12 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
angle = angle[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]]

# Actual subsampling
lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1],:]
lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]]

# Filtering
YIrcos = np.abs(np.sqrt(1 - Yrcos**2))
lomask = pointOp(log_rad, YIrcos, Xrcos)
lomask = torch.from_numpy(lomask[None,:,:,None]).float()
lomask = torch.from_numpy(lomask[None,:,:]).float()
lomask = lomask.to(self.device)

# Convolution in spatial domain
Expand All @@ -200,7 +203,7 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
####################################################################

coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1)
coeff.insert(0, orientations)
coeff.insert(0,orientations)

return coeff

Expand All @@ -224,28 +227,28 @@ def reconstruct(self, coeff):
hi0mask = pointOp(log_rad, Yrcos, Xrcos)

# Note that we expand dims to support broadcasting later
lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)
lo0mask = torch.from_numpy(lo0mask).float()[None,:,:].to(self.device)
hi0mask = torch.from_numpy(hi0mask).float()[None,:,:].to(self.device)

# Start recursive reconstruction
tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle)

hidft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
hidft = math_utils.batch_fftshift2d(hidft)
hidft = torch.fft.fft2(coeff[0]) #new pytorch
hidft = torch.fft.fftshift(hidft,dim=(1,2)).real # new version of pytorch

outdft = tempdft * lo0mask + hidft * hi0mask

reconstruction = math_utils.batch_ifftshift2d(outdft)
reconstruction = torch.ifft(reconstruction, signal_ndim=2)
reconstruction = torch.unbind(reconstruction, -1)[0] # real
reconstruction = torch.fft.fftshift(outdft,dim=(1,2))
reconstruction = torch.fft.ifft2(reconstruction).real #new pytorch version

return reconstruction

def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):

if len(coeff) == 1:
dft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
dft = math_utils.batch_fftshift2d(dft)
dft = torch.fft.fft2(coeff[0]) #new pytorch
dft = torch.fft.fftshift(dft,dim=(1,2)) #new pytorch
dft = torch.stack((dft.real,dft.imag),dim=-1)
return dft

Xrcos = Xrcos - np.log2(self.scale_factor)
Expand All @@ -255,7 +258,7 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
####################################################################

himask = pointOp(log_rad, Yrcos, Xrcos)
himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device)
himask = torch.from_numpy(himask[None,:,:]).float().to(self.device)

lutsize = 1024
Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize
Expand All @@ -267,11 +270,12 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
for b in range(self.nbands):

anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands)
anglemask = anglemask[None,:,:,None] # for broadcasting
anglemask = anglemask[None,:,:] # for broadcasting
anglemask = torch.from_numpy(anglemask).float().to(self.device)

banddft = torch.fft(coeff[0][b], signal_ndim=2)
banddft = math_utils.batch_fftshift2d(banddft)
banddft = torch.fft.fft2(coeff[0][b]) #new pytorch version
banddft = torch.fft.fftshift(banddft,dim=(1,2)).real #new pytorch version


banddft = banddft * anglemask * himask
banddft = torch.unbind(banddft, -1)
Expand All @@ -297,7 +301,7 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):

# Filtering
lomask = pointOp(nlog_rad, YIrcos, Xrcos)
lomask = torch.from_numpy(lomask[None,:,:,None])
lomask = torch.from_numpy(lomask[None,:,:])
lomask = lomask.float().to(self.device)

################################################################################
Expand All @@ -306,6 +310,6 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle)

resdft = torch.zeros_like(coeff[0][0]).to(self.device)
resdft[:,lostart[0]:loend[0], lostart[1]:loend[1],:] = nresdft * lomask
resdft[:,lostart[0]:loend[0], lostart[1]:loend[1]] = nresdft * lomask

return resdft + orientdft
2 changes: 1 addition & 1 deletion steerable/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
################################################################################
################################################################################

def roll_n(X, axis, n):
def roll_n(X, axis, n):
f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()))
b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()))
front = X[f_idx]
Expand Down
38 changes: 38 additions & 0 deletions steerable/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,44 @@ def show_image_batch(im_batch):
plt.show()
return im_batch

def vectorize_batch(coeff_batch, real=True):
'''
Given the batched Complex Steerable Pyramid, create a vectorized version which is 2D, batch by an ordered vector containing each pyramid, flattened. Store this as an all-real pyramid with complex pyramid split into two, or as a complex pyramid.

Args:
coeff_batch (list): list containing low-pass, high-pass and pyr levels
real (bool, optional): Store pyramid as all real values, with complex pyramid stored as vectorized real, imaginary values

Returns:
vector_pyr (tensor): tensor of dimensions (batch, vectorsize), where each vector contains the entire steerable pyramid, collapsed.

'''
if not isinstance(coeff_batch, list):
raise ValueError('Batch of coefficients must be a list')

#first element is high pass
vector_pyr = torch.flatten(coeff_batch[0],start_dim=1)

#loop through levels and stack them on top
for i in range(1,len(coeff_batch)-1):
#loop through orientations within each level
for orientation in coeff_batch[i]:
#if we want a real only pyramid, just use already separated representation
if(real):
print(orientation.shape)
level_vector = torch.flatten(orientation,start_dim=1)
#if we want a complex pyramid, need to create it with real and imaginary parts
else:
level_vector = torch.flatten(orientation[:,:,:,0] + 1j*orientation[:,:,:,1],start_dim=1)
#Stack this level vector on our pyramid
vector_pyr = torch.concat((vector_pyr,level_vector),dim=-1)

#last element is low pass. Stack on top
vector_pyr = torch.concat((vector_pyr,torch.flatten(coeff_batch[-1],start_dim=1)),dim=-1)

#return final batched tensor
return(vector_pyr)

def extract_from_batch(coeff_batch, example_idx=0):
'''
Given the batched Complex Steerable Pyramid, extract the coefficients
Expand Down