1212
1313from .registry import register_model
1414from .helpers import load_pretrained , adapt_model_from_file
15- from .layers import SelectAdaptivePool2d , DropBlock2d , DropPath , AvgPool2dSame , create_attn
15+ from .layers import SelectAdaptivePool2d , DropBlock2d , DropPath , AvgPool2dSame , create_attn , BlurPool2d
1616from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
1717
1818__all__ = ['ResNet' , 'BasicBlock' , 'Bottleneck' ] # model_registry will add each entrypoint fn to this
@@ -118,6 +118,11 @@ def _cfg(url='', **kwargs):
118118 'ecaresnet101d_pruned' : _cfg (
119119 url = 'https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth' ,
120120 interpolation = 'bicubic' ),
121+ 'resnetblur18' : _cfg (
122+ interpolation = 'bicubic' ),
123+ 'resnetblur50' : _cfg (
124+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth' ,
125+ interpolation = 'bicubic' )
121126}
122127
123128
@@ -131,20 +136,23 @@ class BasicBlock(nn.Module):
131136
132137 def __init__ (self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 ,
133138 reduce_first = 1 , dilation = 1 , first_dilation = None , act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d ,
134- attn_layer = None , drop_block = None , drop_path = None ):
139+ attn_layer = None , aa_layer = None , drop_block = None , drop_path = None ):
135140 super (BasicBlock , self ).__init__ ()
136141
137142 assert cardinality == 1 , 'BasicBlock only supports cardinality of 1'
138143 assert base_width == 64 , 'BasicBlock does not support changing base width'
139144 first_planes = planes // reduce_first
140145 outplanes = planes * self .expansion
141146 first_dilation = first_dilation or dilation
147+ use_aa = aa_layer is not None
142148
143149 self .conv1 = nn .Conv2d (
144- inplanes , first_planes , kernel_size = 3 , stride = stride , padding = first_dilation ,
150+ inplanes , first_planes , kernel_size = 3 , stride = 1 if use_aa else stride , padding = first_dilation ,
145151 dilation = first_dilation , bias = False )
146152 self .bn1 = norm_layer (first_planes )
147153 self .act1 = act_layer (inplace = True )
154+ self .aa = aa_layer (channels = first_planes ) if stride == 2 and use_aa else None
155+
148156 self .conv2 = nn .Conv2d (
149157 first_planes , outplanes , kernel_size = 3 , padding = dilation , dilation = dilation , bias = False )
150158 self .bn2 = norm_layer (outplanes )
@@ -169,6 +177,8 @@ def forward(self, x):
169177 if self .drop_block is not None :
170178 x = self .drop_block (x )
171179 x = self .act1 (x )
180+ if self .aa is not None :
181+ x = self .aa (x )
172182
173183 x = self .conv2 (x )
174184 x = self .bn2 (x )
@@ -195,22 +205,26 @@ class Bottleneck(nn.Module):
195205
196206 def __init__ (self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 ,
197207 reduce_first = 1 , dilation = 1 , first_dilation = None , act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d ,
198- attn_layer = None , drop_block = None , drop_path = None ):
208+ attn_layer = None , aa_layer = None , drop_block = None , drop_path = None ):
199209 super (Bottleneck , self ).__init__ ()
200210
201211 width = int (math .floor (planes * (base_width / 64 )) * cardinality )
202212 first_planes = width // reduce_first
203213 outplanes = planes * self .expansion
204214 first_dilation = first_dilation or dilation
215+ use_aa = aa_layer is not None
205216
206217 self .conv1 = nn .Conv2d (inplanes , first_planes , kernel_size = 1 , bias = False )
207218 self .bn1 = norm_layer (first_planes )
208219 self .act1 = act_layer (inplace = True )
220+
209221 self .conv2 = nn .Conv2d (
210- first_planes , width , kernel_size = 3 , stride = stride ,
222+ first_planes , width , kernel_size = 3 , stride = 1 if use_aa else stride ,
211223 padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
212224 self .bn2 = norm_layer (width )
213225 self .act2 = act_layer (inplace = True )
226+ self .aa = aa_layer (channels = width ) if stride == 2 and use_aa else None
227+
214228 self .conv3 = nn .Conv2d (width , outplanes , kernel_size = 1 , bias = False )
215229 self .bn3 = norm_layer (outplanes )
216230
@@ -240,6 +254,8 @@ def forward(self, x):
240254 if self .drop_block is not None :
241255 x = self .drop_block (x )
242256 x = self .act2 (x )
257+ if self .aa is not None :
258+ x = self .aa (x )
243259
244260 x = self .conv3 (x )
245261 x = self .bn3 (x )
@@ -353,8 +369,9 @@ class ResNet(nn.Module):
353369 Whether to use average pooling for projection skip connection between stages/downsample.
354370 output_stride : int, default 32
355371 Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
356- act_layer : class, activation layer
357- norm_layer : class, normalization layer
372+ act_layer : nn.Module, activation layer
373+ norm_layer : nn.Module, normalization layer
374+ aa_layer : nn.Module, anti-aliasing layer
358375 drop_rate : float, default 0.
359376 Dropout probability before classifier, for training
360377 global_pool : str, default 'avg'
@@ -363,7 +380,7 @@ class ResNet(nn.Module):
363380 def __init__ (self , block , layers , num_classes = 1000 , in_chans = 3 ,
364381 cardinality = 1 , base_width = 64 , stem_width = 64 , stem_type = '' ,
365382 block_reduce_first = 1 , down_kernel_size = 1 , avg_down = False , output_stride = 32 ,
366- act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , drop_rate = 0.0 , drop_path_rate = 0. ,
383+ act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , aa_layer = None , drop_rate = 0.0 , drop_path_rate = 0. ,
367384 drop_block_rate = 0. , global_pool = 'avg' , zero_init_last_bn = True , block_args = None ):
368385 block_args = block_args or dict ()
369386 self .num_classes = num_classes
@@ -393,7 +410,14 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
393410 self .conv1 = nn .Conv2d (in_chans , self .inplanes , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
394411 self .bn1 = norm_layer (self .inplanes )
395412 self .act1 = act_layer (inplace = True )
396- self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
413+ # Stem Pooling
414+ if aa_layer is not None :
415+ self .maxpool = nn .Sequential (* [
416+ nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
417+ aa_layer (channels = self .inplanes , stride = 2 )
418+ ])
419+ else :
420+ self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
397421
398422 # Feature Blocks
399423 dp = DropPath (drop_path_rate ) if drop_path_rate else None
@@ -410,7 +434,7 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
410434 assert output_stride == 32
411435 layer_args = list (zip (channels , layers , strides , dilations ))
412436 layer_kwargs = dict (
413- reduce_first = block_reduce_first , act_layer = act_layer , norm_layer = norm_layer ,
437+ reduce_first = block_reduce_first , act_layer = act_layer , norm_layer = norm_layer , aa_layer = aa_layer ,
414438 avg_down = avg_down , down_kernel_size = down_kernel_size , drop_path = dp , ** block_args )
415439 self .layer1 = self ._make_layer (block , * layer_args [0 ], ** layer_kwargs )
416440 self .layer2 = self ._make_layer (block , * layer_args [1 ], ** layer_kwargs )
@@ -1114,3 +1138,29 @@ def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwarg
11141138 if pretrained :
11151139 load_pretrained (model , default_cfg , num_classes , in_chans )
11161140 return model
1141+
1142+
1143+ @register_model
1144+ def resnetblur18 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1145+ """Constructs a ResNet-18 model with blur anti-aliasing
1146+ """
1147+ default_cfg = default_cfgs ['resnetblur18' ]
1148+ model = ResNet (
1149+ BasicBlock , [2 , 2 , 2 , 2 ], num_classes = num_classes , in_chans = in_chans , aa_layer = BlurPool2d , ** kwargs )
1150+ model .default_cfg = default_cfg
1151+ if pretrained :
1152+ load_pretrained (model , default_cfg , num_classes , in_chans )
1153+ return model
1154+
1155+
1156+ @register_model
1157+ def resnetblur50 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1158+ """Constructs a ResNet-50 model with blur anti-aliasing
1159+ """
1160+ default_cfg = default_cfgs ['resnetblur50' ]
1161+ model = ResNet (
1162+ Bottleneck , [3 , 4 , 6 , 3 ], num_classes = num_classes , in_chans = in_chans , aa_layer = BlurPool2d , ** kwargs )
1163+ model .default_cfg = default_cfg
1164+ if pretrained :
1165+ load_pretrained (model , default_cfg , num_classes , in_chans )
1166+ return model
0 commit comments