Skip to content

Commit f4cf977

Browse files
committed
Adding InceptionNeXt
1 parent d2e3c09 commit f4cf977

File tree

2 files changed

+375
-0
lines changed

2 files changed

+375
-0
lines changed

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .ghostnet import *
2727
from .hardcorenas import *
2828
from .hrnet import *
29+
from .inception_next import *
2930
from .inception_resnet_v2 import *
3031
from .inception_v3 import *
3132
from .inception_v4 import *

timm/models/inception_next.py

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
"""
2+
InceptionNeXt implementation, paper: https://arxiv.org/abs/2303.16900
3+
4+
Some code is borrowed from timm: https://github.com/huggingface/pytorch-image-models
5+
"""
6+
7+
from functools import partial
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13+
from timm.layers import trunc_normal_, DropPath, to_2tuple
14+
from ._manipulate import checkpoint_seq
15+
from ._registry import register_model
16+
17+
18+
class InceptionDWConv2d(nn.Module):
19+
""" Inception depthweise convolution
20+
"""
21+
22+
def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125):
23+
super().__init__()
24+
25+
gc = int(in_channels * branch_ratio) # channel numbers of a convolution branch
26+
self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size // 2, groups=gc)
27+
self.dwconv_w = nn.Conv2d(
28+
gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), groups=gc)
29+
self.dwconv_h = nn.Conv2d(
30+
gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), groups=gc)
31+
self.split_indexes = (in_channels - 3 * gc, gc, gc, gc)
32+
33+
def forward(self, x):
34+
x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1)
35+
return torch.cat((
36+
x_id,
37+
self.dwconv_hw(x_hw),
38+
self.dwconv_w(x_w),
39+
self.dwconv_h(x_h)
40+
), dim=1,
41+
)
42+
43+
44+
class ConvMlp(nn.Module):
45+
""" MLP using 1x1 convs that keeps spatial dims
46+
copied from timm: https://github.com/huggingface/pytorch-image-models/blob/v0.6.11/timm/models/layers/mlp.py
47+
"""
48+
49+
def __init__(
50+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
51+
norm_layer=None, bias=True, drop=0.):
52+
super().__init__()
53+
out_features = out_features or in_features
54+
hidden_features = hidden_features or in_features
55+
bias = to_2tuple(bias)
56+
57+
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])
58+
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
59+
self.act = act_layer()
60+
self.drop = nn.Dropout(drop)
61+
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])
62+
63+
def forward(self, x):
64+
x = self.fc1(x)
65+
x = self.norm(x)
66+
x = self.act(x)
67+
x = self.drop(x)
68+
x = self.fc2(x)
69+
return x
70+
71+
72+
class MlpHead(nn.Module):
73+
""" MLP classification head
74+
"""
75+
76+
def __init__(
77+
self, dim, num_classes=1000, mlp_ratio=3, act_layer=nn.GELU,
78+
norm_layer=partial(nn.LayerNorm, eps=1e-6), drop=0., bias=True):
79+
super().__init__()
80+
hidden_features = int(mlp_ratio * dim)
81+
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
82+
self.act = act_layer()
83+
self.norm = norm_layer(hidden_features)
84+
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
85+
self.drop = nn.Dropout(drop)
86+
87+
def forward(self, x):
88+
x = x.mean((2, 3)) # global average pooling
89+
x = self.fc1(x)
90+
x = self.act(x)
91+
x = self.norm(x)
92+
x = self.drop(x)
93+
x = self.fc2(x)
94+
return x
95+
96+
97+
class MetaNeXtBlock(nn.Module):
98+
""" MetaNeXtBlock Block
99+
Args:
100+
dim (int): Number of input channels.
101+
drop_path (float): Stochastic depth rate. Default: 0.0
102+
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
103+
"""
104+
105+
def __init__(
106+
self,
107+
dim,
108+
token_mixer=nn.Identity,
109+
norm_layer=nn.BatchNorm2d,
110+
mlp_layer=ConvMlp,
111+
mlp_ratio=4,
112+
act_layer=nn.GELU,
113+
ls_init_value=1e-6,
114+
drop_path=0.,
115+
116+
):
117+
super().__init__()
118+
self.token_mixer = token_mixer(dim)
119+
self.norm = norm_layer(dim)
120+
self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer)
121+
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value else None
122+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
123+
124+
def forward(self, x):
125+
shortcut = x
126+
x = self.token_mixer(x)
127+
x = self.norm(x)
128+
x = self.mlp(x)
129+
if self.gamma is not None:
130+
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
131+
x = self.drop_path(x) + shortcut
132+
return x
133+
134+
135+
class MetaNeXtStage(nn.Module):
136+
def __init__(
137+
self,
138+
in_chs,
139+
out_chs,
140+
ds_stride=2,
141+
depth=2,
142+
drop_path_rates=None,
143+
ls_init_value=1.0,
144+
token_mixer=nn.Identity,
145+
act_layer=nn.GELU,
146+
norm_layer=None,
147+
mlp_ratio=4,
148+
):
149+
super().__init__()
150+
self.grad_checkpointing = False
151+
if ds_stride > 1:
152+
self.downsample = nn.Sequential(
153+
norm_layer(in_chs),
154+
nn.Conv2d(in_chs, out_chs, kernel_size=ds_stride, stride=ds_stride),
155+
)
156+
else:
157+
self.downsample = nn.Identity()
158+
159+
drop_path_rates = drop_path_rates or [0.] * depth
160+
stage_blocks = []
161+
for i in range(depth):
162+
stage_blocks.append(MetaNeXtBlock(
163+
dim=out_chs,
164+
drop_path=drop_path_rates[i],
165+
ls_init_value=ls_init_value,
166+
token_mixer=token_mixer,
167+
act_layer=act_layer,
168+
norm_layer=norm_layer,
169+
mlp_ratio=mlp_ratio,
170+
))
171+
in_chs = out_chs
172+
self.blocks = nn.Sequential(*stage_blocks)
173+
174+
def forward(self, x):
175+
x = self.downsample(x)
176+
if self.grad_checkpointing and not torch.jit.is_scripting():
177+
x = checkpoint_seq(self.blocks, x)
178+
else:
179+
x = self.blocks(x)
180+
return x
181+
182+
183+
class MetaNeXt(nn.Module):
184+
r""" MetaNeXt
185+
A PyTorch impl of : `InceptionNeXt: When Inception Meets ConvNeXt` - https://arxiv.org/pdf/2203.xxxxx.pdf
186+
187+
Args:
188+
in_chans (int): Number of input image channels. Default: 3
189+
num_classes (int): Number of classes for classification head. Default: 1000
190+
depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 9, 3)
191+
dims (tuple(int)): Feature dimension at each stage. Default: (96, 192, 384, 768)
192+
token_mixers: Token mixer function. Default: nn.Identity
193+
norm_layer: Normalziation layer. Default: nn.BatchNorm2d
194+
act_layer: Activation function for MLP. Default: nn.GELU
195+
mlp_ratios (int or tuple(int)): MLP ratios. Default: (4, 4, 4, 3)
196+
head_fn: classifier head
197+
drop_rate (float): Head dropout rate
198+
drop_path_rate (float): Stochastic depth rate. Default: 0.
199+
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
200+
"""
201+
202+
def __init__(
203+
self,
204+
in_chans=3,
205+
num_classes=1000,
206+
depths=(3, 3, 9, 3),
207+
dims=(96, 192, 384, 768),
208+
token_mixers=nn.Identity,
209+
norm_layer=nn.BatchNorm2d,
210+
act_layer=nn.GELU,
211+
mlp_ratios=(4, 4, 4, 3),
212+
head_fn=MlpHead,
213+
drop_rate=0.,
214+
drop_path_rate=0.,
215+
ls_init_value=1e-6,
216+
**kwargs,
217+
):
218+
super().__init__()
219+
220+
num_stage = len(depths)
221+
if not isinstance(token_mixers, (list, tuple)):
222+
token_mixers = [token_mixers] * num_stage
223+
if not isinstance(mlp_ratios, (list, tuple)):
224+
mlp_ratios = [mlp_ratios] * num_stage
225+
226+
self.num_classes = num_classes
227+
self.drop_rate = drop_rate
228+
self.stem = nn.Sequential(
229+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
230+
norm_layer(dims[0])
231+
)
232+
233+
self.stages = nn.Sequential()
234+
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
235+
stages = []
236+
prev_chs = dims[0]
237+
# feature resolution stages, each consisting of multiple residual blocks
238+
for i in range(num_stage):
239+
out_chs = dims[i]
240+
stages.append(MetaNeXtStage(
241+
prev_chs,
242+
out_chs,
243+
ds_stride=2 if i > 0 else 1,
244+
depth=depths[i],
245+
drop_path_rates=dp_rates[i],
246+
ls_init_value=ls_init_value,
247+
act_layer=act_layer,
248+
token_mixer=token_mixers[i],
249+
norm_layer=norm_layer,
250+
mlp_ratio=mlp_ratios[i],
251+
))
252+
prev_chs = out_chs
253+
self.stages = nn.Sequential(*stages)
254+
self.num_features = prev_chs
255+
self.head = head_fn(self.num_features, num_classes, drop=drop_rate)
256+
self.apply(self._init_weights)
257+
258+
@torch.jit.ignore
259+
def set_grad_checkpointing(self, enable=True):
260+
for s in self.stages:
261+
s.grad_checkpointing = enable
262+
263+
@torch.jit.ignore
264+
def no_weight_decay(self):
265+
return {'norm'}
266+
267+
def forward_features(self, x):
268+
x = self.stem(x)
269+
x = self.stages(x)
270+
return x
271+
272+
def forward_head(self, x):
273+
x = self.head(x)
274+
return x
275+
276+
def forward(self, x):
277+
x = self.forward_features(x)
278+
x = self.forward_head(x)
279+
return x
280+
281+
def _init_weights(self, m):
282+
if isinstance(m, (nn.Conv2d, nn.Linear)):
283+
trunc_normal_(m.weight, std=.02)
284+
if m.bias is not None:
285+
nn.init.constant_(m.bias, 0)
286+
287+
288+
def _cfg(url='', **kwargs):
289+
return {
290+
'url': url,
291+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
292+
'crop_pct': 0.875, 'interpolation': 'bicubic',
293+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
294+
'first_conv': 'stem.0', 'classifier': 'head.fc',
295+
**kwargs
296+
}
297+
298+
299+
default_cfgs = dict(
300+
inception_next_tiny=_cfg(
301+
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth',
302+
),
303+
inception_next_small=_cfg(
304+
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth',
305+
),
306+
inception_next_base=_cfg(
307+
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth',
308+
),
309+
inception_next_base_384=_cfg(
310+
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth',
311+
input_size=(3, 384, 384), crop_pct=1.0,
312+
),
313+
)
314+
315+
316+
@register_model
317+
def inception_next_tiny(pretrained=False, **kwargs):
318+
model = MetaNeXt(
319+
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768),
320+
token_mixers=InceptionDWConv2d,
321+
**kwargs
322+
)
323+
model.default_cfg = default_cfgs['inception_next_tiny']
324+
if pretrained:
325+
state_dict = torch.hub.load_state_dict_from_url(
326+
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
327+
model.load_state_dict(state_dict)
328+
return model
329+
330+
331+
@register_model
332+
def inception_next_small(pretrained=False, **kwargs):
333+
model = MetaNeXt(
334+
depths=(3, 3, 27, 3), dims=(96, 192, 384, 768),
335+
token_mixers=InceptionDWConv2d,
336+
**kwargs
337+
)
338+
model.default_cfg = default_cfgs['inception_next_small']
339+
if pretrained:
340+
state_dict = torch.hub.load_state_dict_from_url(
341+
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
342+
model.load_state_dict(state_dict)
343+
return model
344+
345+
346+
@register_model
347+
def inception_next_base(pretrained=False, **kwargs):
348+
model = MetaNeXt(
349+
depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024),
350+
token_mixers=InceptionDWConv2d,
351+
**kwargs
352+
)
353+
model.default_cfg = default_cfgs['inception_next_base']
354+
if pretrained:
355+
state_dict = torch.hub.load_state_dict_from_url(
356+
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
357+
model.load_state_dict(state_dict)
358+
return model
359+
360+
361+
@register_model
362+
def inception_next_base_384(pretrained=False, **kwargs):
363+
model = MetaNeXt(
364+
depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024],
365+
mlp_ratios=[4, 4, 4, 3],
366+
token_mixers=InceptionDWConv2d,
367+
**kwargs
368+
)
369+
model.default_cfg = default_cfgs['inception_next_base_384']
370+
if pretrained:
371+
state_dict = torch.hub.load_state_dict_from_url(
372+
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
373+
model.load_state_dict(state_dict)
374+
return model

0 commit comments

Comments
 (0)