From 81698f311cc4608f3c22ffc0adcfb2d9428a71e9 Mon Sep 17 00:00:00 2001 From: Max Ehrlich Date: Wed, 2 Aug 2023 10:38:35 -0400 Subject: [PATCH] Correctly annotate autograd functions to support amp Signed-off-by: Max Ehrlich --- pytorch_wavelets/dtcwt/transform_funcs.py | 9 +++++++++ pytorch_wavelets/dwt/lowlevel.py | 9 +++++++++ pytorch_wavelets/scatternet/lowlevel.py | 11 +++++++++++ 3 files changed, 29 insertions(+) diff --git a/pytorch_wavelets/dtcwt/transform_funcs.py b/pytorch_wavelets/dtcwt/transform_funcs.py index b00ea29..a630efc 100644 --- a/pytorch_wavelets/dtcwt/transform_funcs.py +++ b/pytorch_wavelets/dtcwt/transform_funcs.py @@ -1,6 +1,7 @@ import torch from torch import tensor from torch.autograd import Function +from torch.cuda.amp import custom_fwd, custom_bwd from pytorch_wavelets.dtcwt.lowlevel import colfilter, rowfilter from pytorch_wavelets.dtcwt.lowlevel import coldfilt, rowdfilt from pytorch_wavelets.dtcwt.lowlevel import colifilt, rowifilt, q2c, c2q @@ -343,6 +344,7 @@ def inv_j2plus_rot(ll, highr, highi, g0a, g1a, g0b, g1b, g2a, g2b, class FWD_J1(Function): """ Differentiable function doing 1 level forward DTCWT """ @staticmethod + @custom_fwd def forward(ctx, x, h0, h1, skip_hps, o_dim, ri_dim, mode): mode = int_to_mode(mode) ctx.mode = mode @@ -358,6 +360,7 @@ def forward(ctx, x, h0, h1, skip_hps, o_dim, ri_dim, mode): return ll, highs @staticmethod + @custom_bwd def backward(ctx, dl, dh): h0, h1 = ctx.saved_tensors mode = ctx.mode @@ -377,6 +380,7 @@ def backward(ctx, dl, dh): class FWD_J2PLUS(Function): """ Differentiable function doing second level forward DTCWT """ @staticmethod + @custom_fwd def forward(ctx, x, h0a, h1a, h0b, h1b, skip_hps, o_dim, ri_dim, mode): mode = 'symmetric' ctx.mode = mode @@ -392,6 +396,7 @@ def forward(ctx, x, h0a, h1a, h0b, h1b, skip_hps, o_dim, ri_dim, mode): return ll, highs @staticmethod + @custom_bwd def backward(ctx, dl, dh): h0a, h1a, h0b, h1b = ctx.saved_tensors mode = ctx.mode @@ -416,6 +421,7 @@ def backward(ctx, dl, dh): class INV_J1(Function): """ Differentiable function doing 1 level inverse DTCWT """ @staticmethod + @custom_fwd def forward(ctx, lows, highs, g0, g1, o_dim, ri_dim, mode): mode = int_to_mode(mode) ctx.mode = mode @@ -431,6 +437,7 @@ def forward(ctx, lows, highs, g0, g1, o_dim, ri_dim, mode): return y @staticmethod + @custom_bwd def backward(ctx, dy): g0, g1 = ctx.saved_tensors dl = None @@ -452,6 +459,7 @@ def backward(ctx, dy): class INV_J2PLUS(Function): """ Differentiable function doing level 2 onwards inverse DTCWT """ @staticmethod + @custom_fwd def forward(ctx, lows, highs, g0a, g1a, g0b, g1b, o_dim, ri_dim, mode): mode = 'symmetric' ctx.mode = mode @@ -468,6 +476,7 @@ def forward(ctx, lows, highs, g0a, g1a, g0b, g1b, o_dim, ri_dim, mode): return y @staticmethod + @custom_bwd def backward(ctx, dy): g0a, g1a, g0b, g1b = ctx.saved_tensors g0a, g0b = g0b, g0a diff --git a/pytorch_wavelets/dwt/lowlevel.py b/pytorch_wavelets/dwt/lowlevel.py index b2453f2..c70d57c 100644 --- a/pytorch_wavelets/dwt/lowlevel.py +++ b/pytorch_wavelets/dwt/lowlevel.py @@ -2,6 +2,7 @@ import torch.nn.functional as F import numpy as np from torch.autograd import Function +from torch.cuda.amp import custom_fwd, custom_bwd from pytorch_wavelets.utils import reflect import pywt @@ -333,6 +334,7 @@ class AFB2D(Function): y: Tensor of shape (N, C*4, H, W) """ @staticmethod + @custom_fwd def forward(ctx, x, h0_row, h1_row, h0_col, h1_col, mode): ctx.save_for_backward(h0_row, h1_row, h0_col, h1_col) ctx.shape = x.shape[-2:] @@ -347,6 +349,7 @@ def forward(ctx, x, h0_row, h1_row, h0_col, h1_col, mode): return low, highs @staticmethod + @custom_bwd def backward(ctx, low, highs): dx = None if ctx.needs_input_grad[0]: @@ -386,6 +389,7 @@ class AFB1D(Function): x1: Tensor of shape (N, C, L') - highpass """ @staticmethod + @custom_fwd def forward(ctx, x, h0, h1, mode): mode = int_to_mode(mode) @@ -405,6 +409,7 @@ def forward(ctx, x, h0, h1, mode): return x0, x1 @staticmethod + @custom_bwd def backward(ctx, dx0, dx1): dx = None if ctx.needs_input_grad[0]: @@ -668,6 +673,7 @@ class SFB2D(Function): y: Tensor of shape (N, C*4, H, W) """ @staticmethod + @custom_fwd def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode): mode = int_to_mode(mode) ctx.mode = mode @@ -680,6 +686,7 @@ def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode): return y @staticmethod + @custom_bwd def backward(ctx, dy): dlow, dhigh = None, None if ctx.needs_input_grad[0]: @@ -715,6 +722,7 @@ class SFB1D(Function): y: Tensor of shape (N, C*2, L') """ @staticmethod + @custom_fwd def forward(ctx, low, high, g0, g1, mode): mode = int_to_mode(mode) # Make into a 2d tensor with 1 row @@ -729,6 +737,7 @@ def forward(ctx, low, high, g0, g1, mode): return sfb1d(low, high, g0, g1, mode=mode, dim=3)[:, :, 0] @staticmethod + @custom_bwd def backward(ctx, dy): dlow, dhigh = None, None if ctx.needs_input_grad[0]: diff --git a/pytorch_wavelets/scatternet/lowlevel.py b/pytorch_wavelets/scatternet/lowlevel.py index d2b2238..d036a8b 100644 --- a/pytorch_wavelets/scatternet/lowlevel.py +++ b/pytorch_wavelets/scatternet/lowlevel.py @@ -1,6 +1,7 @@ from __future__ import absolute_import import torch import torch.nn.functional as F +from torch.cuda.amp import custom_fwd, custom_bwd from pytorch_wavelets.dtcwt.transform_funcs import fwd_j1, inv_j1 from pytorch_wavelets.dtcwt.transform_funcs import fwd_j1_rot, inv_j1_rot @@ -49,6 +50,7 @@ def int_to_mode(mode): class SmoothMagFn(torch.autograd.Function): """ Class to do complex magnitude """ @staticmethod + @custom_fwd def forward(ctx, x, y, b): r = torch.sqrt(x**2 + y**2 + b**2) if x.requires_grad: @@ -59,6 +61,7 @@ def forward(ctx, x, y, b): return r - b @staticmethod + @custom_bwd def backward(ctx, dr): dx = None if ctx.needs_input_grad[0]: @@ -73,6 +76,7 @@ class ScatLayerj1_f(torch.autograd.Function): layer with the DTCWT biorthogonal filters. """ @staticmethod + @custom_fwd def forward(ctx, x, h0o, h1o, mode, bias, combine_colour): # bias = 1e-2 # bias = 0 @@ -111,6 +115,7 @@ def forward(ctx, x, h0o, h1o, mode, bias, combine_colour): return Z @staticmethod + @custom_bwd def backward(ctx, dZ): dX = None mode = ctx.mode @@ -143,6 +148,7 @@ class ScatLayerj1_rot_f(torch.autograd.Function): filters, i.e. a slightly more expensive operation.""" @staticmethod + @custom_fwd def forward(ctx, x, h0o, h1o, h2o, mode, bias, combine_colour): mode = int_to_mode(mode) ctx.mode = mode @@ -179,6 +185,7 @@ def forward(ctx, x, h0o, h1o, h2o, mode, bias, combine_colour): return Z @staticmethod + @custom_bwd def backward(ctx, dZ): dX = None mode = ctx.mode @@ -208,6 +215,7 @@ class ScatLayerj2_f(torch.autograd.Function): layer with the DTCWT biorthogonal filters. """ @staticmethod + @custom_fwd def forward(ctx, x, h0o, h1o, h0a, h0b, h1a, h1b, mode, bias, combine_colour): # bias = 1e-2 # bias = 0 @@ -309,6 +317,7 @@ def forward(ctx, x, h0o, h1o, h0a, h0b, h1a, h1b, mode, bias, combine_colour): return Z @staticmethod + @custom_bwd def backward(ctx, dZ): dX = None mode = ctx.mode @@ -403,6 +412,7 @@ class ScatLayerj2_rot_f(torch.autograd.Function): layer with the DTCWT bandpass biorthogonal and qshift filters . """ @staticmethod + @custom_fwd def forward(ctx, x, h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, mode, bias, combine_colour): # bias = 1e-2 # bias = 0 @@ -502,6 +512,7 @@ def forward(ctx, x, h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, mode, bias, com return Z @staticmethod + @custom_bwd def backward(ctx, dZ): dX = None mode = ctx.mode