Skip to content

Commit 9590f30

Browse files
committed
Merge branch 'blur' of https://github.com/VRandme/pytorch-image-models into VRandme-blur
2 parents 7a9942a + 1a9ab07 commit 9590f30

File tree

3 files changed

+115
-9
lines changed

3 files changed

+115
-9
lines changed

timm/models/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
1818
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
1919
from .anti_aliasing import AntiAliasDownsampleLayer
20-
from .space_to_depth import SpaceToDepthModule
20+
from .space_to_depth import SpaceToDepthModule
21+
from .blurpool import BlurPool2d

timm/models/layers/blurpool.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
BlurPool layer inspired by
3+
- Kornia's Max_BlurPool2d
4+
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
5+
6+
Hacked together by Chris Ha and Ross Wightman
7+
"""
8+
9+
10+
import torch
11+
import torch.nn as nn
12+
import torch.nn.functional as F
13+
import numpy as np
14+
from .padding import get_padding
15+
16+
17+
class BlurPool2d(nn.Module):
18+
r"""Creates a module that computes blurs and downsample a given feature map.
19+
See :cite:`zhang2019shiftinvar` for more details.
20+
Corresponds to the Downsample class, which does blurring and subsampling
21+
Args:
22+
channels = Number of input channels
23+
blur_filter_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
24+
stride (int): downsampling filter stride
25+
Shape:
26+
Returns:
27+
torch.Tensor: the transformed tensor.
28+
Examples:
29+
"""
30+
31+
def __init__(self, channels, blur_filter_size=3, stride=2) -> None:
32+
super(BlurPool2d, self).__init__()
33+
assert blur_filter_size > 1
34+
self.channels = channels
35+
self.blur_filter_size = blur_filter_size
36+
self.stride = stride
37+
38+
pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4
39+
self.padding = nn.ReflectionPad2d(pad_size)
40+
41+
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
42+
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
43+
self.blur_filter = blur_filter[None, None, :, :]
44+
45+
def _apply(self, fn):
46+
# override nn.Module _apply to prevent need for blur_filter to be registered as a buffer,
47+
# this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected
48+
super(BlurPool2d, self)._apply(fn)
49+
self.blur_filter = fn(self.blur_filter)
50+
51+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
52+
C = input_tensor.shape[1]
53+
return F.conv2d(
54+
self.padding(input_tensor),
55+
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), stride=self.stride, groups=C)

timm/models/resnet.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .registry import register_model
1414
from .helpers import load_pretrained, adapt_model_from_file
15-
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
15+
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
1616
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1717

1818
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
@@ -118,6 +118,8 @@ def _cfg(url='', **kwargs):
118118
'ecaresnet101d_pruned': _cfg(
119119
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
120120
interpolation='bicubic'),
121+
'resnetblur18': _cfg(),
122+
'resnetblur50': _cfg()
121123
}
122124

123125

@@ -131,7 +133,7 @@ class BasicBlock(nn.Module):
131133

132134
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
133135
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
134-
attn_layer=None, drop_block=None, drop_path=None):
136+
attn_layer=None, drop_block=None, drop_path=None, blur=False):
135137
super(BasicBlock, self).__init__()
136138

137139
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@@ -141,10 +143,12 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
141143
first_dilation = first_dilation or dilation
142144

143145
self.conv1 = nn.Conv2d(
144-
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
146+
inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation,
145147
dilation=first_dilation, bias=False)
146148
self.bn1 = norm_layer(first_planes)
147149
self.act1 = act_layer(inplace=True)
150+
self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None
151+
148152
self.conv2 = nn.Conv2d(
149153
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
150154
self.bn2 = norm_layer(outplanes)
@@ -169,6 +173,8 @@ def forward(self, x):
169173
if self.drop_block is not None:
170174
x = self.drop_block(x)
171175
x = self.act1(x)
176+
if self.blurpool is not None:
177+
x = self.blurpool(x)
172178

173179
x = self.conv2(x)
174180
x = self.bn2(x)
@@ -195,22 +201,26 @@ class Bottleneck(nn.Module):
195201

196202
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
197203
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
198-
attn_layer=None, drop_block=None, drop_path=None):
204+
attn_layer=None, drop_block=None, drop_path=None, blur=False):
199205
super(Bottleneck, self).__init__()
200206

201207
width = int(math.floor(planes * (base_width / 64)) * cardinality)
202208
first_planes = width // reduce_first
203209
outplanes = planes * self.expansion
204210
first_dilation = first_dilation or dilation
211+
self.blur = blur
205212

206213
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
207214
self.bn1 = norm_layer(first_planes)
208215
self.act1 = act_layer(inplace=True)
216+
209217
self.conv2 = nn.Conv2d(
210-
first_planes, width, kernel_size=3, stride=stride,
218+
first_planes, width, kernel_size=3, stride=1 if blur else stride,
211219
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
212220
self.bn2 = norm_layer(width)
213221
self.act2 = act_layer(inplace=True)
222+
self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None
223+
214224
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
215225
self.bn3 = norm_layer(outplanes)
216226

@@ -240,6 +250,8 @@ def forward(self, x):
240250
if self.drop_block is not None:
241251
x = self.drop_block(x)
242252
x = self.act2(x)
253+
if self.blurpool is not None:
254+
x = self.blurpool(x)
243255

244256
x = self.conv3(x)
245257
x = self.bn3(x)
@@ -359,12 +371,19 @@ class ResNet(nn.Module):
359371
Dropout probability before classifier, for training
360372
global_pool : str, default 'avg'
361373
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
374+
blur : str, default ''
375+
Location of Blurring:
376+
* '', default - Not applied
377+
* 'max' - only stem layer MaxPool will be blurred
378+
* 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style)
379+
* 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets)
380+
362381
"""
363382
def __init__(self, block, layers, num_classes=1000, in_chans=3,
364383
cardinality=1, base_width=64, stem_width=64, stem_type='',
365384
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
366385
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
367-
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
386+
drop_block_rate=0., global_pool='avg', blur='', zero_init_last_bn=True, block_args=None):
368387
block_args = block_args or dict()
369388
self.num_classes = num_classes
370389
deep_stem = 'deep' in stem_type
@@ -373,6 +392,7 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
373392
self.base_width = base_width
374393
self.drop_rate = drop_rate
375394
self.expansion = block.expansion
395+
self.blur = 'strided' in blur
376396
super(ResNet, self).__init__()
377397

