Skip to content

Commit 316bdf8

Browse files
committed
Add mlp head support for convnext_large, add laion2b CLIP weights, prep fine-tuned weight tags
1 parent 6f28b56 commit 316bdf8

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

timm/models/convnext.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,11 @@ def checkpoint_filter_fn(state_dict, model):
397397
if 'visual.head.proj.weight' in state_dict:
398398
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
399399
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
400+
elif 'visual.head.mlp.fc1.weight' in state_dict:
401+
out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
402+
out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
403+
out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
404+
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
400405
return out_dict
401406

402407
import re
@@ -716,6 +721,16 @@ def _cfgv2(url='', **kwargs):
716721

717722
'convnextv2_small.untrained': _cfg(),
718723

724+
# CLIP weights, fine-tuned on in1k or in12k + in1k
725+
'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
726+
# hf_hub_id='timm/',
727+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
728+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
729+
'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
730+
# hf_hub_id='timm/',
731+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
732+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
733+
719734
# CLIP based weights, original image tower weights and fine-tunes
720735
'convnext_base.clip_laion2b': _cfg(
721736
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
@@ -742,6 +757,11 @@ def _cfgv2(url='', **kwargs):
742757
hf_hub_filename='open_clip_pytorch_model.bin',
743758
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
744759
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
760+
'convnext_large_mlp.clip_laion2b_augreg': _cfg(
761+
hf_hub_id='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg',
762+
hf_hub_filename='open_clip_pytorch_model.bin',
763+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
764+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
745765
})
746766

747767

@@ -854,6 +874,13 @@ def convnext_large(pretrained=False, **kwargs):
854874
return model
855875

856876

877+
@register_model
878+
def convnext_large_mlp(pretrained=False, **kwargs):
879+
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536, **kwargs)
880+
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **model_args)
881+
return model
882+
883+
857884
@register_model
858885
def convnext_xlarge(pretrained=False, **kwargs):
859886
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)

0 commit comments

Comments
 (0)