Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
description='Complex Steerable Pyramids in PyTorch',
long_description='Fast CPU/CUDA implementation of the Complex Steerable Pyramid in PyTorch.',
license='MIT',
packages=['steerable_pytorch'],
packages=['steerable'],
scripts=[]
)
)
90 changes: 47 additions & 43 deletions steerable/SCFpyr_PyTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@

import numpy as np
import torch
from scipy.misc import factorial

#support for mulitple versions of scipy
try:
from scipy.misc import factorial
except ImportError:
from scipy.special import factorial

import steerable.math_utils as math_utils
pointOp = math_utils.pointOp
Expand All @@ -45,6 +50,8 @@ class SCFpyr_PyTorch(object):

Also looks very similar to the original Python code presented here:
https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py

VD: Modified March 2022 for updated pytorch versions 1.7+ where torch.fft is replaced with torch.fft.fft : See porting guide here: https://github.com/pytorch/pytorch/issues/49637

'''

Expand Down Expand Up @@ -86,6 +93,7 @@ def build(self, im_batch):
# Check whether image size is sufficient for number of levels
if self.height > int(np.floor(np.log2(min(width, height))) - 2):
raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))
print(f'height of pyramid is, {self.height}')

# Prepare a grid
log_rad, angle = math_utils.prepare_grid(height, width)
Expand All @@ -100,35 +108,36 @@ def build(self, im_batch):
hi0mask = pointOp(log_rad, Yrcos, Xrcos)

# Note that we expand dims to support broadcasting later
lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)
lo0mask = torch.from_numpy(lo0mask).float()[None,:,:].to(self.device)
hi0mask = torch.from_numpy(hi0mask).float()[None,:,:].to(self.device)

# Fourier transform (2D) and shifting
batch_dft = torch.rfft(im_batch, signal_ndim=2, onesided=False)
batch_dft = math_utils.batch_fftshift2d(batch_dft)

batch_dft = torch.fft.fft2(im_batch) #updated pytorch
batch_dft = torch.fft.fftshift(batch_dft,dim=(1,2)).real
# Low-pass
lo0dft = batch_dft * lo0mask

# Start recursively building the pyramids
coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1)

# High-pass
hi0dft = batch_dft * hi0mask
hi0 = math_utils.batch_ifftshift2d(hi0dft)
hi0 = torch.ifft(hi0, signal_ndim=2)
hi0_real = torch.unbind(hi0, -1)[0]
hi0 = torch.fft.fftshift(hi0dft,dim=(1,2))
hi0 = torch.fft.ifft2(hi0) #updated pytorch
hi0_real = hi0.real
coeff.insert(0, hi0_real)

return coeff

def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):

if height <= 1:

# Low-pass
lo0 = math_utils.batch_ifftshift2d(lodft)
lo0 = torch.ifft(lo0, signal_ndim=2)
lo0_real = torch.unbind(lo0, -1)[0]
lo0 = torch.fft.fftshift(lodft,dim=(1,2))
lo0 = torch.fft.ifft2(lo0) #new pytorch version with complex support
lo0_real = lo0.real #new pytorch with complex support
coeff = [lo0_real]

else:
Expand All @@ -140,7 +149,7 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
####################################################################

himask = pointOp(log_rad, Yrcos, Xrcos)
himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device)
himask = torch.from_numpy(himask[None,:,:]).float().to(self.device)

order = self.nbands - 1
const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
Expand All @@ -151,23 +160,17 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
for b in range(self.nbands):

anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands)
anglemask = anglemask[None,:,:,None] # for broadcasting
anglemask = anglemask[None,:,:] # for broadcasting
anglemask = torch.from_numpy(anglemask).float().to(self.device)

# Bandpass filtering
banddft = lodft * anglemask * himask
Copy link

@ChairManMeow-SY ChairManMeow-SY May 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should line 167 be:

banddft=(-1j)**order*lodft *anglemask * himask ?

As the code here


# Now multiply with complex number
# (x+yi)(u+vi) = (xu-yv) + (xv+yu)i
banddft = torch.unbind(banddft, -1)
banddft_real = self.complex_fact_construct.real*banddft[0] - self.complex_fact_construct.imag*banddft[1]
banddft_imag = self.complex_fact_construct.real*banddft[1] + self.complex_fact_construct.imag*banddft[0]
banddft = torch.stack((banddft_real, banddft_imag), -1)

band = math_utils.batch_ifftshift2d(banddft)
band = torch.ifft(band, signal_ndim=2)
band = torch.fft.fftshift(banddft,dim=(1,2))
band = torch.fft.ifft2(band) #new pytorch version
band = torch.stack((band.real,band.imag),dim=-1)
orientations.append(band)

####################################################################
######################## Subsample lowpass #########################
####################################################################
Expand All @@ -184,12 +187,12 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
angle = angle[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]]

# Actual subsampling
lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1],:]
lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]]

# Filtering
YIrcos = np.abs(np.sqrt(1 - Yrcos**2))
lomask = pointOp(log_rad, YIrcos, Xrcos)
lomask = torch.from_numpy(lomask[None,:,:,None]).float()
lomask = torch.from_numpy(lomask[None,:,:]).float()
lomask = lomask.to(self.device)

# Convolution in spatial domain
Expand All @@ -200,7 +203,7 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
####################################################################

coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1)
coeff.insert(0, orientations)
coeff.insert(0,orientations)

return coeff

Expand All @@ -224,28 +227,28 @@ def reconstruct(self, coeff):
hi0mask = pointOp(log_rad, Yrcos, Xrcos)

# Note that we expand dims to support broadcasting later
lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)
lo0mask = torch.from_numpy(lo0mask).float()[None,:,:].to(self.device)
hi0mask = torch.from_numpy(hi0mask).float()[None,:,:].to(self.device)

# Start recursive reconstruction
tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle)

hidft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
hidft = math_utils.batch_fftshift2d(hidft)
hidft = torch.fft.fft2(coeff[0]) #new pytorch
hidft = torch.fft.fftshift(hidft,dim=(1,2)).real # new version of pytorch

outdft = tempdft * lo0mask + hidft * hi0mask

reconstruction = math_utils.batch_ifftshift2d(outdft)
reconstruction = torch.ifft(reconstruction, signal_ndim=2)
reconstruction = torch.unbind(reconstruction, -1)[0] # real
reconstruction = torch.fft.fftshift(outdft,dim=(1,2))
reconstruction = torch.fft.ifft2(reconstruction).real #new pytorch version

return reconstruction

def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):

if len(coeff) == 1:
dft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
dft = math_utils.batch_fftshift2d(dft)
dft = torch.fft.fft2(coeff[0]) #new pytorch
dft = torch.fft.fftshift(dft,dim=(1,2)) #new pytorch
dft = torch.stack((dft.real,dft.imag),dim=-1)
return dft

Xrcos = Xrcos - np.log2(self.scale_factor)
Expand All @@ -255,7 +258,7 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
####################################################################

himask = pointOp(log_rad, Yrcos, Xrcos)
himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device)
himask = torch.from_numpy(himask[None,:,:]).float().to(self.device)

lutsize = 1024
Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize
Expand All @@ -267,11 +270,12 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
for b in range(self.nbands):

anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands)
anglemask = anglemask[None,:,:,None] # for broadcasting
anglemask = anglemask[None,:,:] # for broadcasting
anglemask = torch.from_numpy(anglemask).float().to(self.device)

banddft = torch.fft(coeff[0][b], signal_ndim=2)
banddft = math_utils.batch_fftshift2d(banddft)
banddft = torch.fft.fft2(coeff[0][b]) #new pytorch version
banddft = torch.fft.fftshift(banddft,dim=(1,2)).real #new pytorch version


banddft = banddft * anglemask * himask
banddft = torch.unbind(banddft, -1)
Expand All @@ -297,7 +301,7 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):

# Filtering
lomask = pointOp(nlog_rad, YIrcos, Xrcos)
lomask = torch.from_numpy(lomask[None,:,:,None])
lomask = torch.from_numpy(lomask[None,:,:])
lomask = lomask.float().to(self.device)

################################################################################
Expand All @@ -306,6 +310,6 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle)

resdft = torch.zeros_like(coeff[0][0]).to(self.device)
resdft[:,lostart[0]:loend[0], lostart[1]:loend[1],:] = nresdft * lomask
resdft[:,lostart[0]:loend[0], lostart[1]:loend[1]] = nresdft * lomask

return resdft + orientdft
2 changes: 1 addition & 1 deletion steerable/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
################################################################################
################################################################################

def roll_n(X, axis, n):
def roll_n(X, axis, n):
f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()))
b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()))
front = X[f_idx]
Expand Down