Skip to content

Commit 1aa617c

Browse files
committed
Add AvgPool2d anti-aliasing support to ResNet arch (as per OpenAI CLIP models), add a few blur aa models as well
1 parent f0f9ecc commit 1aa617c

File tree

1 file changed

+81
-7
lines changed

1 file changed

+81
-7
lines changed

timm/models/resnet.py

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,21 @@ def _cfg(url='', **kwargs):
251251
'resnetblur50': _cfg(
252252
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
253253
interpolation='bicubic'),
254+
'resnetblur50d': _cfg(
255+
url='',
256+
interpolation='bicubic', first_conv='conv1.0'),
257+
'resnetblur101d': _cfg(
258+
url='',
259+
interpolation='bicubic', first_conv='conv1.0'),
260+
'resnetaa50d': _cfg(
261+
url='',
262+
interpolation='bicubic', first_conv='conv1.0'),
263+
'resnetaa101d': _cfg(
264+
url='',
265+
interpolation='bicubic', first_conv='conv1.0'),
266+
'seresnetaa50d': _cfg(
267+
url='',
268+
interpolation='bicubic', first_conv='conv1.0'),
254269

255270
# ResNet-RS models
256271
'resnetrs50': _cfg(
@@ -289,6 +304,12 @@ def get_padding(kernel_size, stride, dilation=1):
289304
return padding
290305

291306

307+
def create_aa(aa_layer, channels, stride=2, enable=True):
308+
if not aa_layer or not enable:
309+
return None
310+
return aa_layer(stride) if issubclass(aa_layer, nn.AvgPool2d) else aa_layer(channels=channels, stride=stride)
311+
312+
292313
class BasicBlock(nn.Module):
293314
expansion = 1
294315

@@ -309,7 +330,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
309330
dilation=first_dilation, bias=False)
310331
self.bn1 = norm_layer(first_planes)
311332
self.act1 = act_layer(inplace=True)
312-
self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None
333+
self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa)
313334

314335
self.conv2 = nn.Conv2d(
315336
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
@@ -380,7 +401,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
380401
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
381402
self.bn2 = norm_layer(width)
382403
self.act2 = act_layer(inplace=True)
383-
self.aa = aa_layer(channels=width, stride=stride) if use_aa else None
404+
self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa)
384405

385406
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
386407
self.bn3 = norm_layer(outplanes)
@@ -617,19 +638,22 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
617638
self.act1 = act_layer(inplace=True)
618639
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
619640

620-
# Stem Pooling
641+
# Stem pooling. The name 'maxpool' remains for weight compatibility.
621642
if replace_stem_pool:
622643
self.maxpool = nn.Sequential(*filter(None, [
623644
nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False),
624-
aa_layer(channels=inplanes, stride=2) if aa_layer else None,
645+
create_aa(aa_layer, channels=inplanes, stride=2),
625646
norm_layer(inplanes),
626647
act_layer(inplace=True)
627648
]))
628649
else:
629650
if aa_layer is not None:
630-
self.maxpool = nn.Sequential(*[
631-
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
632-
aa_layer(channels=inplanes, stride=2)])
651+
if issubclass(aa_layer, nn.AvgPool2d):
652+
self.maxpool = aa_layer(2)
653+
else:
654+
self.maxpool = nn.Sequential(*[
655+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
656+
aa_layer(channels=inplanes, stride=2)])
633657
else:
634658
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
635659

@@ -1342,6 +1366,56 @@ def resnetblur50(pretrained=False, **kwargs):
13421366
return _create_resnet('resnetblur50', pretrained, **model_args)
13431367

13441368

1369+
@register_model
1370+
def resnetblur50d(pretrained=False, **kwargs):
1371+
"""Constructs a ResNet-50-D model with blur anti-aliasing
1372+
"""
1373+
model_args = dict(
1374+
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
1375+
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
1376+
return _create_resnet('resnetblur50d', pretrained, **model_args)
1377+
1378+
1379+
@register_model
1380+
def resnetblur101d(pretrained=False, **kwargs):
1381+
"""Constructs a ResNet-101-D model with blur anti-aliasing
1382+
"""
1383+
model_args = dict(
1384+
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
1385+
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
1386+
return _create_resnet('resnetblur101d', pretrained, **model_args)
1387+
1388+
1389+
@register_model
1390+
def resnetaa50d(pretrained=False, **kwargs):
1391+
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
1392+
"""
1393+
model_args = dict(
1394+
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
1395+
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
1396+
return _create_resnet('resnetaa50d', pretrained, **model_args)
1397+
1398+
1399+
@register_model
1400+
def resnetaa101d(pretrained=False, **kwargs):
1401+
"""Constructs a ResNet-101-D model with avgpool anti-aliasing
1402+
"""
1403+
model_args = dict(
1404+
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
1405+
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
1406+
return _create_resnet('resnetaa101d', pretrained, **model_args)
1407+
1408+
1409+
@register_model
1410+
def seresnetaa50d(pretrained=False, **kwargs):
1411+
"""Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
1412+
"""
1413+
model_args = dict(
1414+
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
1415+
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
1416+
return _create_resnet('seresnetaa50d', pretrained, **model_args)
1417+
1418+
13451419
@register_model
13461420
def seresnet18(pretrained=False, **kwargs):
13471421
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs)

0 commit comments

Comments
 (0)