Skip to content

Commit 3d8d745

Browse files
committed
InceptionNeXt using timm builder, more cleanup
1 parent f4cf977 commit 3d8d745

File tree

1 file changed

+78
-68
lines changed

1 file changed

+78
-68
lines changed

timm/models/inception_next.py

Lines changed: 78 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""
22
InceptionNeXt implementation, paper: https://arxiv.org/abs/2303.16900
3-
4-
Some code is borrowed from timm: https://github.com/huggingface/pytorch-image-models
53
"""
64

75
from functools import partial
@@ -11,24 +9,31 @@
119

1210
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1311
from timm.layers import trunc_normal_, DropPath, to_2tuple
12+
from ._builder import build_model_with_cfg
1413
from ._manipulate import checkpoint_seq
15-
from ._registry import register_model
14+
from ._registry import register_model, generate_default_cfgs
1615

1716

1817
class InceptionDWConv2d(nn.Module):
1918
""" Inception depthweise convolution
2019
"""
2120

22-
def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125):
21+
def __init__(
22+
self,
23+
in_chs,
24+
square_kernel_size=3,
25+
band_kernel_size=11,
26+
branch_ratio=0.125
27+
):
2328
super().__init__()
2429

25-
gc = int(in_channels * branch_ratio) # channel numbers of a convolution branch
30+
gc = int(in_chs * branch_ratio) # channel numbers of a convolution branch
2631
self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size // 2, groups=gc)
2732
self.dwconv_w = nn.Conv2d(
2833
gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), groups=gc)
2934
self.dwconv_h = nn.Conv2d(
3035
gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), groups=gc)
31-
self.split_indexes = (in_channels - 3 * gc, gc, gc, gc)
36+
self.split_indexes = (in_chs - 3 * gc, gc, gc, gc)
3237

3338
def forward(self, x):
3439
x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1)
@@ -47,8 +52,15 @@ class ConvMlp(nn.Module):
4752
"""
4853

4954
def __init__(
50-
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
51-
norm_layer=None, bias=True, drop=0.):
55+
self,
56+
in_features,
57+
hidden_features=None,
58+
out_features=None,
59+
act_layer=nn.ReLU,
60+
norm_layer=None,
61+
bias=True,
62+
drop=0.,
63+
):
5264
super().__init__()
5365
out_features = out_features or in_features
5466
hidden_features = hidden_features or in_features
@@ -69,13 +81,20 @@ def forward(self, x):
6981
return x
7082

7183

72-
class MlpHead(nn.Module):
84+
class MlpClassifierHead(nn.Module):
7385
""" MLP classification head
7486
"""
7587

7688
def __init__(
77-
self, dim, num_classes=1000, mlp_ratio=3, act_layer=nn.GELU,
78-
norm_layer=partial(nn.LayerNorm, eps=1e-6), drop=0., bias=True):
89+
self,
90+
dim,
91+
num_classes=1000,
92+
mlp_ratio=3,
93+
act_layer=nn.GELU,
94+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
95+
drop=0.,
96+
bias=True
97+
):
7998
super().__init__()
8099
hidden_features = int(mlp_ratio * dim)
81100
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
@@ -168,7 +187,6 @@ def __init__(
168187
norm_layer=norm_layer,
169188
mlp_ratio=mlp_ratio,
170189
))
171-
in_chs = out_chs
172190
self.blocks = nn.Sequential(*stage_blocks)
173191

174192
def forward(self, x):
@@ -209,11 +227,10 @@ def __init__(
209227
norm_layer=nn.BatchNorm2d,
210228
act_layer=nn.GELU,
211229
mlp_ratios=(4, 4, 4, 3),
212-
head_fn=MlpHead,
230+
head_fn=MlpClassifierHead,
213231
drop_rate=0.,
214232
drop_path_rate=0.,
215233
ls_init_value=1e-6,
216-
**kwargs,
217234
):
218235
super().__init__()
219236

