@@ -567,7 +567,11 @@ def get_classifier(self):
567567 def reset_classifier (self , num_classes : int , global_pool = None ):
568568 self .num_classes = num_classes
569569 if global_pool is not None :
570- assert global_pool in ('' , 'avg' , 'token' )
570+ assert global_pool in ('' , 'avg' , 'token' , 'map' )
571+ if global_pool == 'map' and self .attn_pool is None :
572+ assert False , "Cannot currently add attention pooling in reset_classifier()."
573+ elif global_pool != 'map ' and self .attn_pool is not None :
574+ self .attn_pool = None # remove attention pooling
571575 self .global_pool = global_pool
572576 self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
573577
@@ -937,10 +941,14 @@ def _convert_openai_clip(state_dict, model):
937941def _convert_dinov2 (state_dict , model ):
938942 import re
939943 out_dict = {}
944+ state_dict .pop ("mask_token" , None )
945+ if 'register_tokens' in state_dict :
946+ # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed)
947+ out_dict ['reg_token' ] = state_dict .pop ('register_tokens' )
948+ out_dict ['cls_token' ] = state_dict .pop ('cls_token' ) + state_dict ['pos_embed' ][:, 0 ]
949+ out_dict ['pos_embed' ] = state_dict .pop ('pos_embed' )[:, 1 :]
940950 for k , v in state_dict .items ():
941- if k == "mask_token" :
942- continue
943- elif re .match (r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)" , k ):
951+ if re .match (r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)" , k ):
944952 out_dict [k .replace ("w12" , "fc1" )] = v
945953 continue
946954 elif re .match (r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)" , k ):
@@ -1229,6 +1237,32 @@ def _cfg(url='', **kwargs):
12291237 mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ,
12301238 input_size = (3 , 518 , 518 ), crop_pct = 1.0 ),
12311239
1240+ # DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only)
1241+ 'vit_small_patch14_reg4_dinov2.lvd142m' : _cfg (
1242+ url = 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth' ,
1243+ # hf_hub_id='timm/',
1244+ license = 'apache-2.0' ,
1245+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ,
1246+ input_size = (3 , 518 , 518 ), crop_pct = 1.0 ),
1247+ 'vit_base_patch14_reg4_dinov2.lvd142m' : _cfg (
1248+ url = 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth' ,
1249+ # hf_hub_id='timm/',
1250+ license = 'apache-2.0' ,
1251+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ,
1252+ input_size = (3 , 518 , 518 ), crop_pct = 1.0 ),
1253+ 'vit_large_patch14_reg4_dinov2.lvd142m' : _cfg (
1254+ url = 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth' ,
1255+ # hf_hub_id='timm/',
1256+ license = 'apache-2.0' ,
1257+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ,
1258+ input_size = (3 , 518 , 518 ), crop_pct = 1.0 ),
1259+ 'vit_giant_patch14_reg4_dinov2.lvd142m' : _cfg (
1260+ url = 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth' ,
1261+ # hf_hub_id='timm/',
1262+ license = 'apache-2.0' ,
1263+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ,
1264+ input_size = (3 , 518 , 518 ), crop_pct = 1.0 ),
1265+
12321266 # ViT ImageNet-21K-P pretraining by MILL
12331267 'vit_base_patch16_224_miil.in21k' : _cfg (
12341268 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth' ,
@@ -2173,9 +2207,7 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer:
21732207def vit_small_patch14_dinov2 (pretrained = False , ** kwargs ) -> VisionTransformer :
21742208 """ ViT-S/14 for DINOv2
21752209 """
2176- model_args = dict (
2177- patch_size = 14 , embed_dim = 384 , depth = 12 , num_heads = 6 , init_values = 1e-5 , img_size = 518 ,
2178- )
2210+ model_args = dict (patch_size = 14 , embed_dim = 384 , depth = 12 , num_heads = 6 , init_values = 1e-5 , img_size = 518 )
21792211 model = _create_vision_transformer (
21802212 'vit_small_patch14_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
21812213 return model
@@ -2185,9 +2217,7 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
21852217def vit_base_patch14_dinov2 (pretrained = False , ** kwargs ) -> VisionTransformer :
21862218 """ ViT-B/14 for DINOv2
21872219 """
2188- model_args = dict (
2189- patch_size = 14 , embed_dim = 768 , depth = 12 , num_heads = 12 , init_values = 1e-5 , img_size = 518 ,
2190- )
2220+ model_args = dict (patch_size = 14 , embed_dim = 768 , depth = 12 , num_heads = 12 , init_values = 1e-5 , img_size = 518 )
21912221 model = _create_vision_transformer (
21922222 'vit_base_patch14_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
21932223 return model
@@ -2197,9 +2227,7 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
21972227def vit_large_patch14_dinov2 (pretrained = False , ** kwargs ) -> VisionTransformer :
21982228 """ ViT-L/14 for DINOv2
21992229 """
2200- model_args = dict (
2201- patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 , init_values = 1e-5 , img_size = 518 ,
2202- )
2230+ model_args = dict (patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 , init_values = 1e-5 , img_size = 518 )
22032231 model = _create_vision_transformer (
22042232 'vit_large_patch14_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
22052233 return model
@@ -2209,12 +2237,10 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
22092237def vit_giant_patch14_dinov2 (pretrained = False , ** kwargs ) -> VisionTransformer :
22102238 """ ViT-G/14 for DINOv2
22112239 """
2212-
22132240 # The hidden_features of SwiGLU is calculated by:
22142241 # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
22152242 # When embed_dim=1536, hidden_features=4096
22162243 # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
2217-
22182244 model_args = dict (
22192245 patch_size = 14 , embed_dim = 1536 , depth = 40 , num_heads = 24 , init_values = 1e-5 ,
22202246 mlp_ratio = 2.66667 * 2 , mlp_layer = SwiGLUPacked , img_size = 518 , act_layer = nn .SiLU
@@ -2224,6 +2250,62 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
22242250 return model
22252251
22262252
2253+ @register_model
2254+ def vit_small_patch14_reg4_dinov2 (pretrained = False , ** kwargs ) -> VisionTransformer :
2255+ """ ViT-S/14 for DINOv2 w/ 4 registers
2256+ """
2257+ model_args = dict (
2258+ patch_size = 14 , embed_dim = 384 , depth = 12 , num_heads = 6 , init_values = 1e-5 ,
2259+ reg_tokens = 4 , no_embed_class = True ,
2260+ )
2261+ model = _create_vision_transformer (
2262+ 'vit_small_patch14_reg4_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2263+ return model
2264+
2265+
2266+ @register_model
2267+ def vit_base_patch14_reg4_dinov2 (pretrained = False , ** kwargs ) -> VisionTransformer :
2268+ """ ViT-B/14 for DINOv2 w/ 4 registers
2269+ """
2270+ model_args = dict (
2271+ patch_size = 14 , embed_dim = 768 , depth = 12 , num_heads = 12 , init_values = 1e-5 ,
2272+ reg_tokens = 4 , no_embed_class = True ,
2273+ )
2274+ model = _create_vision_transformer (
2275+ 'vit_base_patch14_reg4_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2276+ return model
2277+
2278+
2279+ @register_model
2280+ def vit_large_patch14_reg4_dinov2 (pretrained = False , ** kwargs ) -> VisionTransformer :
2281+ """ ViT-L/14 for DINOv2 w/ 4 registers
2282+ """
2283+ model_args = dict (
2284+ patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 , init_values = 1e-5 ,
2285+ reg_tokens = 4 , no_embed_class = True ,
2286+ )
2287+ model = _create_vision_transformer (
2288+ 'vit_large_patch14_reg4_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2289+ return model
2290+
2291+
2292+ @register_model
2293+ def vit_giant_patch14_reg4_dinov2 (pretrained = False , ** kwargs ) -> VisionTransformer :
2294+ """ ViT-G/14 for DINOv2
2295+ """
2296+ # The hidden_features of SwiGLU is calculated by:
2297+ # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
2298+ # When embed_dim=1536, hidden_features=4096
2299+ # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
2300+ model_args = dict (
2301+ patch_size = 14 , embed_dim = 1536 , depth = 40 , num_heads = 24 , init_values = 1e-5 , mlp_ratio = 2.66667 * 2 ,
2302+ mlp_layer = SwiGLUPacked , act_layer = nn .SiLU , reg_tokens = 4 , no_embed_class = True ,
2303+ )
2304+ model = _create_vision_transformer (
2305+ 'vit_giant_patch14_reg4_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2306+ return model
2307+
2308+
22272309@register_model
22282310def vit_base_patch16_siglip_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
22292311 model_args = dict (
0 commit comments