Skip to content

Commit acd1b6c

Browse files
committed
Implement Functional Blur on resnet.py
1. add ResNet argument blur='' 2. implement blur for maxpool and strided convs in downsampling blocks
1 parent ce3d82b commit acd1b6c

File tree

3 files changed

+70
-12
lines changed

3 files changed

+70
-12
lines changed

timm/models/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
1616
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
1717
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
18+
from .blurpool import BlurPool2d

timm/models/layers/blurpool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class BlurPool2d(nn.Module):
1717
Corresponds to the Downsample class, which does blurring and subsampling
1818
Args:
1919
channels = Number of input channels
20-
blur_filter_size (int): binomial filter size for blurring. currently supports 3(default) and 5.
20+
blur_filter_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
2121
stride (int): downsampling filter stride
2222
Shape:
2323
Returns:

timm/models/resnet.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .registry import register_model
1414
from .helpers import load_pretrained
15-
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
15+
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
1616
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1717

1818

@@ -104,6 +104,8 @@ def _cfg(url='', **kwargs):
104104
interpolation='bicubic'),
105105
'ecaresnet18': _cfg(),
106106
'ecaresnet50': _cfg(),
107+
'resnetblur18': _cfg(),
108+
'resnetblur50': _cfg()
107109
}
108110

109111

@@ -117,18 +119,27 @@ class BasicBlock(nn.Module):
117119

118120
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
119121
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
120-
attn_layer=None, drop_block=None, drop_path=None):
122+
attn_layer=None, drop_block=None, drop_path=None, blur=False):
121123
super(BasicBlock, self).__init__()
122124

123125
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
124126
assert base_width == 64, 'BasicBlock doest not support changing base width'
125127
first_planes = planes // reduce_first
126128
outplanes = planes * self.expansion
127129
first_dilation = first_dilation or dilation
130+
self.blur = blur
128131

129-
self.conv1 = nn.Conv2d(
132+
if blur and stride==2:
133+
self.conv1 = nn.Conv2d(
134+
inplanes, first_planes, kernel_size=3, stride=1, padding=first_dilation,
135+
dilation=first_dilation, bias=False)
136+
self.blurpool=BlurPool2d(channels=first_planes)
137+
else:
138+
self.conv1 = nn.Conv2d(
130139
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
131140
dilation=first_dilation, bias=False)
141+
self.blurpool = None
142+
132143
self.bn1 = norm_layer(first_planes)
133144
self.act1 = act_layer(inplace=True)
134145
self.conv2 = nn.Conv2d(
@@ -154,7 +165,11 @@ def forward(self, x):
154165
x = self.bn1(x)
155166
if self.drop_block is not None:
156167
x = self.drop_block(x)
157-
x = self.act1(x)
168+
if self.blurpool is not None:
169+
x = self.act1(x)
170+
x = self.blurpool(x)
171+
else:
172+
x = self.act1(x)
158173

159174
x = self.conv2(x)
160175
x = self.bn2(x)
@@ -181,20 +196,30 @@ class Bottleneck(nn.Module):
181196

182197
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
183198
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
184-
attn_layer=None, drop_block=None, drop_path=None):
199+
attn_layer=None, drop_block=None, drop_path=None, blur=False):
185200
super(Bottleneck, self).__init__()
186201

187202
width = int(math.floor(planes * (base_width / 64)) * cardinality)
188203
first_planes = width // reduce_first
189204
outplanes = planes * self.expansion
190205
first_dilation = first_dilation or dilation
206+
self.blur = blur
191207

192208
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
193209
self.bn1 = norm_layer(first_planes)
194210
self.act1 = act_layer(inplace=True)
195-
self.conv2 = nn.Conv2d(
196-
first_planes, width, kernel_size=3, stride=stride,
197-
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
211+
212+
if blur and stride==2:
213+
self.conv2 = nn.Conv2d(
214+
first_planes, width, kernel_size=3, stride=1,
215+
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
216+
self.blurpool = BlurPool2d(channels=width)
217+
else:
218+
self.conv2 = nn.Conv2d(
219+
first_planes, width, kernel_size=3, stride=stride,
220+
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
221+
self.blurpool = None
222+
198223
self.bn2 = norm_layer(width)
199224
self.act2 = act_layer(inplace=True)
200225
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
@@ -345,12 +370,19 @@ class ResNet(nn.Module):
345370
Dropout probability before classifier, for training
346371
global_pool : str, default 'avg'
347372
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
373+
blur : str, default ''
374+
Location of Blurring:
375+
* '', default - Not applied
376+
* 'max' - only stem layer MaxPool will be blurred
377+
* 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style)
378+
* 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets)
379+
348380
"""
349381
def __init__(self, block, layers, num_classes=1000, in_chans=3,
350382
cardinality=1, base_width=64, stem_width=64, stem_type='',
351383
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
352384
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
353-
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
385+
drop_block_rate=0., global_pool='avg', blur='', zero_init_last_bn=True, block_args=None):
354386
block_args = block_args or dict()
355387
self.num_classes = num_classes
356388
deep_stem = 'deep' in stem_type
@@ -359,6 +391,7 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
359391
self.base_width = base_width
360392
self.drop_rate = drop_rate
361393
self.expansion = block.expansion
394+
self.blur = 'strided' in blur
362395
super(ResNet, self).__init__()
363396

364397
# Stem
@@ -379,7 +412,13 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
379412
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
380413
self.bn1 = norm_layer(self.inplanes)
381414
self.act1 = act_layer(inplace=True)
382-
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
415+
# Stem Blur
416+
if 'max' in blur :
417+
self.maxpool = nn.Sequential(*[
418+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
419+
BlurPool2d(channels=self.inplanes)])
420+
else :
421+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
383422

384423
# Feature Blocks
385424
dp = DropPath(drop_path_rate) if drop_path_rate else None
@@ -432,7 +471,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=
432471
block_kwargs = dict(
433472
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
434473
dilation=dilation, **kwargs)
435-
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
474+
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, blur=self.blur, **block_kwargs)]
436475
self.inplanes = planes * block.expansion
437476
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
438477

@@ -1022,3 +1061,21 @@ def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
10221061
if pretrained:
10231062
load_pretrained(model, default_cfg, num_classes, in_chans)
10241063
return model
1064+
1065+
@register_model
1066+
def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1067+
"""Constructs a ResNet-18 model. With original style blur
1068+
"""
1069+
default_cfg = default_cfgs['resnetblur18']
1070+
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, blur='max_strided',**kwargs)
1071+
model.default_cfg = default_cfg
1072+
return model
1073+
1074+
@register_model
1075+
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1076+
"""Constructs a ResNet-50 model. With assembled-cnn style blur
1077+
"""
1078+
default_cfg = default_cfgs['resnetblur18']
1079+
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='strided', **kwargs)
1080+
model.default_cfg = default_cfg
1081+
return model

0 commit comments

Comments
 (0)