Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 104 additions & 6 deletions examples/network_compression/wavelet_linear.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
# Originally created by moritz (wolter@cs.uni-bonn.de)
# at https://github.com/v0lta/Wavelet-network-compression/blob/master/wavelet_learning/wavelet_linear.py

from abc import ABC, abstractmethod
from typing import Callable, TypeVar, Generic, TypeAlias

import numpy as np
import pywt
import torch
from torch.nn.parameter import Parameter
import torch.nn
from ptwt import wavedec, waverec
from ptwt import wavedec, waverec, waverec2, waverec3
from ptwt.constants import WaveletCoeff1d, WaveletCoeff2d, WaveletCoeffNd
from ptwt.wavelets_learnable import WaveletFilter

X = TypeVar("X")
Y = TypeVar("Y")


class WaveletLayer(torch.nn.Module):
"""
Expand Down Expand Up @@ -38,7 +45,7 @@ def __init__(
wavelet_decomposition = WaveletDecomposition1D(scales, coefficient_lengths, init_wavelet)
mul_p = Permutator(wave_depth)
mul_g = MMDropoutDiagonal(p_drop, wave_depth)
wavelet_reconstruction = WaveletReconstruction1D(scales, coefficient_lengths, init_wavelet)
wavelet_reconstruction = WaveletReconstruction1d(scales, coefficient_lengths, init_wavelet)
mul_s = MMDropoutDiagonal(p_drop, depth)

self.sequence = torch.nn.Sequential(
Expand Down Expand Up @@ -119,12 +126,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return c_tensor.squeeze(1)


class WaveletReconstruction1D(torch.nn.Module):
def __init__(self, scales: int, coefficient_lengths: list[int], wavelet: WaveletFilter) -> None:
Reconstruction: TypeAlias = Callable[[X, WaveletFilter, Y], torch.Tensor]


class WaveletReconstruction(torch.nn.Module, Generic[X, Y], ABC):
"""Am abstract wavelet reconstruction module."""

def __init__(
self,
scales: int,
coefficient_lengths: list[int],
wavelet: WaveletFilter,
func: Reconstruction[X, Y],
axis: Y,
) -> None:
super().__init__()
self.scales = scales
self.wavelet = wavelet
self.coefficient_lengths = coefficient_lengths
self.axis = axis
self.func = func

@abstractmethod
def get_coefficients(self, x: torch.Tensor) -> X:
"""Get coefficients."""
raise NotImplementedError

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Reconstruction from a tensor input.
Expand All @@ -133,12 +159,84 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Input reconstruction.
"""
coefficients = self.get_coefficients(x)
y = self.func(coefficients, self.wavelet, self.axis)
return y


class WaveletReconstruction1d(WaveletReconstruction[WaveletCoeff1d, int]):
"""A module for 1D wavelet construction."""

def __init__(
self,
scales: int,
coefficient_lengths: list[int],
wavelet: WaveletFilter,
axis: int | None = None,
) -> None:
super().__init__(
scales=scales,
wavelet=wavelet,
coefficient_lengths=coefficient_lengths,
axis=-1 if axis is None else None,
func=waverec,
)

def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff1d:
"""Get coefficients for 1D reconstruction."""
coeff_lst = []
start = 0
# turn tensor into list
for s in range(self.scales + 1):
stop = start + self.coefficient_lengths[::-1][s]
coeff_lst.append(x[..., start:stop])
start = self.coefficient_lengths[s]
y = waverec(coeff_lst, self.wavelet)
return y
return coeff_lst


class WaveletReconstruction2d(WaveletReconstruction[WaveletCoeff2d, tuple[int, int]]):
"""A module for 2D wavelet construction."""

def __init__(
self,
scales: int,
coefficient_lengths: list[int],
wavelet: WaveletFilter,
axis: tuple[int, int] | None = None,
) -> None:
super().__init__(
scales=scales,
wavelet=wavelet,
coefficient_lengths=coefficient_lengths,
axis=(-2, -1) if axis is None else None,
func=waverec2,
)

def get_coefficients(self, x: torch.Tensor) -> WaveletCoeff2d:
"""Get coefficients for 2D reconstruction."""
raise NotImplementedError


class WaveletReconstruction3d(
WaveletReconstruction[WaveletCoeffNd, tuple[int, int, int]]
):
"""A module for 3D wavelet construction."""

def __init__(
self,
scales: int,
coefficient_lengths: list[int],
wavelet: WaveletFilter,
axis: tuple[int, int, int] | None = None,
) -> None:
super().__init__(
scales=scales,
wavelet=wavelet,
coefficient_lengths=coefficient_lengths,
axis=(-3, -2, -1) if axis is None else None,
func=waverec3,
)

def get_coefficients(self, x: torch.Tensor) -> WaveletCoeffNd:
"""Get coefficients for 3D reconstruction."""
raise NotImplementedError