Skip to content

Commit e728f3e

Browse files
committed
Cleanup ijepa models, they're just gap (global-avg-pool) models w/o heads. fc-norm conversion was wrong, gigantic should have been giant
1 parent 49a459e commit e728f3e

File tree

1 file changed

+46
-49
lines changed

1 file changed

+46
-49
lines changed

timm/models/vision_transformer.py

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -950,17 +950,6 @@ def _convert_dinov2(state_dict, model):
950950
return out_dict
951951

952952

953-
def _convert_ijepa(state_dict, model):
954-
out_dict = {}
955-
for k, v in state_dict['encoder'].items():
956-
if k.startswith('module.'):
957-
k = k[7:]
958-
if k.startswith('norm.'):
959-
k = 'fc_norm.' + k[5:]
960-
out_dict[k] = v
961-
return out_dict
962-
963-
964953
def checkpoint_filter_fn(
965954
state_dict,
966955
model,
@@ -973,6 +962,7 @@ def checkpoint_filter_fn(
973962
out_dict = {}
974963
state_dict = state_dict.get('model', state_dict)
975964
state_dict = state_dict.get('state_dict', state_dict)
965+
prefix = ''
976966

977967
if 'visual.class_embedding' in state_dict:
978968
return _convert_openai_clip(state_dict, model)
@@ -981,13 +971,17 @@ def checkpoint_filter_fn(
981971
state_dict = _convert_dinov2(state_dict, model)
982972

983973
if "encoder" in state_dict:
984-
state_dict = _convert_ijepa(state_dict, model)
974+
state_dict = state_dict['encoder']
975+
prefix = 'module.'
985976

986977
if 'visual.trunk.pos_embed' in state_dict:
987978
# convert an OpenCLIP model with timm vision encoder
979+
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
988980
prefix = 'visual.trunk.'
981+
982+
if prefix:
983+
# filter on & remove prefix string from keys
989984
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
990-
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
991985

992986
for k, v in state_dict.items():
993987
if 'patch_embed.proj.weight' in k:
@@ -1529,23 +1523,23 @@ def _cfg(url='', **kwargs):
15291523
license='cc-by-nc-4.0',
15301524
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
15311525

1532-
'vit_huge_patch14_ijepa_224.in1k': _cfg(
1526+
'vit_huge_patch14_gap_224.in1k_ijepa': _cfg(
15331527
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
15341528
# hf_hub_id='timm/',
15351529
license='cc-by-nc-4.0',
15361530
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1537-
'vit_huge_patch14_ijepa_224.in22k': _cfg(
1531+
'vit_huge_patch14_gap_224.in22k_ijepa': _cfg(
15381532
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
15391533
# hf_hub_id='timm/',
15401534
license='cc-by-nc-4.0',
15411535
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1542-
'vit_huge_patch16_ijepa_448.in1k': _cfg(
1536+
'vit_huge_patch16_gap_448.in1k_ijepa': _cfg(
15431537
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
15441538
# hf_hub_id='timm/',
15451539
license='cc-by-nc-4.0',
15461540
input_size=(3, 448, 448), crop_pct=1.0,
15471541
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1548-
'vit_gigantic_patch16_ijepa_224.in22k': _cfg(
1542+
'vit_giant_patch16_gap_224.in22k_ijepa': _cfg(
15491543
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
15501544
# hf_hub_id='timm/',
15511545
license='cc-by-nc-4.0',
@@ -1856,7 +1850,7 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs) -> VisionTransformer:
18561850

18571851
@register_model
18581852
def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
1859-
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
1853+
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
18601854
"""
18611855
model_args = dict(
18621856
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
@@ -1865,6 +1859,40 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
18651859
return model
18661860

18671861

1862+
@register_model
1863+
def vit_huge_patch14_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
1864+
""" ViT-Huge model (ViT-H/14) w/ no class token, avg pool
1865+
"""
1866+
model_args = dict(
1867+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
1868+
model = _create_vision_transformer(
1869+
'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
1870+
return model
1871+
1872+
1873+
@register_model
1874+
def vit_huge_patch16_gap_448(pretrained=False, **kwargs) -> VisionTransformer:
1875+
""" ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448
1876+
"""
1877+
model_args = dict(
1878+
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
1879+
model = _create_vision_transformer(
1880+
'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
1881+
return model
1882+
1883+
1884+
@register_model
1885+
def vit_giant_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
1886+
""" ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool
1887+
"""
1888+
model_args = dict(
1889+
patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
1890+
class_token=False, global_pool='avg', fc_norm=False)
1891+
model = _create_vision_transformer(
1892+
'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
1893+
return model
1894+
1895+
18681896
@register_model
18691897
def vit_base_patch32_clip_224(pretrained=False, **kwargs) -> VisionTransformer:
18701898
""" ViT-B/32 CLIP image tower @ 224x224
@@ -2190,37 +2218,6 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
21902218
return model
21912219

21922220

2193-
@register_model
2194-
def vit_huge_patch14_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer:
2195-
""" ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243
2196-
"""
2197-
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg')
2198-
model = _create_vision_transformer(
2199-
'vit_huge_patch14_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs))
2200-
return model
2201-
2202-
2203-
@register_model
2204-
def vit_huge_patch16_ijepa_448(pretrained=False, **kwargs) -> VisionTransformer:
2205-
""" ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243
2206-
"""
2207-
model_args = dict(
2208-
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448)
2209-
model = _create_vision_transformer(
2210-
'vit_huge_patch16_ijepa_448', pretrained=pretrained, **dict(model_args, **kwargs))
2211-
return model
2212-
2213-
2214-
@register_model
2215-
def vit_gigantic_patch16_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer:
2216-
""" ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243
2217-
"""
2218-
model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
2219-
model = _create_vision_transformer(
2220-
'vit_gigantic_patch16_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs))
2221-
return model
2222-
2223-
22242221
@register_model
22252222
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
22262223
model_args = dict(

0 commit comments

Comments
 (0)