diff --git a/setup.py b/setup.py index abf0d9e..02767bc 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup setup( - name='steerable_pytorch', + name='steerable', version='0.1', author='Tom Runia', author_email='tomrunia@gmail.com', @@ -10,6 +10,6 @@ description='Complex Steerable Pyramids in PyTorch', long_description='Fast CPU/CUDA implementation of the Complex Steerable Pyramid in PyTorch.', license='MIT', - packages=['steerable_pytorch'], + packages=['steerable'], scripts=[] -) \ No newline at end of file +) diff --git a/steerable/SCFpyr_PyTorch.py b/steerable/SCFpyr_PyTorch.py index ec46e63..c50e2b9 100644 --- a/steerable/SCFpyr_PyTorch.py +++ b/steerable/SCFpyr_PyTorch.py @@ -18,7 +18,12 @@ import numpy as np import torch -from scipy.misc import factorial + +#support for mulitple versions of scipy +try: + from scipy.misc import factorial +except ImportError: + from scipy.special import factorial import steerable.math_utils as math_utils pointOp = math_utils.pointOp @@ -45,6 +50,8 @@ class SCFpyr_PyTorch(object): Also looks very similar to the original Python code presented here: https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py + + VD: Modified March 2022 for updated pytorch versions 1.7+ where torch.fft is replaced with torch.fft.fft : See porting guide here: https://github.com/pytorch/pytorch/issues/49637 ''' @@ -86,6 +93,7 @@ def build(self, im_batch): # Check whether image size is sufficient for number of levels if self.height > int(np.floor(np.log2(min(width, height))) - 2): raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height)) + print(f'height of pyramid is, {self.height}') # Prepare a grid log_rad, angle = math_utils.prepare_grid(height, width) @@ -100,25 +108,26 @@ def build(self, im_batch): hi0mask = pointOp(log_rad, Yrcos, Xrcos) # Note that we expand dims to support broadcasting later - lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device) - hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device) + lo0mask = torch.from_numpy(lo0mask).float()[None,:,:].to(self.device) + hi0mask = torch.from_numpy(hi0mask).float()[None,:,:].to(self.device) # Fourier transform (2D) and shifting - batch_dft = torch.rfft(im_batch, signal_ndim=2, onesided=False) - batch_dft = math_utils.batch_fftshift2d(batch_dft) - + batch_dft = torch.fft.fft2(im_batch) #updated pytorch + batch_dft = torch.fft.fftshift(batch_dft,dim=(1,2)).real + # Low-pass lo0dft = batch_dft * lo0mask # Start recursively building the pyramids coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1) - + # High-pass hi0dft = batch_dft * hi0mask - hi0 = math_utils.batch_ifftshift2d(hi0dft) - hi0 = torch.ifft(hi0, signal_ndim=2) - hi0_real = torch.unbind(hi0, -1)[0] + hi0 = torch.fft.fftshift(hi0dft,dim=(1,2)) + hi0 = torch.fft.ifft2(hi0) #updated pytorch + hi0_real = hi0.real coeff.insert(0, hi0_real) + return coeff def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): @@ -126,9 +135,9 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): if height <= 1: # Low-pass - lo0 = math_utils.batch_ifftshift2d(lodft) - lo0 = torch.ifft(lo0, signal_ndim=2) - lo0_real = torch.unbind(lo0, -1)[0] + lo0 = torch.fft.fftshift(lodft,dim=(1,2)) + lo0 = torch.fft.ifft2(lo0) #new pytorch version with complex support + lo0_real = lo0.real #new pytorch with complex support coeff = [lo0_real] else: @@ -140,7 +149,7 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): #################################################################### himask = pointOp(log_rad, Yrcos, Xrcos) - himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device) + himask = torch.from_numpy(himask[None,:,:]).float().to(self.device) order = self.nbands - 1 const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order)) @@ -151,23 +160,17 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): for b in range(self.nbands): anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands) - anglemask = anglemask[None,:,:,None] # for broadcasting + anglemask = anglemask[None,:,:] # for broadcasting anglemask = torch.from_numpy(anglemask).float().to(self.device) # Bandpass filtering banddft = lodft * anglemask * himask - # Now multiply with complex number - # (x+yi)(u+vi) = (xu-yv) + (xv+yu)i - banddft = torch.unbind(banddft, -1) - banddft_real = self.complex_fact_construct.real*banddft[0] - self.complex_fact_construct.imag*banddft[1] - banddft_imag = self.complex_fact_construct.real*banddft[1] + self.complex_fact_construct.imag*banddft[0] - banddft = torch.stack((banddft_real, banddft_imag), -1) - - band = math_utils.batch_ifftshift2d(banddft) - band = torch.ifft(band, signal_ndim=2) + band = torch.fft.fftshift(banddft,dim=(1,2)) + band = torch.fft.ifft2(band) #new pytorch version + band = torch.stack((band.real,band.imag),dim=-1) orientations.append(band) - + #################################################################### ######################## Subsample lowpass ######################### #################################################################### @@ -184,12 +187,12 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): angle = angle[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]] # Actual subsampling - lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1],:] + lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]] # Filtering YIrcos = np.abs(np.sqrt(1 - Yrcos**2)) lomask = pointOp(log_rad, YIrcos, Xrcos) - lomask = torch.from_numpy(lomask[None,:,:,None]).float() + lomask = torch.from_numpy(lomask[None,:,:]).float() lomask = lomask.to(self.device) # Convolution in spatial domain @@ -200,7 +203,7 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): #################################################################### coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1) - coeff.insert(0, orientations) + coeff.insert(0,orientations) return coeff @@ -224,28 +227,28 @@ def reconstruct(self, coeff): hi0mask = pointOp(log_rad, Yrcos, Xrcos) # Note that we expand dims to support broadcasting later - lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device) - hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device) + lo0mask = torch.from_numpy(lo0mask).float()[None,:,:].to(self.device) + hi0mask = torch.from_numpy(hi0mask).float()[None,:,:].to(self.device) # Start recursive reconstruction tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle) - hidft = torch.rfft(coeff[0], signal_ndim=2, onesided=False) - hidft = math_utils.batch_fftshift2d(hidft) + hidft = torch.fft.fft2(coeff[0]) #new pytorch + hidft = torch.fft.fftshift(hidft,dim=(1,2)).real # new version of pytorch outdft = tempdft * lo0mask + hidft * hi0mask - reconstruction = math_utils.batch_ifftshift2d(outdft) - reconstruction = torch.ifft(reconstruction, signal_ndim=2) - reconstruction = torch.unbind(reconstruction, -1)[0] # real + reconstruction = torch.fft.fftshift(outdft,dim=(1,2)) + reconstruction = torch.fft.ifft2(reconstruction).real #new pytorch version return reconstruction def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): if len(coeff) == 1: - dft = torch.rfft(coeff[0], signal_ndim=2, onesided=False) - dft = math_utils.batch_fftshift2d(dft) + dft = torch.fft.fft2(coeff[0]) #new pytorch + dft = torch.fft.fftshift(dft,dim=(1,2)) #new pytorch + dft = torch.stack((dft.real,dft.imag),dim=-1) return dft Xrcos = Xrcos - np.log2(self.scale_factor) @@ -255,7 +258,7 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): #################################################################### himask = pointOp(log_rad, Yrcos, Xrcos) - himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device) + himask = torch.from_numpy(himask[None,:,:]).float().to(self.device) lutsize = 1024 Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize @@ -267,11 +270,12 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): for b in range(self.nbands): anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands) - anglemask = anglemask[None,:,:,None] # for broadcasting + anglemask = anglemask[None,:,:] # for broadcasting anglemask = torch.from_numpy(anglemask).float().to(self.device) - banddft = torch.fft(coeff[0][b], signal_ndim=2) - banddft = math_utils.batch_fftshift2d(banddft) + banddft = torch.fft.fft2(coeff[0][b]) #new pytorch version + banddft = torch.fft.fftshift(banddft,dim=(1,2)).real #new pytorch version + banddft = banddft * anglemask * himask banddft = torch.unbind(banddft, -1) @@ -297,7 +301,7 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): # Filtering lomask = pointOp(nlog_rad, YIrcos, Xrcos) - lomask = torch.from_numpy(lomask[None,:,:,None]) + lomask = torch.from_numpy(lomask[None,:,:]) lomask = lomask.float().to(self.device) ################################################################################ @@ -306,6 +310,6 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle) resdft = torch.zeros_like(coeff[0][0]).to(self.device) - resdft[:,lostart[0]:loend[0], lostart[1]:loend[1],:] = nresdft * lomask + resdft[:,lostart[0]:loend[0], lostart[1]:loend[1]] = nresdft * lomask return resdft + orientdft diff --git a/steerable/math_utils.py b/steerable/math_utils.py index 688d04b..129dd49 100644 --- a/steerable/math_utils.py +++ b/steerable/math_utils.py @@ -22,7 +22,7 @@ ################################################################################ ################################################################################ -def roll_n(X, axis, n): +def roll_n(X, axis, n): f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim())) b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim())) front = X[f_idx] diff --git a/steerable/utils.py b/steerable/utils.py index 9e74c95..7e75672 100644 --- a/steerable/utils.py +++ b/steerable/utils.py @@ -69,6 +69,44 @@ def show_image_batch(im_batch): plt.show() return im_batch +def vectorize_batch(coeff_batch, real=True): + ''' + Given the batched Complex Steerable Pyramid, create a vectorized version which is 2D, batch by an ordered vector containing each pyramid, flattened. Store this as an all-real pyramid with complex pyramid split into two, or as a complex pyramid. + + Args: + coeff_batch (list): list containing low-pass, high-pass and pyr levels + real (bool, optional): Store pyramid as all real values, with complex pyramid stored as vectorized real, imaginary values + + Returns: + vector_pyr (tensor): tensor of dimensions (batch, vectorsize), where each vector contains the entire steerable pyramid, collapsed. + + ''' + if not isinstance(coeff_batch, list): + raise ValueError('Batch of coefficients must be a list') + + #first element is high pass + vector_pyr = torch.flatten(coeff_batch[0],start_dim=1) + + #loop through levels and stack them on top + for i in range(1,len(coeff_batch)-1): + #loop through orientations within each level + for orientation in coeff_batch[i]: + #if we want a real only pyramid, just use already separated representation + if(real): + print(orientation.shape) + level_vector = torch.flatten(orientation,start_dim=1) + #if we want a complex pyramid, need to create it with real and imaginary parts + else: + level_vector = torch.flatten(orientation[:,:,:,0] + 1j*orientation[:,:,:,1],start_dim=1) + #Stack this level vector on our pyramid + vector_pyr = torch.concat((vector_pyr,level_vector),dim=-1) + + #last element is low pass. Stack on top + vector_pyr = torch.concat((vector_pyr,torch.flatten(coeff_batch[-1],start_dim=1)),dim=-1) + + #return final batched tensor + return(vector_pyr) + def extract_from_batch(coeff_batch, example_idx=0): ''' Given the batched Complex Steerable Pyramid, extract the coefficients