@@ -255,14 +272,38 @@ def __init__(
255272
self.head = head_fn(self.num_features, num_classes, drop=drop_rate)
256273
self.apply(self._init_weights)
257274

275+
def _init_weights(self, m):
276+
if isinstance(m, (nn.Conv2d, nn.Linear)):
277+
trunc_normal_(m.weight, std=.02)
278+
if m.bias is not None:
279+
nn.init.constant_(m.bias, 0)
280+
281+
@torch.jit.ignore
282+
def group_matcher(self, coarse=False):
283+
return dict(
284+
stem=r'^stem',
285+
blocks=r'^stages\.(\d+)' if coarse else [
286+
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
287+
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
288+
]
289+
)
290+
291+
@torch.jit.ignore
292+
def get_classifier(self):
293+
return self.head.fc2
294+
295+
def reset_classifier(self, num_classes=0, global_pool=None):
296+
# FIXME
297+
self.head.reset(num_classes, global_pool)
298+
258299
@torch.jit.ignore
259300
def set_grad_checkpointing(self, enable=True):
260301
for s in self.stages:
261302
s.grad_checkpointing = enable
262303

263304
@torch.jit.ignore
264305
def no_weight_decay(self):
265-
return {'norm'}
306+
return set()
266307

267308
def forward_features(self, x):
268309
x = self.stem(x)
@@ -278,97 +319,66 @@ def forward(self, x):
278319
x = self.forward_head(x)
279320
return x
280321

281-
def _init_weights(self, m):
282-
if isinstance(m, (nn.Conv2d, nn.Linear)):
283-
trunc_normal_(m.weight, std=.02)
284-
if m.bias is not None:
285-
nn.init.constant_(m.bias, 0)
286-
287322

288323
def _cfg(url='', **kwargs):
289324
return {
290325
'url': url,
291326
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
292327
'crop_pct': 0.875, 'interpolation': 'bicubic',
293328
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
294-
'first_conv': 'stem.0', 'classifier': 'head.fc',
329+
'first_conv': 'stem.0', 'classifier': 'head.fc2',
295330
**kwargs
296331
}
297332

298333

299-
default_cfgs = dict(
300-
inception_next_tiny=_cfg(
334+
default_cfgs = generate_default_cfgs({
335+
'inception_next_tiny.sail_in1k': _cfg(
301336
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth',
302337
),
303-
inception_next_small=_cfg(
338+
'inception_next_small.sail_in1k': _cfg(
304339
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth',
305340
),
306-
inception_next_base=_cfg(
341+
'inception_next_base.sail_in1k': _cfg(
307342
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth',
343+
crop_pct=0.95,
308344
),
309-
inception_next_base_384=_cfg(
345+
'inception_next_base.sail_in1k_384': _cfg(
310346
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth',
311347
input_size=(3, 384, 384), crop_pct=1.0,
312348
),
313-
)
349+
})
350+
351+
352+
def _create_inception_next(variant, pretrained=False, **kwargs):
353+
model = build_model_with_cfg(
354+
MetaNeXt, variant, pretrained,
355+
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
356+
**kwargs)
357+
return model
314358

315359

316360
@register_model
317361
def inception_next_tiny(pretrained=False, **kwargs):
318-
model = MetaNeXt(
362+
model_args = dict(
319363
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768),
320364
token_mixers=InceptionDWConv2d,
321-
**kwargs
322365
)
323-
model.default_cfg = default_cfgs['inception_next_tiny']
324-
if pretrained:
325-
state_dict = torch.hub.load_state_dict_from_url(
326-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
327-
model.load_state_dict(state_dict)
328-
return model
366+
return _create_inception_next('inception_next_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
329367

330368

331369
@register_model
332370
def inception_next_small(pretrained=False, **kwargs):
333-
model = MetaNeXt(
371+
model_args = dict(
334372
depths=(3, 3, 27, 3), dims=(96, 192, 384, 768),
335373
token_mixers=InceptionDWConv2d,
336-
**kwargs
337374
)
338-
model.default_cfg = default_cfgs['inception_next_small']
339-
if pretrained:
340-
state_dict = torch.hub.load_state_dict_from_url(
341-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
342-
model.load_state_dict(state_dict)
343-
return model
375+
return _create_inception_next('inception_next_small', pretrained=pretrained, **dict(model_args, **kwargs))
344376

345377

346378
@register_model
347379
def inception_next_base(pretrained=False, **kwargs):
348-
model = MetaNeXt(
380+
model_args = dict(
349381
depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024),
350382
token_mixers=InceptionDWConv2d,
351-
**kwargs
352383
)
353-
model.default_cfg = default_cfgs['inception_next_base']
354-
if pretrained:
355-
state_dict = torch.hub.load_state_dict_from_url(
356-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
357-
model.load_state_dict(state_dict)
358-
return model
359-
360-
361-
@register_model
362-
def inception_next_base_384(pretrained=False, **kwargs):
363-
model = MetaNeXt(
364-
depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024],
365-
mlp_ratios=[4, 4, 4, 3],
366-
token_mixers=InceptionDWConv2d,
367-
**kwargs
368-
)
369-
model.default_cfg = default_cfgs['inception_next_base_384']
370-
if pretrained:
371-
state_dict = torch.hub.load_state_dict_from_url(
372-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
373-
model.load_state_dict(state_dict)
374-
return model
384+
return _create_inception_next('inception_next_base', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)