@@ -127,21 +127,14 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
127127 first_planes = planes // reduce_first
128128 outplanes = planes * self .expansion
129129 first_dilation = first_dilation or dilation
130- self .blur = blur
131130
132- if blur and stride == 2 :
133- self .conv1 = nn .Conv2d (
134- inplanes , first_planes , kernel_size = 3 , stride = 1 , padding = first_dilation ,
135- dilation = first_dilation , bias = False )
136- self .blurpool = BlurPool2d (channels = first_planes )
137- else :
138- self .conv1 = nn .Conv2d (
139- inplanes , first_planes , kernel_size = 3 , stride = stride , padding = first_dilation ,
131+ self .conv1 = nn .Conv2d (
132+ inplanes , first_planes , kernel_size = 3 , stride = 1 if blur else stride , padding = first_dilation ,
140133 dilation = first_dilation , bias = False )
141- self .blurpool = None
142-
143134 self .bn1 = norm_layer (first_planes )
144135 self .act1 = act_layer (inplace = True )
136+ self .blurpool = BlurPool2d (channels = first_planes ) if stride == 2 and blur else None
137+
145138 self .conv2 = nn .Conv2d (
146139 first_planes , outplanes , kernel_size = 3 , padding = dilation , dilation = dilation , bias = False )
147140 self .bn2 = norm_layer (outplanes )
@@ -165,11 +158,9 @@ def forward(self, x):
165158 x = self .bn1 (x )
166159 if self .drop_block is not None :
167160 x = self .drop_block (x )
161+ x = self .act1 (x )
168162 if self .blurpool is not None :
169- x = self .act1 (x )
170163 x = self .blurpool (x )
171- else :
172- x = self .act1 (x )
173164
174165 x = self .conv2 (x )
175166 x = self .bn2 (x )
@@ -209,19 +200,13 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
209200 self .bn1 = norm_layer (first_planes )
210201 self .act1 = act_layer (inplace = True )
211202
212- if blur and stride == 2 :
213- self .conv2 = nn .Conv2d (
214- first_planes , width , kernel_size = 3 , stride = 1 ,
215- padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
216- self .blurpool = BlurPool2d (channels = width )
217- else :
218- self .conv2 = nn .Conv2d (
219- first_planes , width , kernel_size = 3 , stride = stride ,
220- padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
221- self .blurpool = None
222-
203+ self .conv2 = nn .Conv2d (
204+ first_planes , width , kernel_size = 3 , stride = 1 if blur else stride ,
205+ padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
223206 self .bn2 = norm_layer (width )
224207 self .act2 = act_layer (inplace = True )
208+ self .blurpool = BlurPool2d (channels = width ) if stride == 2 and blur else None
209+
225210 self .conv3 = nn .Conv2d (width , outplanes , kernel_size = 1 , bias = False )
226211 self .bn3 = norm_layer (outplanes )
227212
@@ -251,6 +236,8 @@ def forward(self, x):
251236 if self .drop_block is not None :
252237 x = self .drop_block (x )
253238 x = self .act2 (x )
239+ if self .blurpool is not None :
240+ x = self .blurpool (x )
254241
255242 x = self .conv3 (x )
256243 x = self .bn3 (x )
@@ -412,11 +399,12 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
412399 self .conv1 = nn .Conv2d (in_chans , self .inplanes , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
413400 self .bn1 = norm_layer (self .inplanes )
414401 self .act1 = act_layer (inplace = True )
415- # Stem Blur
402+ # Stem Pooling
416403 if 'max' in blur :
417404 self .maxpool = nn .Sequential (* [
418- nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
419- BlurPool2d (channels = self .inplanes )])
405+ nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
406+ BlurPool2d (channels = self .inplanes , stride = 2 )
407+ ])
420408 else :
421409 self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
422410
@@ -470,8 +458,8 @@ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=
470458
471459 block_kwargs = dict (
472460 cardinality = self .cardinality , base_width = self .base_width , reduce_first = reduce_first ,
473- dilation = dilation , ** kwargs )
474- layers = [block (self .inplanes , planes , stride , downsample , first_dilation = first_dilation , blur = self . blur , ** block_kwargs )]
461+ dilation = dilation , blur = self . blur , ** kwargs )
462+ layers = [block (self .inplanes , planes , stride , downsample , first_dilation = first_dilation , ** block_kwargs )]
475463 self .inplanes = planes * block .expansion
476464 layers += [block (self .inplanes , planes , ** block_kwargs ) for _ in range (1 , blocks )]
477465
@@ -1075,7 +1063,7 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
10751063def resnetblur50 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
10761064 """Constructs a ResNet-50 model. With assembled-cnn style blur
10771065 """
1078- default_cfg = default_cfgs ['resnetblur18 ' ]
1079- model = ResNet (Bottleneck , [3 , 4 , 6 , 3 ], num_classes = num_classes , in_chans = in_chans , blur = 'strided ' , ** kwargs )
1066+ default_cfg = default_cfgs ['resnetblur50 ' ]
1067+ model = ResNet (Bottleneck , [3 , 4 , 6 , 3 ], num_classes = num_classes , in_chans = in_chans , blur = 'max_strided ' , ** kwargs )
10801068 model .default_cfg = default_cfg
10811069 return model
0 commit comments