@@ -118,8 +118,11 @@ def _cfg(url='', **kwargs):
118118 'ecaresnet101d_pruned' : _cfg (
119119 url = 'https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth' ,
120120 interpolation = 'bicubic' ),
121- 'resnetblur18' : _cfg (),
122- 'resnetblur50' : _cfg ()
121+ 'resnetblur18' : _cfg (
122+ interpolation = 'bicubic' ),
123+ 'resnetblur50' : _cfg (
124+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth' ,
125+ interpolation = 'bicubic' )
123126}
124127
125128
@@ -133,21 +136,22 @@ class BasicBlock(nn.Module):
133136
134137 def __init__ (self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 ,
135138 reduce_first = 1 , dilation = 1 , first_dilation = None , act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d ,
136- attn_layer = None , drop_block = None , drop_path = None , blur = False ):
139+ attn_layer = None , aa_layer = None , drop_block = None , drop_path = None ):
137140 super (BasicBlock , self ).__init__ ()
138141
139142 assert cardinality == 1 , 'BasicBlock only supports cardinality of 1'
140143 assert base_width == 64 , 'BasicBlock does not support changing base width'
141144 first_planes = planes // reduce_first
142145 outplanes = planes * self .expansion
143146 first_dilation = first_dilation or dilation
147+ use_aa = aa_layer is not None
144148
145149 self .conv1 = nn .Conv2d (
146- inplanes , first_planes , kernel_size = 3 , stride = 1 if blur else stride , padding = first_dilation ,
150+ inplanes , first_planes , kernel_size = 3 , stride = 1 if use_aa else stride , padding = first_dilation ,
147151 dilation = first_dilation , bias = False )
148152 self .bn1 = norm_layer (first_planes )
149153 self .act1 = act_layer (inplace = True )
150- self .blurpool = BlurPool2d (channels = first_planes ) if stride == 2 and blur else None
154+ self .aa = aa_layer (channels = first_planes ) if stride == 2 and use_aa else None
151155
152156 self .conv2 = nn .Conv2d (
153157 first_planes , outplanes , kernel_size = 3 , padding = dilation , dilation = dilation , bias = False )
@@ -173,8 +177,8 @@ def forward(self, x):
173177 if self .drop_block is not None :
174178 x = self .drop_block (x )
175179 x = self .act1 (x )
176- if self .blurpool is not None :
177- x = self .blurpool (x )
180+ if self .aa is not None :
181+ x = self .aa (x )
178182
179183 x = self .conv2 (x )
180184 x = self .bn2 (x )
@@ -201,25 +205,25 @@ class Bottleneck(nn.Module):
201205
202206 def __init__ (self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 ,
203207 reduce_first = 1 , dilation = 1 , first_dilation = None , act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d ,
204- attn_layer = None , drop_block = None , drop_path = None , blur = False ):
208+ attn_layer = None , aa_layer = None , drop_block = None , drop_path = None ):
205209 super (Bottleneck , self ).__init__ ()
206210
207211 width = int (math .floor (planes * (base_width / 64 )) * cardinality )
208212 first_planes = width // reduce_first
209213 outplanes = planes * self .expansion
210214 first_dilation = first_dilation or dilation
211- self . blur = blur
215+ use_aa = aa_layer is not None
212216
213217 self .conv1 = nn .Conv2d (inplanes , first_planes , kernel_size = 1 , bias = False )
214218 self .bn1 = norm_layer (first_planes )
215219 self .act1 = act_layer (inplace = True )
216220
217221 self .conv2 = nn .Conv2d (
218- first_planes , width , kernel_size = 3 , stride = 1 if blur else stride ,
222+ first_planes , width , kernel_size = 3 , stride = 1 if use_aa else stride ,
219223 padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
220224 self .bn2 = norm_layer (width )
221225 self .act2 = act_layer (inplace = True )
222- self .blurpool = BlurPool2d (channels = width ) if stride == 2 and blur else None
226+ self .aa = aa_layer (channels = width ) if stride == 2 and use_aa else None
223227
224228 self .conv3 = nn .Conv2d (width , outplanes , kernel_size = 1 , bias = False )
225229 self .bn3 = norm_layer (outplanes )
@@ -250,8 +254,8 @@ def forward(self, x):
250254 if self .drop_block is not None :
251255 x = self .drop_block (x )
252256 x = self .act2 (x )
253- if self .blurpool is not None :
254- x = self .blurpool (x )
257+ if self .aa is not None :
258+ x = self .aa (x )
255259
256260 x = self .conv3 (x )
257261 x = self .bn3 (x )
@@ -365,25 +369,19 @@ class ResNet(nn.Module):
365369 Whether to use average pooling for projection skip connection between stages/downsample.
366370 output_stride : int, default 32
367371 Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
368- act_layer : class, activation layer
369- norm_layer : class, normalization layer
372+ act_layer : nn.Module, activation layer
373+ norm_layer : nn.Module, normalization layer
374+ aa_layer : nn.Module, anti-aliasing layer
370375 drop_rate : float, default 0.
371376 Dropout probability before classifier, for training
372377 global_pool : str, default 'avg'
373378 Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
374- blur : str, default ''
375- Location of Blurring:
376- * '', default - Not applied
377- * 'max' - only stem layer MaxPool will be blurred
378- * 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style)
379- * 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets)
380-
381379 """
382380 def __init__ (self , block , layers , num_classes = 1000 , in_chans = 3 ,
383381 cardinality = 1 , base_width = 64 , stem_width = 64 , stem_type = '' ,
384382 block_reduce_first = 1 , down_kernel_size = 1 , avg_down = False , output_stride = 32 ,
385- act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , drop_rate = 0.0 , drop_path_rate = 0. ,
386- drop_block_rate = 0. , global_pool = 'avg' , blur = '' , zero_init_last_bn = True , block_args = None ):
383+ act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , aa_layer = None , drop_rate = 0.0 , drop_path_rate = 0. ,
384+ drop_block_rate = 0. , global_pool = 'avg' , zero_init_last_bn = True , block_args = None ):
387385 block_args = block_args or dict ()
388386 self .num_classes = num_classes
389387 deep_stem = 'deep' in stem_type
@@ -392,7 +390,6 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
392390 self .base_width = base_width
393391 self .drop_rate = drop_rate
394392 self .expansion = block .expansion
395- self .blur = 'strided' in blur
396393 super (ResNet , self ).__init__ ()
397394
398395 # Stem
@@ -414,12 +411,12 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
414411 self .bn1 = norm_layer (self .inplanes )
415412 self .act1 = act_layer (inplace = True )
416413 # Stem Pooling
417- if 'max' in blur :
414+ if aa_layer is not None :
418415 self .maxpool = nn .Sequential (* [
419416 nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
420- BlurPool2d (channels = self .inplanes , stride = 2 )
417+ aa_layer (channels = self .inplanes , stride = 2 )
421418 ])
422- else :
419+ else :
423420 self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
424421
425422 # Feature Blocks
@@ -437,7 +434,7 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
437434 assert output_stride == 32
438435 layer_args = list (zip (channels , layers , strides , dilations ))
439436 layer_kwargs = dict (
440- reduce_first = block_reduce_first , act_layer = act_layer , norm_layer = norm_layer ,
437+ reduce_first = block_reduce_first , act_layer = act_layer , norm_layer = norm_layer , aa_layer = aa_layer ,
441438 avg_down = avg_down , down_kernel_size = down_kernel_size , drop_path = dp , ** block_args )
442439 self .layer1 = self ._make_layer (block , * layer_args [0 ], ** layer_kwargs )
443440 self .layer2 = self ._make_layer (block , * layer_args [1 ], ** layer_kwargs )
@@ -472,7 +469,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=
472469
473470 block_kwargs = dict (
474471 cardinality = self .cardinality , base_width = self .base_width , reduce_first = reduce_first ,
475- dilation = dilation , blur = self . blur , ** kwargs )
472+ dilation = dilation , ** kwargs )
476473 layers = [block (self .inplanes , planes , stride , downsample , first_dilation = first_dilation , ** block_kwargs )]
477474 self .inplanes = planes * block .expansion
478475 layers += [block (self .inplanes , planes , ** block_kwargs ) for _ in range (1 , blocks )]
@@ -1148,18 +1145,21 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
11481145 """Constructs a ResNet-18 model with blur anti-aliasing
11491146 """
11501147 default_cfg = default_cfgs ['resnetblur18' ]
1151- model = ResNet (BasicBlock , [2 , 2 , 2 , 2 ], num_classes = num_classes , in_chans = in_chans , blur = 'max_strided' ,** kwargs )
1148+ model = ResNet (
1149+ BasicBlock , [2 , 2 , 2 , 2 ], num_classes = num_classes , in_chans = in_chans , aa_layer = BlurPool2d , ** kwargs )
11521150 model .default_cfg = default_cfg
11531151 if pretrained :
11541152 load_pretrained (model , default_cfg , num_classes , in_chans )
11551153 return model
11561154
1155+
11571156@register_model
11581157def resnetblur50 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
11591158 """Constructs a ResNet-50 model with blur anti-aliasing
11601159 """
11611160 default_cfg = default_cfgs ['resnetblur50' ]
1162- model = ResNet (Bottleneck , [3 , 4 , 6 , 3 ], num_classes = num_classes , in_chans = in_chans , blur = 'max_strided' , ** kwargs )
1161+ model = ResNet (
1162+ Bottleneck , [3 , 4 , 6 , 3 ], num_classes = num_classes , in_chans = in_chans , aa_layer = BlurPool2d , ** kwargs )
11631163 model .default_cfg = default_cfg
11641164 if pretrained :
11651165 load_pretrained (model , default_cfg , num_classes , in_chans )
0 commit comments