Skip to content

Commit 010b486

Browse files
committed
Add Dino pretrained weights (no head) for vit models. Add support to tests and helpers for models w/ no classifier (num_classes=0 in pretrained cfg)
1 parent 738a9cd commit 010b486

File tree

3 files changed

+86
-34
lines changed

3 files changed

+86
-34
lines changed

tests/test_models.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,12 @@ def test_model_default_cfgs(model_name, batch_size):
170170
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
171171

172172
# check classifier name matches default_cfg
173-
classifier = cfg['classifier']
174-
if not isinstance(classifier, (tuple, list)):
175-
classifier = classifier,
176-
for c in classifier:
177-
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
173+
if cfg.get('num_classes', None):
174+
classifier = cfg['classifier']
175+
if not isinstance(classifier, (tuple, list)):
176+
classifier = classifier,
177+
for c in classifier:
178+
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
178179

179180
# check first conv(s) names match default_cfg
180181
first_conv = cfg['first_conv']
@@ -222,11 +223,12 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
222223
assert outputs.shape[1] == model.num_features
223224

224225
# check classifier name matches default_cfg
225-
classifier = cfg['classifier']
226-
if not isinstance(classifier, (tuple, list)):
227-
classifier = classifier,
228-
for c in classifier:
229-
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
226+
if cfg.get('num_classes', None):
227+
classifier = cfg['classifier']
228+
if not isinstance(classifier, (tuple, list)):
229+
classifier = classifier,
230+
for c in classifier:
231+
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
230232

231233
# check first conv(s) names match default_cfg
232234
first_conv = cfg['first_conv']

timm/models/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
221221
if num_classes != default_cfg['num_classes']:
222222
for classifier_name in classifiers:
223223
# completely discard fully connected if model num_classes doesn't match pretrained weights
224-
del state_dict[classifier_name + '.weight']
225-
del state_dict[classifier_name + '.bias']
224+
state_dict.pop(classifier_name + '.weight', None)
225+
state_dict.pop(classifier_name + '.bias', None)
226226
strict = False
227227
elif label_offset > 0:
228228
for classifier_name in classifiers:

timm/models/vision_transformer.py

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,25 @@ def _cfg(url='', **kwargs):
140140
num_classes=21843),
141141

142142
# SAM trained models (https://arxiv.org/abs/2106.01548)
143-
'vit_base_patch32_sam_224': _cfg(
143+
'vit_base_patch32_224_sam': _cfg(
144144
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
145-
'vit_base_patch16_sam_224': _cfg(
145+
'vit_base_patch16_224_sam': _cfg(
146146
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
147147

148+
# DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
149+
'vit_small_patch16_224_dino': _cfg(
150+
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
151+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
152+
'vit_small_patch8_224_dino': _cfg(
153+
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
154+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
155+
'vit_base_patch16_224_dino': _cfg(
156+
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
157+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
158+
'vit_base_patch8_224_dino': _cfg(
159+
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
160+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
161+
148162
# deit models (FB weights)
149163
'deit_tiny_patch16_224': _cfg(
150164
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
@@ -699,26 +713,6 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
699713
return model
700714

701715

702-
@register_model
703-
def vit_base_patch16_sam_224(pretrained=False, **kwargs):
704-
""" ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
705-
"""
706-
# NOTE original SAM weights release worked with representation_size=768
707-
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
708-
model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs)
709-
return model
710-
711-
712-
@register_model
713-
def vit_base_patch32_sam_224(pretrained=False, **kwargs):
714-
""" ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
715-
"""
716-
# NOTE original SAM weights release worked with representation_size=768
717-
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
718-
model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs)
719-
return model
720-
721-
722716
@register_model
723717
def vit_huge_patch14_224(pretrained=False, **kwargs):
724718
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
@@ -851,6 +845,62 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
851845
return model
852846

853847

848+
@register_model
849+
def vit_base_patch16_224_sam(pretrained=False, **kwargs):
850+
""" ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
851+
"""
852+
# NOTE original SAM weights release worked with representation_size=768
853+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
854+
model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs)
855+
return model
856+
857+
858+
@register_model
859+
def vit_base_patch32_224_sam(pretrained=False, **kwargs):
860+
""" ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
861+
"""
862+
# NOTE original SAM weights release worked with representation_size=768
863+
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
864+
model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs)
865+
return model
866+
867+
868+
@register_model
869+
def vit_small_patch16_224_dino(pretrained=False, **kwargs):
870+
""" ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
871+
"""
872+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
873+
model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs)
874+
return model
875+
876+
877+
@register_model
878+
def vit_small_patch8_224_dino(pretrained=False, **kwargs):
879+
""" ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
880+
"""
881+
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
882+
model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs)
883+
return model
884+
885+
886+
@register_model
887+
def vit_base_patch16_224_dino(pretrained=False, **kwargs):
888+
""" ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
889+
"""
890+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
891+
model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs)
892+
return model
893+
894+
895+
@register_model
896+
def vit_base_patch8_224_dino(pretrained=False, **kwargs):
897+
""" ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
898+
"""
899+
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
900+
model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs)
901+
return model
902+
903+
854904
@register_model
855905
def deit_tiny_patch16_224(pretrained=False, **kwargs):
856906
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).

0 commit comments

Comments
 (0)