@@ -251,6 +251,21 @@ def _cfg(url='', **kwargs):
251251 'resnetblur50' : _cfg (
252252 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth' ,
253253 interpolation = 'bicubic' ),
254+ 'resnetblur50d' : _cfg (
255+ url = '' ,
256+ interpolation = 'bicubic' , first_conv = 'conv1.0' ),
257+ 'resnetblur101d' : _cfg (
258+ url = '' ,
259+ interpolation = 'bicubic' , first_conv = 'conv1.0' ),
260+ 'resnetaa50d' : _cfg (
261+ url = '' ,
262+ interpolation = 'bicubic' , first_conv = 'conv1.0' ),
263+ 'resnetaa101d' : _cfg (
264+ url = '' ,
265+ interpolation = 'bicubic' , first_conv = 'conv1.0' ),
266+ 'seresnetaa50d' : _cfg (
267+ url = '' ,
268+ interpolation = 'bicubic' , first_conv = 'conv1.0' ),
254269
255270 # ResNet-RS models
256271 'resnetrs50' : _cfg (
@@ -289,6 +304,12 @@ def get_padding(kernel_size, stride, dilation=1):
289304 return padding
290305
291306
307+ def create_aa (aa_layer , channels , stride = 2 , enable = True ):
308+ if not aa_layer or not enable :
309+ return None
310+ return aa_layer (stride ) if issubclass (aa_layer , nn .AvgPool2d ) else aa_layer (channels = channels , stride = stride )
311+
312+
292313class BasicBlock (nn .Module ):
293314 expansion = 1
294315
@@ -309,7 +330,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
309330 dilation = first_dilation , bias = False )
310331 self .bn1 = norm_layer (first_planes )
311332 self .act1 = act_layer (inplace = True )
312- self .aa = aa_layer ( channels = first_planes , stride = stride ) if use_aa else None
333+ self .aa = create_aa ( aa_layer , channels = first_planes , stride = stride , enable = use_aa )
313334
314335 self .conv2 = nn .Conv2d (
315336 first_planes , outplanes , kernel_size = 3 , padding = dilation , dilation = dilation , bias = False )
@@ -380,7 +401,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
380401 padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
381402 self .bn2 = norm_layer (width )
382403 self .act2 = act_layer (inplace = True )
383- self .aa = aa_layer ( channels = width , stride = stride ) if use_aa else None
404+ self .aa = create_aa ( aa_layer , channels = width , stride = stride , enable = use_aa )
384405
385406 self .conv3 = nn .Conv2d (width , outplanes , kernel_size = 1 , bias = False )
386407 self .bn3 = norm_layer (outplanes )
@@ -617,19 +638,22 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
617638 self .act1 = act_layer (inplace = True )
618639 self .feature_info = [dict (num_chs = inplanes , reduction = 2 , module = 'act1' )]
619640
620- # Stem Pooling
641+ # Stem pooling. The name 'maxpool' remains for weight compatibility.
621642 if replace_stem_pool :
622643 self .maxpool = nn .Sequential (* filter (None , [
623644 nn .Conv2d (inplanes , inplanes , 3 , stride = 1 if aa_layer else 2 , padding = 1 , bias = False ),
624- aa_layer ( channels = inplanes , stride = 2 ) if aa_layer else None ,
645+ create_aa ( aa_layer , channels = inplanes , stride = 2 ),
625646 norm_layer (inplanes ),
626647 act_layer (inplace = True )
627648 ]))
628649 else :
629650 if aa_layer is not None :
630- self .maxpool = nn .Sequential (* [
631- nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
632- aa_layer (channels = inplanes , stride = 2 )])
651+ if issubclass (aa_layer , nn .AvgPool2d ):
652+ self .maxpool = aa_layer (2 )
653+ else :
654+ self .maxpool = nn .Sequential (* [
655+ nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
656+ aa_layer (channels = inplanes , stride = 2 )])
633657 else :
634658 self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
635659
@@ -1342,6 +1366,56 @@ def resnetblur50(pretrained=False, **kwargs):
13421366 return _create_resnet ('resnetblur50' , pretrained , ** model_args )
13431367
13441368
1369+ @register_model
1370+ def resnetblur50d (pretrained = False , ** kwargs ):
1371+ """Constructs a ResNet-50-D model with blur anti-aliasing
1372+ """
1373+ model_args = dict (
1374+ block = Bottleneck , layers = [3 , 4 , 6 , 3 ], aa_layer = BlurPool2d ,
1375+ stem_width = 32 , stem_type = 'deep' , avg_down = True , ** kwargs )
1376+ return _create_resnet ('resnetblur50d' , pretrained , ** model_args )
1377+
1378+
1379+ @register_model
1380+ def resnetblur101d (pretrained = False , ** kwargs ):
1381+ """Constructs a ResNet-101-D model with blur anti-aliasing
1382+ """
1383+ model_args = dict (
1384+ block = Bottleneck , layers = [3 , 4 , 23 , 3 ], aa_layer = BlurPool2d ,
1385+ stem_width = 32 , stem_type = 'deep' , avg_down = True , ** kwargs )
1386+ return _create_resnet ('resnetblur101d' , pretrained , ** model_args )
1387+
1388+
1389+ @register_model
1390+ def resnetaa50d (pretrained = False , ** kwargs ):
1391+ """Constructs a ResNet-50-D model with avgpool anti-aliasing
1392+ """
1393+ model_args = dict (
1394+ block = Bottleneck , layers = [3 , 4 , 6 , 3 ], aa_layer = nn .AvgPool2d ,
1395+ stem_width = 32 , stem_type = 'deep' , avg_down = True , ** kwargs )
1396+ return _create_resnet ('resnetaa50d' , pretrained , ** model_args )
1397+
1398+
1399+ @register_model
1400+ def resnetaa101d (pretrained = False , ** kwargs ):
1401+ """Constructs a ResNet-101-D model with avgpool anti-aliasing
1402+ """
1403+ model_args = dict (
1404+ block = Bottleneck , layers = [3 , 4 , 23 , 3 ], aa_layer = nn .AvgPool2d ,
1405+ stem_width = 32 , stem_type = 'deep' , avg_down = True , ** kwargs )
1406+ return _create_resnet ('resnetaa101d' , pretrained , ** model_args )
1407+
1408+
1409+ @register_model
1410+ def seresnetaa50d (pretrained = False , ** kwargs ):
1411+ """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
1412+ """
1413+ model_args = dict (
1414+ block = Bottleneck , layers = [3 , 4 , 6 , 3 ], aa_layer = nn .AvgPool2d ,
1415+ stem_width = 32 , stem_type = 'deep' , avg_down = True , block_args = dict (attn_layer = 'se' ), ** kwargs )
1416+ return _create_resnet ('seresnetaa50d' , pretrained , ** model_args )
1417+
1418+
13451419@register_model
13461420def seresnet18 (pretrained = False , ** kwargs ):
13471421 model_args = dict (block = BasicBlock , layers = [2 , 2 , 2 , 2 ], block_args = dict (attn_layer = 'se' ), ** kwargs )
0 commit comments