378398
# Stem
@@ -393,7 +413,14 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
393413
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
394414
self.bn1 = norm_layer(self.inplanes)
395415
self.act1 = act_layer(inplace=True)
396-
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
416+
# Stem Pooling
417+
if 'max' in blur :
418+
self.maxpool = nn.Sequential(*[
419+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
420+
BlurPool2d(channels=self.inplanes, stride=2)
421+
])
422+
else :
423+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
397424

398425
# Feature Blocks
399426
dp = DropPath(drop_path_rate) if drop_path_rate else None
@@ -445,7 +472,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=
445472

446473
block_kwargs = dict(
447474
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
448-
dilation=dilation, **kwargs)
475+
dilation=dilation, blur=self.blur, **kwargs)
449476
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
450477
self.inplanes = planes * block.expansion
451478
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
@@ -1114,3 +1141,26 @@ def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwarg
11141141
if pretrained:
11151142
load_pretrained(model, default_cfg, num_classes, in_chans)
11161143
return model
1144+
1145+
1146+
@register_model
1147+
def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1148+
"""Constructs a ResNet-18 model with blur anti-aliasing
1149+
"""
1150+
default_cfg = default_cfgs['resnetblur18']
1151+
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, blur='max_strided',**kwargs)
1152+
model.default_cfg = default_cfg
1153+
if pretrained:
1154+
load_pretrained(model, default_cfg, num_classes, in_chans)
1155+
return model
1156+
1157+
@register_model
1158+
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1159+
"""Constructs a ResNet-50 model with blur anti-aliasing
1160+
"""
1161+
default_cfg = default_cfgs['resnetblur50']
1162+
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs)
1163+
model.default_cfg = default_cfg
1164+
if pretrained:
1165+
load_pretrained(model, default_cfg, num_classes, in_chans)
1166+
return model

0 commit comments

Comments
 (0)