diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py index 217836c4..af14de0f 100644 --- a/examples/network_compression/wavelet_linear.py +++ b/examples/network_compression/wavelet_linear.py @@ -1,14 +1,21 @@ # Originally created by moritz (wolter@cs.uni-bonn.de) # at https://github.com/v0lta/Wavelet-network-compression/blob/master/wavelet_learning/wavelet_linear.py +from abc import ABC, abstractmethod +from typing import Callable, TypeVar, Generic, TypeAlias + import numpy as np import pywt import torch from torch.nn.parameter import Parameter import torch.nn -from ptwt import wavedec, waverec +from ptwt import wavedec, waverec, waverec2, waverec3 +from ptwt.constants import WaveletCoeff1d, WaveletCoeff2d, WaveletCoeffNd from ptwt.wavelets_learnable import WaveletFilter +X = TypeVar("X") +Y = TypeVar("Y") + class WaveletLayer(torch.nn.Module): """ @@ -38,7 +45,7 @@ def __init__( wavelet_decomposition = WaveletDecomposition1D(scales, coefficient_lengths, init_wavelet) mul_p = Permutator(wave_depth) mul_g = MMDropoutDiagonal(p_drop, wave_depth) - wavelet_reconstruction = WaveletReconstruction1D(scales, coefficient_lengths, init_wavelet) + wavelet_reconstruction = WaveletReconstruction1d(scales, coefficient_lengths, init_wavelet) mul_s = MMDropoutDiagonal(p_drop, depth) self.sequence = torch.nn.Sequential( @@ -119,12 +126,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return c_tensor.squeeze(1) -class WaveletReconstruction1D(torch.nn.Module): - def __init__(self, scales: int, coefficient_lengths: list[int], wavelet: WaveletFilter) -> None: +Reconstruction: TypeAlias = Callable[[X, WaveletFilter, Y], torch.Tensor] + + +class WaveletReconstruction(torch.nn.Module, Generic[X, Y], ABC): + """Am abstract wavelet reconstruction module.""" + + def __init__( + self, + scales: int, + coefficient_lengths: list[int], + wavelet: WaveletFilter, + func: Reconstruction[X, Y], + axis: Y, + ) -> None: super().__init__() self.scales = scales self.wavelet = wavelet self.coefficient_lengths = coefficient_lengths + self.axis = axis + self.func = func + + @abstractmethod + def get_coefficients(self, x: torch.Tensor) -> X: + """Get coefficients.""" + raise NotImplementedError def forward(self, x: torch.Tensor) -> torch.Tensor: """Reconstruction from a tensor input. @@ -133,6 +159,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Input reconstruction. """ + coefficients = self.get_coefficients(x) + y = self.func(coefficients, self.wavelet, self.axis) + return y + + +class WaveletReconstruction1d(WaveletReconstruction[WaveletCoeff1d, int]): + """A module for 1D wavelet construction.""" + + def __init__( + self, + scales: int, + coefficient_lengths: list[int], + wavelet: WaveletFilter, + axis: int | None = None, + ) -> None: + super().__init__( + scales=scales, + wavelet=wavelet, + coefficient_lengths=coefficient_lengths, + axis=-1 if axis is None else None, + func=waverec, + ) + + def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff1d: + """Get coefficients for 1D reconstruction.""" coeff_lst = [] start = 0 # turn tensor into list @@ -140,5 +191,52 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: stop = start + self.coefficient_lengths[::-1][s] coeff_lst.append(x[..., start:stop]) start = self.coefficient_lengths[s] - y = waverec(coeff_lst, self.wavelet) - return y + return coeff_lst + + +class WaveletReconstruction2d(WaveletReconstruction[WaveletCoeff2d, tuple[int, int]]): + """A module for 2D wavelet construction.""" + + def __init__( + self, + scales: int, + coefficient_lengths: list[int], + wavelet: WaveletFilter, + axis: tuple[int, int] | None = None, + ) -> None: + super().__init__( + scales=scales, + wavelet=wavelet, + coefficient_lengths=coefficient_lengths, + axis=(-2, -1) if axis is None else None, + func=waverec2, + ) + + def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff2d: + """Get coefficients for 2D reconstruction.""" + raise NotImplementedError + + +class WaveletReconstruction3d( + WaveletReconstruction[WaveletCoeffNd, tuple[int, int, int]] +): + """A module for 3D wavelet construction.""" + + def __init__( + self, + scales: int, + coefficient_lengths: list[int], + wavelet: WaveletFilter, + axis: tuple[int, int, int] | None = None, + ) -> None: + super().__init__( + scales=scales, + wavelet=wavelet, + coefficient_lengths=coefficient_lengths, + axis=(-3, -2, -1) if axis is None else None, + func=waverec3, + ) + + def get_coefficients(self, x: torch.Tensor) -> WaveletCoeffNd: + """Get coefficients for 3D reconstruction.""" + raise NotImplementedError