1- '''independent attempt to implement
2-
3- MaxBlurPool2d in a more general fashion(separate maxpooling from BlurPool)
4- which was again inspired by
1+ '''
2+ BlurPool layer inspired by
3+ Kornia's Max_BlurPool2d
4+ and
55Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
66
77'''
@@ -17,8 +17,7 @@ class BlurPool2d(nn.Module):
1717 Corresponds to the Downsample class, which does blurring and subsampling
1818 Args:
1919 channels = Number of input channels
20- blur_filter_size (int): filter size for blurring. currently supports either 3 or 5 (most common)
21- defaults to 3.
20+ blur_filter_size (int): binomial filter size for blurring. currently supports 3(default) and 5.
2221 stride (int): downsampling filter stride
2322 Shape:
2423 Returns:
@@ -35,34 +34,21 @@ def __init__(self, channels=None, blur_filter_size=3, stride=2) -> None:
3534
3635 if blur_filter_size == 3 :
3736 pad_size = [1 ] * 4
38- blur_matrix = torch .Tensor ([[1. , 2. , 1 ]]) / 4 # binomial kernel b2
37+ blur_matrix = torch .Tensor ([[1. , 2. , 1 ]]) / 4 # binomial filter b2
3938 else :
4039 pad_size = [2 ] * 4
41- blur_matrix = torch .Tensor ([[1. , 4. , 6. , 4. , 1. ]]) / 16 # binomial filter kernel b4
40+ blur_matrix = torch .Tensor ([[1. , 4. , 6. , 4. , 1. ]]) / 16 # binomial filter b4
4241
4342 self .padding = nn .ReflectionPad2d (pad_size )
4443 blur_filter = blur_matrix * blur_matrix .T
4544 self .register_buffer ('blur_filter' , blur_filter [None , None , :, :].repeat ((self .channels , 1 , 1 , 1 )))
4645
47- def forward (self , input_tensor : torch .Tensor ) -> torch .Tensor : # type: ignore
46+ def forward (self , input_tensor : torch .Tensor ) -> torch .Tensor : # type: ignore
4847 if not torch .is_tensor (input_tensor ):
4948 raise TypeError ("Input input type is not a torch.Tensor. Got {}"
5049 .format (type (input_tensor )))
5150 if not len (input_tensor .shape ) == 4 :
5251 raise ValueError ("Invalid input shape, we expect BxCxHxW. Got: {}"
5352 .format (input_tensor .shape ))
5453 # apply blur_filter on input
55- return F .conv2d (self .padding (input_tensor ), self .blur_filter , stride = self .stride , groups = input_tensor .shape [1 ])
56-
57-
58- ######################
59- # functional interface
60- ######################
61-
62-
63- '''def blur_pool2d() -> torch.Tensor:
64- r"""Creates a module that computes pools and blurs and downsample a given
65- feature map.
66- See :class:`~kornia.contrib.MaxBlurPool2d` for details.
67- """
68- return BlurPool2d(kernel_size, ceil_mode)(input)'''
54+ return F .conv2d (self .padding (input_tensor ), self .blur_filter , stride = self .stride , groups = input_tensor .shape [1 ])
0 commit comments