Skip to content

Commit 2d33b9d

Browse files
committed
Add features_only support to inception_next
1 parent 3d8d745 commit 2d33b9d

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

timm/models/inception_next.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def __init__(
221221
self,
222222
in_chans=3,
223223
num_classes=1000,
224+
output_stride=32,
224225
depths=(3, 3, 9, 3),
225226
dims=(96, 192, 384, 768),
226227
token_mixers=nn.Identity,
@@ -239,22 +240,30 @@ def __init__(
239240
token_mixers = [token_mixers] * num_stage
240241
if not isinstance(mlp_ratios, (list, tuple)):
241242
mlp_ratios = [mlp_ratios] * num_stage
242-
243243
self.num_classes = num_classes
244244
self.drop_rate = drop_rate
245+
self.feature_info = []
246+
245247
self.stem = nn.Sequential(
246248
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
247249
norm_layer(dims[0])
248250
)
249251

250-
self.stages = nn.Sequential()
251252
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
252-
stages = []
253253
prev_chs = dims[0]
254+
curr_stride = 4
255+
dilation = 1
254256
# feature resolution stages, each consisting of multiple residual blocks
257+
self.stages = nn.Sequential()
255258
for i in range(num_stage):
259+
stride = 2 if curr_stride == 2 or i > 0 else 1
260+
if curr_stride >= output_stride and stride > 1:
261+
dilation *= stride
262+
stride = 1
263+
curr_stride *= stride
264+
first_dilation = 1 if dilation in (1, 2) else 2
256265
out_chs = dims[i]
257-
stages.append(MetaNeXtStage(
266+
self.stages.append(MetaNeXtStage(
258267
prev_chs,
259268
out_chs,
260269
ds_stride=2 if i > 0 else 1,
@@ -267,7 +276,7 @@ def __init__(
267276
mlp_ratio=mlp_ratios[i],
268277
))
269278
prev_chs = out_chs
270-
self.stages = nn.Sequential(*stages)
279+
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
271280
self.num_features = prev_chs
272281
self.head = head_fn(self.num_features, num_classes, drop=drop_rate)
273282
self.apply(self._init_weights)
@@ -353,7 +362,8 @@ def _create_inception_next(variant, pretrained=False, **kwargs):
353362
model = build_model_with_cfg(
354363
MetaNeXt, variant, pretrained,
355364
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
356-
**kwargs)
365+
**kwargs,
366+
)
357367
return model
358368

359369

0 commit comments

Comments
 (0)