11"""
22InceptionNeXt implementation, paper: https://arxiv.org/abs/2303.16900
3-
4- Some code is borrowed from timm: https://github.com/huggingface/pytorch-image-models
53"""
64
75from functools import partial
119
1210from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
1311from timm .layers import trunc_normal_ , DropPath , to_2tuple
12+ from ._builder import build_model_with_cfg
1413from ._manipulate import checkpoint_seq
15- from ._registry import register_model
14+ from ._registry import register_model , generate_default_cfgs
1615
1716
1817class InceptionDWConv2d (nn .Module ):
1918 """ Inception depthweise convolution
2019 """
2120
22- def __init__ (self , in_channels , square_kernel_size = 3 , band_kernel_size = 11 , branch_ratio = 0.125 ):
21+ def __init__ (
22+ self ,
23+ in_chs ,
24+ square_kernel_size = 3 ,
25+ band_kernel_size = 11 ,
26+ branch_ratio = 0.125
27+ ):
2328 super ().__init__ ()
2429
25- gc = int (in_channels * branch_ratio ) # channel numbers of a convolution branch
30+ gc = int (in_chs * branch_ratio ) # channel numbers of a convolution branch
2631 self .dwconv_hw = nn .Conv2d (gc , gc , square_kernel_size , padding = square_kernel_size // 2 , groups = gc )
2732 self .dwconv_w = nn .Conv2d (
2833 gc , gc , kernel_size = (1 , band_kernel_size ), padding = (0 , band_kernel_size // 2 ), groups = gc )
2934 self .dwconv_h = nn .Conv2d (
3035 gc , gc , kernel_size = (band_kernel_size , 1 ), padding = (band_kernel_size // 2 , 0 ), groups = gc )
31- self .split_indexes = (in_channels - 3 * gc , gc , gc , gc )
36+ self .split_indexes = (in_chs - 3 * gc , gc , gc , gc )
3237
3338 def forward (self , x ):
3439 x_id , x_hw , x_w , x_h = torch .split (x , self .split_indexes , dim = 1 )
@@ -47,8 +52,15 @@ class ConvMlp(nn.Module):
4752 """
4853
4954 def __init__ (
50- self , in_features , hidden_features = None , out_features = None , act_layer = nn .ReLU ,
51- norm_layer = None , bias = True , drop = 0. ):
55+ self ,
56+ in_features ,
57+ hidden_features = None ,
58+ out_features = None ,
59+ act_layer = nn .ReLU ,
60+ norm_layer = None ,
61+ bias = True ,
62+ drop = 0. ,
63+ ):
5264 super ().__init__ ()
5365 out_features = out_features or in_features
5466 hidden_features = hidden_features or in_features
@@ -69,13 +81,20 @@ def forward(self, x):
6981 return x
7082
7183
72- class MlpHead (nn .Module ):
84+ class MlpClassifierHead (nn .Module ):
7385 """ MLP classification head
7486 """
7587
7688 def __init__ (
77- self , dim , num_classes = 1000 , mlp_ratio = 3 , act_layer = nn .GELU ,
78- norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), drop = 0. , bias = True ):
89+ self ,
90+ dim ,
91+ num_classes = 1000 ,
92+ mlp_ratio = 3 ,
93+ act_layer = nn .GELU ,
94+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
95+ drop = 0. ,
96+ bias = True
97+ ):
7998 super ().__init__ ()
8099 hidden_features = int (mlp_ratio * dim )
81100 self .fc1 = nn .Linear (dim , hidden_features , bias = bias )
@@ -168,7 +187,6 @@ def __init__(
168187 norm_layer = norm_layer ,
169188 mlp_ratio = mlp_ratio ,
170189 ))
171- in_chs = out_chs
172190 self .blocks = nn .Sequential (* stage_blocks )
173191
174192 def forward (self , x ):
@@ -209,11 +227,10 @@ def __init__(
209227 norm_layer = nn .BatchNorm2d ,
210228 act_layer = nn .GELU ,
211229 mlp_ratios = (4 , 4 , 4 , 3 ),
212- head_fn = MlpHead ,
230+ head_fn = MlpClassifierHead ,
213231 drop_rate = 0. ,
214232 drop_path_rate = 0. ,
215233 ls_init_value = 1e-6 ,
216- ** kwargs ,
217234 ):
218235 super ().__init__ ()
219236
@@ -255,14 +272,38 @@ def __init__(
255272 self .head = head_fn (self .num_features , num_classes , drop = drop_rate )
256273 self .apply (self ._init_weights )
257274
275+ def _init_weights (self , m ):
276+ if isinstance (m , (nn .Conv2d , nn .Linear )):
277+ trunc_normal_ (m .weight , std = .02 )
278+ if m .bias is not None :
279+ nn .init .constant_ (m .bias , 0 )
280+
281+ @torch .jit .ignore
282+ def group_matcher (self , coarse = False ):
283+ return dict (
284+ stem = r'^stem' ,
285+ blocks = r'^stages\.(\d+)' if coarse else [
286+ (r'^stages\.(\d+)\.downsample' , (0 ,)), # blocks
287+ (r'^stages\.(\d+)\.blocks\.(\d+)' , None ),
288+ ]
289+ )
290+
291+ @torch .jit .ignore
292+ def get_classifier (self ):
293+ return self .head .fc2
294+
295+ def reset_classifier (self , num_classes = 0 , global_pool = None ):
296+ # FIXME
297+ self .head .reset (num_classes , global_pool )
298+
258299 @torch .jit .ignore
259300 def set_grad_checkpointing (self , enable = True ):
260301 for s in self .stages :
261302 s .grad_checkpointing = enable
262303
263304 @torch .jit .ignore
264305 def no_weight_decay (self ):
265- return { 'norm' }
306+ return set ()
266307
267308 def forward_features (self , x ):
268309 x = self .stem (x )
@@ -278,97 +319,66 @@ def forward(self, x):
278319 x = self .forward_head (x )
279320 return x
280321
281- def _init_weights (self , m ):
282- if isinstance (m , (nn .Conv2d , nn .Linear )):
283- trunc_normal_ (m .weight , std = .02 )
284- if m .bias is not None :
285- nn .init .constant_ (m .bias , 0 )
286-
287322
288323def _cfg (url = '' , ** kwargs ):
289324 return {
290325 'url' : url ,
291326 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : (7 , 7 ),
292327 'crop_pct' : 0.875 , 'interpolation' : 'bicubic' ,
293328 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
294- 'first_conv' : 'stem.0' , 'classifier' : 'head.fc ' ,
329+ 'first_conv' : 'stem.0' , 'classifier' : 'head.fc2 ' ,
295330 ** kwargs
296331 }
297332
298333
299- default_cfgs = dict (
300- inception_next_tiny = _cfg (
334+ default_cfgs = generate_default_cfgs ({
335+ ' inception_next_tiny.sail_in1k' : _cfg (
301336 url = 'https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth' ,
302337 ),
303- inception_next_small = _cfg (
338+ ' inception_next_small.sail_in1k' : _cfg (
304339 url = 'https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth' ,
305340 ),
306- inception_next_base = _cfg (
341+ ' inception_next_base.sail_in1k' : _cfg (
307342 url = 'https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth' ,
343+ crop_pct = 0.95 ,
308344 ),
309- inception_next_base_384 = _cfg (
345+ 'inception_next_base.sail_in1k_384' : _cfg (
310346 url = 'https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth' ,
311347 input_size = (3 , 384 , 384 ), crop_pct = 1.0 ,
312348 ),
313- )
349+ })
350+
351+
352+ def _create_inception_next (variant , pretrained = False , ** kwargs ):
353+ model = build_model_with_cfg (
354+ MetaNeXt , variant , pretrained ,
355+ feature_cfg = dict (out_indices = (0 , 1 , 2 , 3 ), flatten_sequential = True ),
356+ ** kwargs )
357+ return model
314358
315359
316360@register_model
317361def inception_next_tiny (pretrained = False , ** kwargs ):
318- model = MetaNeXt (
362+ model_args = dict (
319363 depths = (3 , 3 , 9 , 3 ), dims = (96 , 192 , 384 , 768 ),
320364 token_mixers = InceptionDWConv2d ,
321- ** kwargs
322365 )
323- model .default_cfg = default_cfgs ['inception_next_tiny' ]
324- if pretrained :
325- state_dict = torch .hub .load_state_dict_from_url (
326- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
327- model .load_state_dict (state_dict )
328- return model
366+ return _create_inception_next ('inception_next_tiny' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
329367
330368
331369@register_model
332370def inception_next_small (pretrained = False , ** kwargs ):
333- model = MetaNeXt (
371+ model_args = dict (
334372 depths = (3 , 3 , 27 , 3 ), dims = (96 , 192 , 384 , 768 ),
335373 token_mixers = InceptionDWConv2d ,
336- ** kwargs
337374 )
338- model .default_cfg = default_cfgs ['inception_next_small' ]
339- if pretrained :
340- state_dict = torch .hub .load_state_dict_from_url (
341- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
342- model .load_state_dict (state_dict )
343- return model
375+ return _create_inception_next ('inception_next_small' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
344376
345377
346378@register_model
347379def inception_next_base (pretrained = False , ** kwargs ):
348- model = MetaNeXt (
380+ model_args = dict (
349381 depths = (3 , 3 , 27 , 3 ), dims = (128 , 256 , 512 , 1024 ),
350382 token_mixers = InceptionDWConv2d ,
351- ** kwargs
352383 )
353- model .default_cfg = default_cfgs ['inception_next_base' ]
354- if pretrained :
355- state_dict = torch .hub .load_state_dict_from_url (
356- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
357- model .load_state_dict (state_dict )
358- return model
359-
360-
361- @register_model
362- def inception_next_base_384 (pretrained = False , ** kwargs ):
363- model = MetaNeXt (
364- depths = [3 , 3 , 27 , 3 ], dims = [128 , 256 , 512 , 1024 ],
365- mlp_ratios = [4 , 4 , 4 , 3 ],
366- token_mixers = InceptionDWConv2d ,
367- ** kwargs
368- )
369- model .default_cfg = default_cfgs ['inception_next_base_384' ]
370- if pretrained :
371- state_dict = torch .hub .load_state_dict_from_url (
372- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
373- model .load_state_dict (state_dict )
374- return model
384+ return _create_inception_next ('inception_next_base' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments