1+ '''independent attempt to implement
2+
3+ MaxBlurPool2d in a more general fashion(separate maxpooling from BlurPool)
4+ which was again inspired by
5+ Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
6+
7+ '''
8+
9+ import torch
10+ import torch .nn as nn
11+ import torch .nn .functional as F
12+
13+
14+ class BlurPool2d (nn .Module ):
15+ r"""Creates a module that computes blurs and downsample a given feature map.
16+ See :cite:`zhang2019shiftinvar` for more details.
17+ Corresponds to the Downsample class, which does blurring and subsampling
18+ Args:
19+ channels = Number of input channels
20+ blur_filter_size (int): filter size for blurring. currently supports either 3 or 5 (most common)
21+ defaults to 3.
22+ stride (int): downsampling filter stride
23+ Shape:
24+ Returns:
25+ torch.Tensor: the transformed tensor.
26+ Examples:
27+ """
28+
29+ def __init__ (self , channels = None , blur_filter_size = 3 , stride = 2 ) -> None :
30+ super (BlurPool2d , self ).__init__ ()
31+ assert blur_filter_size in [3 , 5 ]
32+ self .channels = channels
33+ self .blur_filter_size = blur_filter_size
34+ self .stride = stride
35+
36+ if blur_filter_size == 3 :
37+ pad_size = [1 ] * 4
38+ blur_matrix = torch .Tensor ([[1. , 2. , 1 ]]) / 4 # binomial kernel b2
39+ else :
40+ pad_size = [2 ] * 4
41+ blur_matrix = torch .Tensor ([[1. , 4. , 6. , 4. , 1. ]]) / 16 # binomial filter kernel b4
42+
43+ self .padding = nn .ReflectionPad2d (pad_size )
44+ blur_filter = blur_matrix * blur_matrix .T
45+ self .register_buffer ('blur_filter' , blur_filter [None , None , :, :].repeat ((self .channels , 1 , 1 , 1 )))
46+
47+ def forward (self , input_tensor : torch .Tensor ) -> torch .Tensor : # type: ignore
48+ if not torch .is_tensor (input_tensor ):
49+ raise TypeError ("Input input type is not a torch.Tensor. Got {}"
50+ .format (type (input_tensor )))
51+ if not len (input_tensor .shape ) == 4 :
52+ raise ValueError ("Invalid input shape, we expect BxCxHxW. Got: {}"
53+ .format (input_tensor .shape ))
54+ # apply blur_filter on input
55+ return F .conv2d (self .padding (input_tensor ), self .blur_filter , stride = self .stride , groups = input_tensor .shape [1 ])
56+
57+
58+ ######################
59+ # functional interface
60+ ######################
61+
62+
63+ '''def blur_pool2d() -> torch.Tensor:
64+ r"""Creates a module that computes pools and blurs and downsample a given
65+ feature map.
66+ See :class:`~kornia.contrib.MaxBlurPool2d` for details.
67+ """
68+ return BlurPool2d(kernel_size, ceil_mode)(input)'''
0 commit comments