@@ -606,6 +606,7 @@ def __init__(
606606 self .attn_pool = AttentionPoolLatent (
607607 self .embed_dim ,
608608 num_heads = num_heads ,
609+ mlp_ratio = mlp_ratio ,
609610 norm_layer = norm_layer ,
610611 )
611612 else :
@@ -1644,6 +1645,39 @@ def _cfg(url='', **kwargs):
16441645 input_size = (3 , 256 , 256 ),
16451646 # hf_hub_id='timm/',
16461647 num_classes = 0 ),
1648+ 'vit_base_patch16_siglip_384' : _cfg (
1649+ file = '' ,
1650+ custom_load = True ,
1651+ input_size = (3 , 384 , 384 ),
1652+ # hf_hub_id='timm/',
1653+ num_classes = 0 ),
1654+ 'vit_base_patch16_siglip_512' : _cfg (
1655+ file = '' ,
1656+ custom_load = True ,
1657+ input_size = (3 , 512 , 512 ),
1658+ # hf_hub_id='timm/',
1659+ num_classes = 0 ),
1660+ 'vit_large_patch16_siglip_256' : _cfg (
1661+ custom_load = True ,
1662+ input_size = (3 , 256 , 256 ),
1663+ # hf_hub_id='timm/',
1664+ num_classes = 0 ),
1665+ 'vit_large_patch16_siglip_384' : _cfg (
1666+ custom_load = True ,
1667+ input_size = (3 , 384 , 384 ),
1668+ # hf_hub_id='timm/',
1669+ num_classes = 0 ),
1670+ 'vit_so400m_patch14_siglip_224' : _cfg (
1671+ # file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
1672+ custom_load = True ,
1673+ # hf_hub_id='timm/',
1674+ num_classes = 0 ),
1675+ 'vit_so400m_patch14_siglip_384' : _cfg (
1676+ #file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
1677+ custom_load = True ,
1678+ # hf_hub_id='timm/',
1679+ input_size = (3 , 384 , 384 ),
1680+ num_classes = 0 ),
16471681})
16481682
16491683
@@ -2290,6 +2324,65 @@ def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer
22902324 return model
22912325
22922326
2327+ @register_model
2328+ def vit_base_patch16_siglip_384 (pretrained = False , ** kwargs ) -> VisionTransformer :
2329+ model_args = dict (
2330+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , class_token = False , global_pool = 'map' ,
2331+ )
2332+ model = _create_vision_transformer (
2333+ 'vit_base_patch16_siglip_384' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2334+ return model
2335+
2336+
2337+ @register_model
2338+ def vit_base_patch16_siglip_512 (pretrained = False , ** kwargs ) -> VisionTransformer :
2339+ model_args = dict (
2340+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , class_token = False , global_pool = 'map' ,
2341+ )
2342+ model = _create_vision_transformer (
2343+ 'vit_base_patch16_siglip_512' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2344+ return model
2345+
2346+
2347+ @register_model
2348+ def vit_large_patch16_siglip_256 (pretrained = False , ** kwargs ) -> VisionTransformer :
2349+ model_args = dict (
2350+ patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , class_token = False , global_pool = 'map' ,
2351+ )
2352+ model = _create_vision_transformer (
2353+ 'vit_large_patch16_siglip_256' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2354+ return model
2355+
2356+
2357+ @register_model
2358+ def vit_large_patch16_siglip_384 (pretrained = False , ** kwargs ) -> VisionTransformer :
2359+ model_args = dict (
2360+ patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , class_token = False , global_pool = 'map' ,
2361+ )
2362+ model = _create_vision_transformer (
2363+ 'vit_large_patch16_siglip_384' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2364+ return model
2365+
2366+
2367+ @register_model
2368+ def vit_so400m_patch14_siglip_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
2369+ model_args = dict (
2370+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , mlp_ratio = 3.7362 , class_token = False , global_pool = 'map' ,
2371+ )
2372+ model = _create_vision_transformer (
2373+ 'vit_so400m_patch14_siglip_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2374+ return model
2375+
2376+
2377+ @register_model
2378+ def vit_so400m_patch14_siglip_384 (pretrained = False , ** kwargs ) -> VisionTransformer :
2379+ model_args = dict (
2380+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , mlp_ratio = 3.7362 , class_token = False , global_pool = 'map' ,
2381+ )
2382+ model = _create_vision_transformer (
2383+ 'vit_so400m_patch14_siglip_384' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2384+ return model
2385+
22932386
22942387@register_model
22952388def vit_medium_patch16_reg8_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
0 commit comments