@@ -397,6 +397,11 @@ def checkpoint_filter_fn(state_dict, model):
397397 if 'visual.head.proj.weight' in state_dict :
398398 out_dict ['head.fc.weight' ] = state_dict ['visual.head.proj.weight' ]
399399 out_dict ['head.fc.bias' ] = torch .zeros (state_dict ['visual.head.proj.weight' ].shape [0 ])
400+ elif 'visual.head.mlp.fc1.weight' in state_dict :
401+ out_dict ['head.pre_logits.fc.weight' ] = state_dict ['visual.head.mlp.fc1.weight' ]
402+ out_dict ['head.pre_logits.fc.bias' ] = state_dict ['visual.head.mlp.fc1.bias' ]
403+ out_dict ['head.fc.weight' ] = state_dict ['visual.head.mlp.fc2.weight' ]
404+ out_dict ['head.fc.bias' ] = torch .zeros (state_dict ['visual.head.mlp.fc2.weight' ].shape [0 ])
400405 return out_dict
401406
402407 import re
@@ -716,6 +721,16 @@ def _cfgv2(url='', **kwargs):
716721
717722 'convnextv2_small.untrained' : _cfg (),
718723
724+ # CLIP weights, fine-tuned on in1k or in12k + in1k
725+ 'convnext_base.clip_laion2b_augreg_ft_in1k' : _cfg (
726+ # hf_hub_id='timm/',
727+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
728+ input_size = (3 , 256 , 256 ), pool_size = (8 , 8 ), crop_pct = 1.0 ),
729+ 'convnext_base.clip_laiona_augreg_ft_in1k_384' : _cfg (
730+ # hf_hub_id='timm/',
731+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
732+ input_size = (3 , 384 , 384 ), pool_size = (12 , 12 ), crop_pct = 1.0 ),
733+
719734 # CLIP based weights, original image tower weights and fine-tunes
720735 'convnext_base.clip_laion2b' : _cfg (
721736 hf_hub_id = 'laion/CLIP-convnext_base_w-laion2B-s13B-b82K' ,
@@ -742,6 +757,11 @@ def _cfgv2(url='', **kwargs):
742757 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
743758 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
744759 input_size = (3 , 320 , 320 ), pool_size = (10 , 10 ), crop_pct = 1.0 , num_classes = 640 ),
760+ 'convnext_large_mlp.clip_laion2b_augreg' : _cfg (
761+ hf_hub_id = 'laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg' ,
762+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
763+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
764+ input_size = (3 , 256 , 256 ), pool_size = (8 , 8 ), crop_pct = 1.0 , num_classes = 768 ),
745765})
746766
747767
@@ -854,6 +874,13 @@ def convnext_large(pretrained=False, **kwargs):
854874 return model
855875
856876
877+ @register_model
878+ def convnext_large_mlp (pretrained = False , ** kwargs ):
879+ model_args = dict (depths = [3 , 3 , 27 , 3 ], dims = [192 , 384 , 768 , 1536 ], head_hidden_size = 1536 , ** kwargs )
880+ model = _create_convnext ('convnext_large_mlp' , pretrained = pretrained , ** model_args )
881+ return model
882+
883+
857884@register_model
858885def convnext_xlarge (pretrained = False , ** kwargs ):
859886 model_args = dict (depths = [3 , 3 , 27 , 3 ], dims = [256 , 512 , 1024 , 2048 ], ** kwargs )
0 commit comments