Skip to content

Commit 3f02392

Browse files
committed
Add DINOv2 models with register tokens. Convert pos embed to non-overlapping for consistency.
1 parent fe92fd9 commit 3f02392

File tree

1 file changed

+97
-15
lines changed

1 file changed

+97
-15
lines changed

timm/models/vision_transformer.py

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,11 @@ def get_classifier(self):
567567
def reset_classifier(self, num_classes: int, global_pool=None):
568568
self.num_classes = num_classes
569569
if global_pool is not None:
570-
assert global_pool in ('', 'avg', 'token')
570+
assert global_pool in ('', 'avg', 'token', 'map')
571+
if global_pool == 'map' and self.attn_pool is None:
572+
assert False, "Cannot currently add attention pooling in reset_classifier()."
573+
elif global_pool != 'map ' and self.attn_pool is not None:
574+
self.attn_pool = None # remove attention pooling
571575
self.global_pool = global_pool
572576
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
573577

@@ -937,10 +941,14 @@ def _convert_openai_clip(state_dict, model):
937941
def _convert_dinov2(state_dict, model):
938942
import re
939943
out_dict = {}
944+
state_dict.pop("mask_token", None)
945+
if 'register_tokens' in state_dict:
946+
# convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed)
947+
out_dict['reg_token'] = state_dict.pop('register_tokens')
948+
out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0]
949+
out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:]
940950
for k, v in state_dict.items():
941-
if k == "mask_token":
942-
continue
943-
elif re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
951+
if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
944952
out_dict[k.replace("w12", "fc1")] = v
945953
continue
946954
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
@@ -1229,6 +1237,32 @@ def _cfg(url='', **kwargs):
12291237
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
12301238
input_size=(3, 518, 518), crop_pct=1.0),
12311239

1240+
# DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only)
1241+
'vit_small_patch14_reg4_dinov2.lvd142m': _cfg(
1242+
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth',
1243+
# hf_hub_id='timm/',
1244+
license='apache-2.0',
1245+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
1246+
input_size=(3, 518, 518), crop_pct=1.0),
1247+
'vit_base_patch14_reg4_dinov2.lvd142m': _cfg(
1248+
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth',
1249+
# hf_hub_id='timm/',
1250+
license='apache-2.0',
1251+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
1252+
input_size=(3, 518, 518), crop_pct=1.0),
1253+
'vit_large_patch14_reg4_dinov2.lvd142m': _cfg(
1254+
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth',
1255+
# hf_hub_id='timm/',
1256+
license='apache-2.0',
1257+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
1258+
input_size=(3, 518, 518), crop_pct=1.0),
1259+
'vit_giant_patch14_reg4_dinov2.lvd142m': _cfg(
1260+
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth',
1261+
# hf_hub_id='timm/',
1262+
license='apache-2.0',
1263+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
1264+
input_size=(3, 518, 518), crop_pct=1.0),
1265+
12321266
# ViT ImageNet-21K-P pretraining by MILL
12331267
'vit_base_patch16_224_miil.in21k': _cfg(
12341268
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
@@ -2173,9 +2207,7 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer:
21732207
def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
21742208
""" ViT-S/14 for DINOv2
21752209
"""
2176-
model_args = dict(
2177-
patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518,
2178-
)
2210+
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518)
21792211
model = _create_vision_transformer(
21802212
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
21812213
return model
@@ -2185,9 +2217,7 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
21852217
def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
21862218
""" ViT-B/14 for DINOv2
21872219
"""
2188-
model_args = dict(
2189-
patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518,
2190-
)
2220+
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518)
21912221
model = _create_vision_transformer(
21922222
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
21932223
return model
@@ -2197,9 +2227,7 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
21972227
def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
21982228
""" ViT-L/14 for DINOv2
21992229
"""
2200-
model_args = dict(
2201-
patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518,
2202-
)
2230+
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518)
22032231
model = _create_vision_transformer(
22042232
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
22052233
return model
@@ -2209,12 +2237,10 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
22092237
def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
22102238
""" ViT-G/14 for DINOv2
22112239
"""
2212-
22132240
# The hidden_features of SwiGLU is calculated by:
22142241
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
22152242
# When embed_dim=1536, hidden_features=4096
22162243
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
2217-
22182244
model_args = dict(
22192245
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
22202246
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
@@ -2224,6 +2250,62 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
22242250
return model
22252251

22262252

2253+
@register_model
2254+
def vit_small_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
2255+
""" ViT-S/14 for DINOv2 w/ 4 registers
2256+
"""
2257+
model_args = dict(
2258+
patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5,
2259+
reg_tokens=4, no_embed_class=True,
2260+
)
2261+
model = _create_vision_transformer(
2262+
'vit_small_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
2263+
return model
2264+
2265+
2266+
@register_model
2267+
def vit_base_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
2268+
""" ViT-B/14 for DINOv2 w/ 4 registers
2269+
"""
2270+
model_args = dict(
2271+
patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
2272+
reg_tokens=4, no_embed_class=True,
2273+
)
2274+
model = _create_vision_transformer(
2275+
'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
2276+
return model
2277+
2278+
2279+
@register_model
2280+
def vit_large_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
2281+
""" ViT-L/14 for DINOv2 w/ 4 registers
2282+
"""
2283+
model_args = dict(
2284+
patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
2285+
reg_tokens=4, no_embed_class=True,
2286+
)
2287+
model = _create_vision_transformer(
2288+
'vit_large_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
2289+
return model
2290+
2291+
2292+
@register_model
2293+
def vit_giant_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
2294+
""" ViT-G/14 for DINOv2
2295+
"""
2296+
# The hidden_features of SwiGLU is calculated by:
2297+
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
2298+
# When embed_dim=1536, hidden_features=4096
2299+
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
2300+
model_args = dict(
2301+
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2,
2302+
mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True,
2303+
)
2304+
model = _create_vision_transformer(
2305+
'vit_giant_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
2306+
return model
2307+
2308+
22272309
@register_model
22282310
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
22292311
model_args = dict(

0 commit comments

Comments
 (0)