3939
4040from collections import OrderedDict
4141from functools import partial
42+ from typing import Callable , Optional , Tuple , Union
4243
4344import torch
4445import torch .nn as nn
4546
4647from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
4748from timm .layers import trunc_normal_ , SelectAdaptivePool2d , DropPath , Mlp , GlobalResponseNormMlp , \
4849 LayerNorm2d , LayerNorm , create_conv2d , get_act_layer , make_divisible , to_ntuple
50+ from timm .layers import NormMlpClassifierHead , ClassifierHead
4951from ._builder import build_model_with_cfg
5052from ._manipulate import named_apply , checkpoint_seq
5153from ._pretrained import generate_default_cfgs
@@ -188,48 +190,50 @@ class ConvNeXt(nn.Module):
188190
189191 def __init__ (
190192 self ,
191- in_chans = 3 ,
192- num_classes = 1000 ,
193- global_pool = 'avg' ,
194- output_stride = 32 ,
195- depths = (3 , 3 , 9 , 3 ),
196- dims = (96 , 192 , 384 , 768 ),
197- kernel_sizes = 7 ,
198- ls_init_value = 1e-6 ,
199- stem_type = 'patch' ,
200- patch_size = 4 ,
201- head_init_scale = 1. ,
202- head_norm_first = False ,
203- conv_mlp = False ,
204- conv_bias = True ,
205- use_grn = False ,
206- act_layer = 'gelu' ,
207- norm_layer = None ,
208- norm_eps = None ,
209- drop_rate = 0. ,
210- drop_path_rate = 0. ,
193+ in_chans : int = 3 ,
194+ num_classes : int = 1000 ,
195+ global_pool : str = 'avg' ,
196+ output_stride : int = 32 ,
197+ depths : Tuple [int , ...] = (3 , 3 , 9 , 3 ),
198+ dims : Tuple [int , ...] = (96 , 192 , 384 , 768 ),
199+ kernel_sizes : Union [int , Tuple [int , ...]] = 7 ,
200+ ls_init_value : Optional [float ] = 1e-6 ,
201+ stem_type : str = 'patch' ,
202+ patch_size : int = 4 ,
203+ head_init_scale : float = 1. ,
204+ head_norm_first : bool = False ,
205+ head_hidden_size : Optional [int ] = None ,
206+ conv_mlp : bool = False ,
207+ conv_bias : bool = True ,
208+ use_grn : bool = False ,
209+ act_layer : Union [str , Callable ] = 'gelu' ,
210+ norm_layer : Optional [Union [str , Callable ]] = None ,
211+ norm_eps : Optional [float ] = None ,
212+ drop_rate : float = 0. ,
213+ drop_path_rate : float = 0. ,
211214 ):
212215 """
213216 Args:
214- in_chans (int): Number of input image channels (default: 3)
215- num_classes (int): Number of classes for classification head (default: 1000)
216- global_pool (str): Global pooling type (default: 'avg')
217- output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
218- depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
219- dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
220- kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
221- ls_init_value (float): Init value for Layer Scale (default: 1e-6)
222- stem_type (str): Type of stem (default: 'patch')
223- patch_size (int): Stem patch size for patch stem (default: 4)
224- head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
225- head_norm_first (bool): Apply normalization before global pool + head (default: False)
226- conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
227- conv_bias (bool): Use bias layers w/ all convolutions (default: True)
228- use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
229- act_layer (Union[str, nn.Module]): Activation Layer
230- norm_layer (Union[str, nn.Module]): Normalization Layer
231- drop_rate (float): Head dropout rate (default: 0.)
232- drop_path_rate (float): Stochastic depth rate (default: 0.)
217+ in_chans: Number of input image channels.
218+ num_classes: Number of classes for classification head.
219+ global_pool: Global pooling type.
220+ output_stride: Output stride of network, one of (8, 16, 32).
221+ depths: Number of blocks at each stage.
222+ dims: Feature dimension at each stage.
223+ kernel_sizes: Depthwise convolution kernel-sizes for each stage.
224+ ls_init_value: Init value for Layer Scale, disabled if None.
225+ stem_type: Type of stem.
226+ patch_size: Stem patch size for patch stem.
227+ head_init_scale: Init scaling value for classifier weights and biases.
228+ head_norm_first: Apply normalization before global pool + head.
229+ head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
230+ conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
231+ conv_bias: Use bias layers w/ all convolutions.
232+ use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
233+ act_layer: Activation layer type.
234+ norm_layer: Normalization layer type.
235+ drop_rate: Head pre-classifier dropout rate.
236+ drop_path_rate: Stochastic depth drop rate.
233237 """
234238 super ().__init__ ()
235239 assert output_stride in (8 , 16 , 32 )
@@ -307,14 +311,26 @@ def __init__(
307311
308312 # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
309313 # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
310- self .norm_pre = norm_layer (self .num_features ) if head_norm_first else nn .Identity ()
311- self .head = nn .Sequential (OrderedDict ([
312- ('global_pool' , SelectAdaptivePool2d (pool_type = global_pool )),
313- ('norm' , nn .Identity () if head_norm_first else norm_layer (self .num_features )),
314- ('flatten' , nn .Flatten (1 ) if global_pool else nn .Identity ()),
315- ('drop' , nn .Dropout (self .drop_rate )),
316- ('fc' , nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ())]))
317-
314+ if head_norm_first :
315+ assert not head_hidden_size
316+ self .norm_pre = norm_layer (self .num_features )
317+ self .head = ClassifierHead (
318+ self .num_features ,
319+ num_classes ,
320+ pool_type = global_pool ,
321+ drop_rate = self .drop_rate ,
322+ )
323+ else :
324+ self .norm_pre = nn .Identity ()
325+ self .head = NormMlpClassifierHead (
326+ self .num_features ,
327+ num_classes ,
328+ hidden_size = head_hidden_size ,
329+ pool_type = global_pool ,
330+ drop_rate = self .drop_rate ,
331+ norm_layer = norm_layer ,
332+ act_layer = 'gelu' ,
333+ )
318334 named_apply (partial (_init_weights , head_init_scale = head_init_scale ), self )
319335
320336 @torch .jit .ignore
@@ -338,10 +354,7 @@ def get_classifier(self):
338354 return self .head .fc
339355
340356 def reset_classifier (self , num_classes = 0 , global_pool = None ):
341- if global_pool is not None :
342- self .head .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
343- self .head .flatten = nn .Flatten (1 ) if global_pool else nn .Identity ()
344- self .head .fc = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
357+ self .head .reset (num_classes , global_pool = global_pool )
345358
346359 def forward_features (self , x ):
347360 x = self .stem (x )
@@ -350,12 +363,7 @@ def forward_features(self, x):
350363 return x
351364
352365 def forward_head (self , x , pre_logits : bool = False ):
353- # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
354- x = self .head .global_pool (x )
355- x = self .head .norm (x )
356- x = self .head .flatten (x )
357- x = self .head .drop (x )
358- return x if pre_logits else self .head .fc (x )
366+ return self .head (x , pre_logits = pre_logits )
359367
360368 def forward (self , x ):
361369 x = self .forward_features (x )
@@ -389,6 +397,11 @@ def checkpoint_filter_fn(state_dict, model):
389397 if 'visual.head.proj.weight' in state_dict :
390398 out_dict ['head.fc.weight' ] = state_dict ['visual.head.proj.weight' ]
391399 out_dict ['head.fc.bias' ] = torch .zeros (state_dict ['visual.head.proj.weight' ].shape [0 ])
400+ elif 'visual.head.mlp.fc1.weight' in state_dict :
401+ out_dict ['head.pre_logits.fc.weight' ] = state_dict ['visual.head.mlp.fc1.weight' ]
402+ out_dict ['head.pre_logits.fc.bias' ] = state_dict ['visual.head.mlp.fc1.bias' ]
403+ out_dict ['head.fc.weight' ] = state_dict ['visual.head.mlp.fc2.weight' ]
404+ out_dict ['head.fc.bias' ] = torch .zeros (state_dict ['visual.head.mlp.fc2.weight' ].shape [0 ])
392405 return out_dict
393406
394407 import re
@@ -708,6 +721,22 @@ def _cfgv2(url='', **kwargs):
708721
709722 'convnextv2_small.untrained' : _cfg (),
710723
724+ # CLIP weights, fine-tuned on in1k or in12k + in1k
725+ 'convnext_base.clip_laion2b_augreg_ft_in1k' : _cfg (
726+ hf_hub_id = 'timm/' ,
727+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
728+ input_size = (3 , 256 , 256 ), pool_size = (8 , 8 ), crop_pct = 1.0 ),
729+ 'convnext_base.clip_laiona_augreg_ft_in1k_384' : _cfg (
730+ hf_hub_id = 'timm/' ,
731+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
732+ input_size = (3 , 384 , 384 ), pool_size = (12 , 12 ), crop_pct = 1.0 ),
733+ 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k' : _cfg (
734+ hf_hub_id = 'timm/' ,
735+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
736+ input_size = (3 , 256 , 256 ), pool_size = (8 , 8 ), crop_pct = 1.0
737+ ),
738+
739+
711740 # CLIP based weights, original image tower weights and fine-tunes
712741 'convnext_base.clip_laion2b' : _cfg (
713742 hf_hub_id = 'laion/CLIP-convnext_base_w-laion2B-s13B-b82K' ,
@@ -734,6 +763,11 @@ def _cfgv2(url='', **kwargs):
734763 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
735764 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
736765 input_size = (3 , 320 , 320 ), pool_size = (10 , 10 ), crop_pct = 1.0 , num_classes = 640 ),
766+ 'convnext_large_mlp.clip_laion2b_augreg' : _cfg (
767+ hf_hub_id = 'laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg' ,
768+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
769+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
770+ input_size = (3 , 256 , 256 ), pool_size = (8 , 8 ), crop_pct = 1.0 , num_classes = 768 ),
737771})
738772
739773
@@ -846,6 +880,13 @@ def convnext_large(pretrained=False, **kwargs):
846880 return model
847881
848882
883+ @register_model
884+ def convnext_large_mlp (pretrained = False , ** kwargs ):
885+ model_args = dict (depths = [3 , 3 , 27 , 3 ], dims = [192 , 384 , 768 , 1536 ], head_hidden_size = 1536 , ** kwargs )
886+ model = _create_convnext ('convnext_large_mlp' , pretrained = pretrained , ** model_args )
887+ return model
888+
889+
849890@register_model
850891def convnext_xlarge (pretrained = False , ** kwargs ):
851892 model_args = dict (depths = [3 , 3 , 27 , 3 ], dims = [256 , 512 , 1024 , 2048 ], ** kwargs )
0 commit comments