@@ -221,6 +221,7 @@ def __init__(
221221 self ,
222222 in_chans = 3 ,
223223 num_classes = 1000 ,
224+ output_stride = 32 ,
224225 depths = (3 , 3 , 9 , 3 ),
225226 dims = (96 , 192 , 384 , 768 ),
226227 token_mixers = nn .Identity ,
@@ -239,22 +240,30 @@ def __init__(
239240 token_mixers = [token_mixers ] * num_stage
240241 if not isinstance (mlp_ratios , (list , tuple )):
241242 mlp_ratios = [mlp_ratios ] * num_stage
242-
243243 self .num_classes = num_classes
244244 self .drop_rate = drop_rate
245+ self .feature_info = []
246+
245247 self .stem = nn .Sequential (
246248 nn .Conv2d (in_chans , dims [0 ], kernel_size = 4 , stride = 4 ),
247249 norm_layer (dims [0 ])
248250 )
249251
250- self .stages = nn .Sequential ()
251252 dp_rates = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
252- stages = []
253253 prev_chs = dims [0 ]
254+ curr_stride = 4
255+ dilation = 1
254256 # feature resolution stages, each consisting of multiple residual blocks
257+ self .stages = nn .Sequential ()
255258 for i in range (num_stage ):
259+ stride = 2 if curr_stride == 2 or i > 0 else 1
260+ if curr_stride >= output_stride and stride > 1 :
261+ dilation *= stride
262+ stride = 1
263+ curr_stride *= stride
264+ first_dilation = 1 if dilation in (1 , 2 ) else 2
256265 out_chs = dims [i ]
257- stages .append (MetaNeXtStage (
266+ self . stages .append (MetaNeXtStage (
258267 prev_chs ,
259268 out_chs ,
260269 ds_stride = 2 if i > 0 else 1 ,
@@ -267,7 +276,7 @@ def __init__(
267276 mlp_ratio = mlp_ratios [i ],
268277 ))
269278 prev_chs = out_chs
270- self .stages = nn . Sequential ( * stages )
279+ self .feature_info += [ dict ( num_chs = prev_chs , reduction = curr_stride , module = f' stages. { i } ' )]
271280 self .num_features = prev_chs
272281 self .head = head_fn (self .num_features , num_classes , drop = drop_rate )
273282 self .apply (self ._init_weights )
@@ -353,7 +362,8 @@ def _create_inception_next(variant, pretrained=False, **kwargs):
353362 model = build_model_with_cfg (
354363 MetaNeXt , variant , pretrained ,
355364 feature_cfg = dict (out_indices = (0 , 1 , 2 , 3 ), flatten_sequential = True ),
356- ** kwargs )
365+ ** kwargs ,
366+ )
357367 return model
358368
359369
0 commit comments