Skip to content

Commit 42daa3b

Browse files
committed
Add full set of SigLIP models
1 parent b9dde58 commit 42daa3b

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed

timm/models/vision_transformer.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def __init__(
606606
self.attn_pool = AttentionPoolLatent(
607607
self.embed_dim,
608608
num_heads=num_heads,
609+
mlp_ratio=mlp_ratio,
609610
norm_layer=norm_layer,
610611
)
611612
else:
@@ -1644,6 +1645,39 @@ def _cfg(url='', **kwargs):
16441645
input_size=(3, 256, 256),
16451646
# hf_hub_id='timm/',
16461647
num_classes=0),
1648+
'vit_base_patch16_siglip_384': _cfg(
1649+
file='',
1650+
custom_load=True,
1651+
input_size=(3, 384, 384),
1652+
# hf_hub_id='timm/',
1653+
num_classes=0),
1654+
'vit_base_patch16_siglip_512': _cfg(
1655+
file='',
1656+
custom_load=True,
1657+
input_size=(3, 512, 512),
1658+
# hf_hub_id='timm/',
1659+
num_classes=0),
1660+
'vit_large_patch16_siglip_256': _cfg(
1661+
custom_load=True,
1662+
input_size=(3, 256, 256),
1663+
# hf_hub_id='timm/',
1664+
num_classes=0),
1665+
'vit_large_patch16_siglip_384': _cfg(
1666+
custom_load=True,
1667+
input_size=(3, 384, 384),
1668+
# hf_hub_id='timm/',
1669+
num_classes=0),
1670+
'vit_so400m_patch14_siglip_224': _cfg(
1671+
# file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
1672+
custom_load=True,
1673+
# hf_hub_id='timm/',
1674+
num_classes=0),
1675+
'vit_so400m_patch14_siglip_384': _cfg(
1676+
#file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
1677+
custom_load=True,
1678+
# hf_hub_id='timm/',
1679+
input_size=(3, 384, 384),
1680+
num_classes=0),
16471681
})
16481682

16491683

@@ -2290,6 +2324,65 @@ def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer
22902324
return model
22912325

22922326

2327+
@register_model
2328+
def vit_base_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
2329+
model_args = dict(
2330+
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
2331+
)
2332+
model = _create_vision_transformer(
2333+
'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
2334+
return model
2335+
2336+
2337+
@register_model
2338+
def vit_base_patch16_siglip_512(pretrained=False, **kwargs) -> VisionTransformer:
2339+
model_args = dict(
2340+
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
2341+
)
2342+
model = _create_vision_transformer(
2343+
'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
2344+
return model
2345+
2346+
2347+
@register_model
2348+
def vit_large_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer:
2349+
model_args = dict(
2350+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
2351+
)
2352+
model = _create_vision_transformer(
2353+
'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
2354+
return model
2355+
2356+
2357+
@register_model
2358+
def vit_large_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
2359+
model_args = dict(
2360+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
2361+
)
2362+
model = _create_vision_transformer(
2363+
'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
2364+
return model
2365+
2366+
2367+
@register_model
2368+
def vit_so400m_patch14_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
2369+
model_args = dict(
2370+
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
2371+
)
2372+
model = _create_vision_transformer(
2373+
'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
2374+
return model
2375+
2376+
2377+
@register_model
2378+
def vit_so400m_patch14_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
2379+
model_args = dict(
2380+
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
2381+
)
2382+
model = _create_vision_transformer(
2383+
'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
2384+
return model
2385+
22932386

22942387
@register_model
22952388
def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer:

0 commit comments

Comments
 (0)