1212
1313from .registry import register_model
1414from .helpers import load_pretrained
15- from .layers import SelectAdaptivePool2d , DropBlock2d , DropPath , AvgPool2dSame , create_attn
15+ from .layers import SelectAdaptivePool2d , DropBlock2d , DropPath , AvgPool2dSame , create_attn , BlurPool2d
1616from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
1717
1818
@@ -104,6 +104,8 @@ def _cfg(url='', **kwargs):
104104 interpolation = 'bicubic' ),
105105 'ecaresnet18' : _cfg (),
106106 'ecaresnet50' : _cfg (),
107+ 'resnetblur18' : _cfg (),
108+ 'resnetblur50' : _cfg ()
107109}
108110
109111
@@ -117,18 +119,27 @@ class BasicBlock(nn.Module):
117119
118120 def __init__ (self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 ,
119121 reduce_first = 1 , dilation = 1 , first_dilation = None , act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d ,
120- attn_layer = None , drop_block = None , drop_path = None ):
122+ attn_layer = None , drop_block = None , drop_path = None , blur = False ):
121123 super (BasicBlock , self ).__init__ ()
122124
123125 assert cardinality == 1 , 'BasicBlock only supports cardinality of 1'
124126 assert base_width == 64 , 'BasicBlock doest not support changing base width'
125127 first_planes = planes // reduce_first
126128 outplanes = planes * self .expansion
127129 first_dilation = first_dilation or dilation
130+ self .blur = blur
128131
129- self .conv1 = nn .Conv2d (
132+ if blur and stride == 2 :
133+ self .conv1 = nn .Conv2d (
134+ inplanes , first_planes , kernel_size = 3 , stride = 1 , padding = first_dilation ,
135+ dilation = first_dilation , bias = False )
136+ self .blurpool = BlurPool2d (channels = first_planes )
137+ else :
138+ self .conv1 = nn .Conv2d (
130139 inplanes , first_planes , kernel_size = 3 , stride = stride , padding = first_dilation ,
131140 dilation = first_dilation , bias = False )
141+ self .blurpool = None
142+
132143 self .bn1 = norm_layer (first_planes )
133144 self .act1 = act_layer (inplace = True )
134145 self .conv2 = nn .Conv2d (
@@ -154,7 +165,11 @@ def forward(self, x):
154165 x = self .bn1 (x )
155166 if self .drop_block is not None :
156167 x = self .drop_block (x )
157- x = self .act1 (x )
168+ if self .blurpool is not None :
169+ x = self .act1 (x )
170+ x = self .blurpool (x )
171+ else :
172+ x = self .act1 (x )
158173
159174 x = self .conv2 (x )
160175 x = self .bn2 (x )
@@ -181,20 +196,30 @@ class Bottleneck(nn.Module):
181196
182197 def __init__ (self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 ,
183198 reduce_first = 1 , dilation = 1 , first_dilation = None , act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d ,
184- attn_layer = None , drop_block = None , drop_path = None ):
199+ attn_layer = None , drop_block = None , drop_path = None , blur = False ):
185200 super (Bottleneck , self ).__init__ ()
186201
187202 width = int (math .floor (planes * (base_width / 64 )) * cardinality )
188203 first_planes = width // reduce_first
189204 outplanes = planes * self .expansion
190205 first_dilation = first_dilation or dilation
206+ self .blur = blur
191207
192208 self .conv1 = nn .Conv2d (inplanes , first_planes , kernel_size = 1 , bias = False )
193209 self .bn1 = norm_layer (first_planes )
194210 self .act1 = act_layer (inplace = True )
195- self .conv2 = nn .Conv2d (
196- first_planes , width , kernel_size = 3 , stride = stride ,
197- padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
211+
212+ if blur and stride == 2 :
213+ self .conv2 = nn .Conv2d (
214+ first_planes , width , kernel_size = 3 , stride = 1 ,
215+ padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
216+ self .blurpool = BlurPool2d (channels = width )
217+ else :
218+ self .conv2 = nn .Conv2d (
219+ first_planes , width , kernel_size = 3 , stride = stride ,
220+ padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
221+ self .blurpool = None
222+
198223 self .bn2 = norm_layer (width )
199224 self .act2 = act_layer (inplace = True )
200225 self .conv3 = nn .Conv2d (width , outplanes , kernel_size = 1 , bias = False )
@@ -345,12 +370,19 @@ class ResNet(nn.Module):
345370 Dropout probability before classifier, for training
346371 global_pool : str, default 'avg'
347372 Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
373+ blur : str, default ''
374+ Location of Blurring:
375+ * '', default - Not applied
376+ * 'max' - only stem layer MaxPool will be blurred
377+ * 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style)
378+ * 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets)
379+
348380 """
349381 def __init__ (self , block , layers , num_classes = 1000 , in_chans = 3 ,
350382 cardinality = 1 , base_width = 64 , stem_width = 64 , stem_type = '' ,
351383 block_reduce_first = 1 , down_kernel_size = 1 , avg_down = False , output_stride = 32 ,
352384 act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , drop_rate = 0.0 , drop_path_rate = 0. ,
353- drop_block_rate = 0. , global_pool = 'avg' , zero_init_last_bn = True , block_args = None ):
385+ drop_block_rate = 0. , global_pool = 'avg' , blur = '' , zero_init_last_bn = True , block_args = None ):
354386 block_args = block_args or dict ()
355387 self .num_classes = num_classes
356388 deep_stem = 'deep' in stem_type
@@ -359,6 +391,7 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
359391 self .base_width = base_width
360392 self .drop_rate = drop_rate
361393 self .expansion = block .expansion
394+ self .blur = 'strided' in blur
362395 super (ResNet , self ).__init__ ()
363396
364397 # Stem
@@ -379,7 +412,13 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
379412 self .conv1 = nn .Conv2d (in_chans , self .inplanes , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
380413 self .bn1 = norm_layer (self .inplanes )
381414 self .act1 = act_layer (inplace = True )
382- self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
415+ # Stem Blur
416+ if 'max' in blur :
417+ self .maxpool = nn .Sequential (* [
418+ nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
419+ BlurPool2d (channels = self .inplanes )])
420+ else :
421+ self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
383422
384423 # Feature Blocks
385424 dp = DropPath (drop_path_rate ) if drop_path_rate else None
@@ -432,7 +471,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=
432471 block_kwargs = dict (
433472 cardinality = self .cardinality , base_width = self .base_width , reduce_first = reduce_first ,
434473 dilation = dilation , ** kwargs )
435- layers = [block (self .inplanes , planes , stride , downsample , first_dilation = first_dilation , ** block_kwargs )]
474+ layers = [block (self .inplanes , planes , stride , downsample , first_dilation = first_dilation , blur = self . blur , ** block_kwargs )]
436475 self .inplanes = planes * block .expansion
437476 layers += [block (self .inplanes , planes , ** block_kwargs ) for _ in range (1 , blocks )]
438477
@@ -1022,3 +1061,21 @@ def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
10221061 if pretrained :
10231062 load_pretrained (model , default_cfg , num_classes , in_chans )
10241063 return model
1064+
1065+ @register_model
1066+ def resnetblur18 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1067+ """Constructs a ResNet-18 model. With original style blur
1068+ """
1069+ default_cfg = default_cfgs ['resnetblur18' ]
1070+ model = ResNet (BasicBlock , [2 , 2 , 2 , 2 ], num_classes = num_classes , in_chans = in_chans , blur = 'max_strided' ,** kwargs )
1071+ model .default_cfg = default_cfg
1072+ return model
1073+
1074+ @register_model
1075+ def resnetblur50 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1076+ """Constructs a ResNet-50 model. With assembled-cnn style blur
1077+ """
1078+ default_cfg = default_cfgs ['resnetblur18' ]
1079+ model = ResNet (Bottleneck , [3 , 4 , 6 , 3 ], num_classes = num_classes , in_chans = in_chans , blur = 'strided' , ** kwargs )
1080+ model .default_cfg = default_cfg
1081+ return model
0 commit comments