From a745f35ef60d7cf612f8501b1e1b08c13ef6aa27 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Mon, 1 Dec 2025 23:40:16 +0100 Subject: [PATCH 1/7] Add reconstruction 2D ad 3D modules --- .../network_compression/wavelet_linear.py | 118 ++++++++++++++++-- 1 file changed, 106 insertions(+), 12 deletions(-) diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py index 217836c4..1eb25b46 100644 --- a/examples/network_compression/wavelet_linear.py +++ b/examples/network_compression/wavelet_linear.py @@ -1,14 +1,21 @@ # Originally created by moritz (wolter@cs.uni-bonn.de) # at https://github.com/v0lta/Wavelet-network-compression/blob/master/wavelet_learning/wavelet_linear.py +from abc import ABC, abstractmethod +from typing import Callable, TypeVar, Generic, TypeAlias + import numpy as np import pywt import torch from torch.nn.parameter import Parameter import torch.nn -from ptwt import wavedec, waverec +from ptwt import wavedec, waverec, waverec2, waverec3 +from ptwt.constants import WaveletCoeff1d, WaveletCoeff2d, WaveletCoeffNd from ptwt.wavelets_learnable import WaveletFilter +X = TypeVar("X") +Y = TypeVar("Y") + class WaveletLayer(torch.nn.Module): """ @@ -119,26 +126,113 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return c_tensor.squeeze(1) -class WaveletReconstruction1D(torch.nn.Module): - def __init__(self, scales: int, coefficient_lengths: list[int], wavelet: WaveletFilter) -> None: +Reconstruction: TypeAlias = Callable[[X, WaveletFilter, Y], torch.Tensor] + + +class WaveletReconstruction(torch.nn.Module, Generic[X, Y], ABC): + """Am abstract wavelet reconstruction module.""" + + def __init__( + self, + scales: int, + coefficient_lengths: list[int], + wavelet: WaveletFilter, + func: Reconstruction[X, Y], + axis: Y, + ) -> None: super().__init__() self.scales = scales self.wavelet = wavelet self.coefficient_lengths = coefficient_lengths + self.axis = axis + self.func = func + + @abstractmethod + def get_coefficients(self, x: torch.Tensor) -> X: + """Get coefficients.""" + raise NotImplementedError def forward(self, x: torch.Tensor) -> torch.Tensor: - """Reconstruction from a tensor input. - Args: - x (torch.Tensor): Analysis coefficient tensor. - Returns: - torch.Tensor: Input reconstruction. - """ - coeff_lst = [] + """Reconstruction from a tensor input.""" + coefficients = self.get_coefficients(x) + y = self.func(coefficients, self.wavelet, self.axis) + return y + + +class WaveletReconstruction1d(WaveletReconstruction[WaveletCoeff1d, int]): + """A module for 1D wavelet construction.""" + + def __init__( + self, + scales: int, + coefficient_lengths: list[int], + wavelet: WaveletFilter, + axis: int | None = None, + ) -> None: + super().__init__( + scales=scales, + wavelet=wavelet, + coefficient_lengths=coefficient_lengths, + axis=-1 if axis is None else None, + func=waverec, + ) + + def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff1d: + """Get coefficients for 1D reconstruction.""" + coefficients = [] start = 0 # turn tensor into list for s in range(self.scales + 1): stop = start + self.coefficient_lengths[::-1][s] coeff_lst.append(x[..., start:stop]) + coefficients.append(x[..., start:stop]) start = self.coefficient_lengths[s] - y = waverec(coeff_lst, self.wavelet) - return y + return coefficients + + +class WaveletReconstruction2d(WaveletReconstruction[WaveletCoeff2d, tuple[int, int]]): + """A module for 2D wavelet construction.""" + + def __init__( + self, + scales: int, + coefficient_lengths: list[int], + wavelet: WaveletFilter, + axis: tuple[int, int] | None = None, + ) -> None: + super().__init__( + scales=scales, + wavelet=wavelet, + coefficient_lengths=coefficient_lengths, + axis=(-2, -1) if axis is None else None, + func=waverec2, + ) + + def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff1d: + """Get coefficients for 2D reconstruction.""" + raise NotImplementedError + + +class WaveletReconstruction3d( + WaveletReconstruction[WaveletCoeffNd, tuple[int, int, int]] +): + """A module for 3D wavelet construction.""" + + def __init__( + self, + scales: int, + coefficient_lengths: list[int], + wavelet: WaveletFilter, + axis: tuple[int, int] | None = None, + ) -> None: + super().__init__( + scales=scales, + wavelet=wavelet, + coefficient_lengths=coefficient_lengths, + axis=(-3, -2, -1) if axis is None else None, + func=waverec3, + ) + + def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff1d: + """Get coefficients for 3D reconstruction.""" + raise NotImplementedError From b82992fb8c669f06336431923e28f07c9e71244f Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Mon, 1 Dec 2025 23:44:21 +0100 Subject: [PATCH 2/7] Update wavelet_linear.py --- examples/network_compression/wavelet_linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py index 1eb25b46..4036083f 100644 --- a/examples/network_compression/wavelet_linear.py +++ b/examples/network_compression/wavelet_linear.py @@ -184,7 +184,6 @@ def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff1d: # turn tensor into list for s in range(self.scales + 1): stop = start + self.coefficient_lengths[::-1][s] - coeff_lst.append(x[..., start:stop]) coefficients.append(x[..., start:stop]) start = self.coefficient_lengths[s] return coefficients From 9f2ae9ad35b916c01bc39d4ee14d656381b76483 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Mon, 1 Dec 2025 23:50:09 +0100 Subject: [PATCH 3/7] Update wavelet_linear.py --- examples/network_compression/wavelet_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py index 4036083f..135fd0fd 100644 --- a/examples/network_compression/wavelet_linear.py +++ b/examples/network_compression/wavelet_linear.py @@ -207,7 +207,7 @@ def __init__( func=waverec2, ) - def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff1d: + def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff2d: """Get coefficients for 2D reconstruction.""" raise NotImplementedError @@ -232,6 +232,6 @@ def __init__( func=waverec3, ) - def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff1d: + def get_coefficients(self, x: torch.Tensor) -> WaveletCoeffNd: """Get coefficients for 3D reconstruction.""" raise NotImplementedError From b97e7c5611449d3ba69240e623702a12a068059d Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Mon, 1 Dec 2025 23:50:56 +0100 Subject: [PATCH 4/7] Update wavelet_linear.py --- examples/network_compression/wavelet_linear.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py index 135fd0fd..f1344211 100644 --- a/examples/network_compression/wavelet_linear.py +++ b/examples/network_compression/wavelet_linear.py @@ -179,14 +179,14 @@ def __init__( def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff1d: """Get coefficients for 1D reconstruction.""" - coefficients = [] + coeff_lst = [] start = 0 # turn tensor into list for s in range(self.scales + 1): stop = start + self.coefficient_lengths[::-1][s] - coefficients.append(x[..., start:stop]) + coeff_lst.append(x[..., start:stop]) start = self.coefficient_lengths[s] - return coefficients + return coeff_lst class WaveletReconstruction2d(WaveletReconstruction[WaveletCoeff2d, tuple[int, int]]): From a05dac47005e3f52e9556a964414c9b849e33c9b Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Mon, 1 Dec 2025 23:51:15 +0100 Subject: [PATCH 5/7] Update wavelet_linear.py --- examples/network_compression/wavelet_linear.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py index f1344211..d8a2aaad 100644 --- a/examples/network_compression/wavelet_linear.py +++ b/examples/network_compression/wavelet_linear.py @@ -153,7 +153,12 @@ def get_coefficients(self, x: torch.Tensor) -> X: raise NotImplementedError def forward(self, x: torch.Tensor) -> torch.Tensor: - """Reconstruction from a tensor input.""" + """Reconstruction from a tensor input. + Args: + x (torch.Tensor): Analysis coefficient tensor. + Returns: + torch.Tensor: Input reconstruction. + """ coefficients = self.get_coefficients(x) y = self.func(coefficients, self.wavelet, self.axis) return y From 6c83ff6e90282bf61cc97566832931200cb9486a Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Mon, 1 Dec 2025 23:52:43 +0100 Subject: [PATCH 6/7] Update wavelet_linear.py --- examples/network_compression/wavelet_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py index d8a2aaad..12bd43f9 100644 --- a/examples/network_compression/wavelet_linear.py +++ b/examples/network_compression/wavelet_linear.py @@ -45,7 +45,7 @@ def __init__( wavelet_decomposition = WaveletDecomposition1D(scales, coefficient_lengths, init_wavelet) mul_p = Permutator(wave_depth) mul_g = MMDropoutDiagonal(p_drop, wave_depth) - wavelet_reconstruction = WaveletReconstruction1D(scales, coefficient_lengths, init_wavelet) + wavelet_reconstruction = WaveletReconstruction1d(scales, coefficient_lengths, init_wavelet) mul_s = MMDropoutDiagonal(p_drop, depth) self.sequence = torch.nn.Sequential( From d4e8b838c0cd8df5e45e6c54e31c14861b3b4881 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 2 Dec 2025 08:55:52 +0100 Subject: [PATCH 7/7] Update wavelet_linear.py --- examples/network_compression/wavelet_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py index 12bd43f9..af14de0f 100644 --- a/examples/network_compression/wavelet_linear.py +++ b/examples/network_compression/wavelet_linear.py @@ -227,7 +227,7 @@ def __init__( scales: int, coefficient_lengths: list[int], wavelet: WaveletFilter, - axis: tuple[int, int] | None = None, + axis: tuple[int, int, int] | None = None, ) -> None: super().__init__( scales=scales,