99# Copyright (c) Meta Platforms, Inc. and affiliates.
1010# All rights reserved.
1111# This source code is licensed under the MIT license
12-
12+ from collections import OrderedDict
1313from functools import partial
1414
1515import torch
@@ -32,7 +32,7 @@ def _cfg(url='', **kwargs):
3232 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : (7 , 7 ),
3333 'crop_pct' : 0.875 , 'interpolation' : 'bicubic' ,
3434 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
35- 'first_conv' : 'stem.0' , 'classifier' : 'head' ,
35+ 'first_conv' : 'stem.0' , 'classifier' : 'head.fc ' ,
3636 ** kwargs
3737 }
3838
@@ -43,7 +43,7 @@ def _cfg(url='', **kwargs):
4343 convnext_base = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth" ),
4444 convnext_large = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth" ),
4545
46- convnext_tiny_hnf = _cfg (url = '' , classifier = 'head.fc' ),
46+ convnext_tiny_hnf = _cfg (url = '' ),
4747
4848 convnext_base_in22k = _cfg (
4949 url = "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth" , num_classes = 21841 ),
@@ -65,16 +65,12 @@ def _is_contiguous(tensor: torch.Tensor) -> bool:
6565
6666
6767@register_notrace_module
68- class LayerNorm2d (nn .Module ):
68+ class LayerNorm2d (nn .LayerNorm ):
6969 r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
7070 """
7171
7272 def __init__ (self , normalized_shape , eps = 1e-6 ):
73- super ().__init__ ()
74- self .weight = nn .Parameter (torch .ones (normalized_shape ))
75- self .bias = nn .Parameter (torch .zeros (normalized_shape ))
76- self .eps = eps
77- self .normalized_shape = (normalized_shape ,)
73+ super ().__init__ (normalized_shape , eps = eps )
7874
7975 def forward (self , x ) -> torch .Tensor :
8076 if _is_contiguous (x ):
@@ -105,7 +101,8 @@ class ConvNeXtBlock(nn.Module):
105101
106102 def __init__ (self , dim , drop_path = 0. , ls_init_value = 1e-6 , conv_mlp = True , mlp_ratio = 4 , norm_layer = None ):
107103 super ().__init__ ()
108- norm_layer = norm_layer or (partial (LayerNorm2d , eps = 1e-6 ) if conv_mlp else partial (nn .LayerNorm , eps = 1e-6 ))
104+ if not norm_layer :
105+ norm_layer = partial (LayerNorm2d , eps = 1e-6 ) if conv_mlp else partial (nn .LayerNorm , eps = 1e-6 )
109106 mlp_layer = ConvMlp if conv_mlp else Mlp
110107 self .use_conv_mlp = conv_mlp
111108 self .conv_dw = nn .Conv2d (dim , dim , kernel_size = 7 , padding = 3 , groups = dim ) # depthwise conv
@@ -120,15 +117,13 @@ def forward(self, x):
120117 if self .use_conv_mlp :
121118 x = self .norm (x )
122119 x = self .mlp (x )
123- if self .gamma is not None :
124- x .mul_ (self .gamma .reshape (1 , - 1 , 1 , 1 ))
125120 else :
126121 x = x .permute (0 , 2 , 3 , 1 )
127122 x = self .norm (x )
128123 x = self .mlp (x )
129- if self .gamma is not None :
130- x .mul_ (self .gamma )
131124 x = x .permute (0 , 3 , 1 , 2 )
125+ if self .gamma is not None :
126+ x = x .mul (self .gamma .reshape (1 , - 1 , 1 , 1 ))
132127 x = self .drop_path (x ) + shortcut
133128 return x
134129
@@ -191,7 +186,6 @@ def __init__(
191186 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
192187 cl_norm_layer = norm_layer
193188
194- partial (LayerNorm2d , eps = 1e-6 )
195189 self .num_classes = num_classes
196190 self .drop_rate = drop_rate
197191 self .feature_info = []
@@ -226,51 +220,46 @@ def __init__(
226220 self .num_features = prev_chs
227221 if head_norm_first :
228222 # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights)
229- self .norm = norm_layer (self .num_features ) # final norm layer
230- self .pool = None # global pool in ClassifierHead, pool == None being used to differentiate
223+ self .norm_pre = norm_layer (self .num_features ) # final norm layer, before pooling
231224 self .head = ClassifierHead (self .num_features , num_classes , pool_type = global_pool , drop_rate = drop_rate )
232225 else :
233226 # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
234- self .pool = SelectAdaptivePool2d (pool_type = global_pool )
235- # NOTE when cl_norm_layer != norm_layer we could flatten here and use cl, but makes no performance diff
236- self .norm = norm_layer (self .num_features )
237- self .head = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
227+ self .norm_pre = nn .Identity ()
228+ self .head = nn .Sequential (OrderedDict ([
229+ ('global_pool' , SelectAdaptivePool2d (pool_type = global_pool )),
230+ ('norm' , norm_layer (self .num_features )),
231+ ('flatten' , nn .Flatten (1 ) if global_pool else nn .Identity ()),
232+ ('fc' , nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ())
233+ ]))
238234
239235 named_apply (partial (_init_weights , head_init_scale = head_init_scale ), self )
240236
241237 def get_classifier (self ):
242- return self .head .fc if self . pool is None else self . head
238+ return self .head .fc
243239
244240 def reset_classifier (self , num_classes = 0 , global_pool = 'avg' ):
245- if self .pool is None :
246- # norm -> global pool -> fc ordering
241+ if isinstance ( self .head , ClassifierHead ) :
242+ # norm -> global pool -> fc
247243 self .head = ClassifierHead (
248244 self .num_features , num_classes , pool_type = global_pool , drop_rate = self .drop_rate )
249245 else :
250246 # pool -> norm -> fc
251- self .pool = SelectAdaptivePool2d (pool_type = global_pool )
252- self .head = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
247+ self .head = nn .Sequential (OrderedDict ([
248+ ('global_pool' , SelectAdaptivePool2d (pool_type = global_pool )),
249+ ('norm' , self .head .norm ),
250+ ('flatten' , nn .Flatten (1 ) if global_pool else nn .Identity ()),
251+ ('drop' , nn .Dropout (self .drop_rate )),
252+ ('fc' , nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ())
253+ ]))
253254
254255 def forward_features (self , x ):
255256 x = self .stem (x )
256257 x = self .stages (x )
257- if self .pool is None :
258- # standard head, norm -> spatial pool -> fc
259- # ideally, last norm is within forward_features, but can only do so if norm precedes pooling
260- x = self .norm (x )
258+ x = self .norm_pre (x )
261259 return x
262260
263261 def forward (self , x ):
264262 x = self .forward_features (x )
265- if self .pool is not None :
266- # ConvNeXt head, spatial pool -> norm -> fc
267- # FIXME clean this up
268- x = self .pool (x )
269- x = self .norm (x )
270- if not self .pool .is_identity ():
271- x = x .flatten (1 )
272- if self .drop_rate > 0 :
273- x = F .dropout (x , self .drop_rate , self .training )
274263 x = self .head (x )
275264 return x
276265
@@ -282,7 +271,7 @@ def _init_weights(module, name=None, head_init_scale=1.0):
282271 elif isinstance (module , nn .Linear ):
283272 trunc_normal_ (module .weight , std = .02 )
284273 nn .init .constant_ (module .bias , 0 )
285- if name and '. head' in name :
274+ if name and 'head. ' in name :
286275 module .weight .data .mul_ (head_init_scale )
287276 module .bias .data .mul_ (head_init_scale )
288277
@@ -299,6 +288,9 @@ def checkpoint_filter_fn(state_dict, model):
299288 k = re .sub (r'downsample_layers.([0-9]+).([0-9]+)' , r'stages.\1.downsample.\2' , k )
300289 k = k .replace ('dwconv' , 'conv_dw' )
301290 k = k .replace ('pwconv' , 'mlp.fc' )
291+ k = k .replace ('head.' , 'head.fc.' )
292+ if k .startswith ('norm.' ):
293+ k = k .replace ('norm' , 'head.norm' )
302294 if v .ndim == 2 and 'head' not in k :
303295 model_shape = model .state_dict ()[k ].shape
304296 v = v .reshape (model_shape )
0 commit comments