@@ -950,17 +950,6 @@ def _convert_dinov2(state_dict, model):
950950 return out_dict
951951
952952
953- def _convert_ijepa (state_dict , model ):
954- out_dict = {}
955- for k , v in state_dict ['encoder' ].items ():
956- if k .startswith ('module.' ):
957- k = k [7 :]
958- if k .startswith ('norm.' ):
959- k = 'fc_norm.' + k [5 :]
960- out_dict [k ] = v
961- return out_dict
962-
963-
964953def checkpoint_filter_fn (
965954 state_dict ,
966955 model ,
@@ -973,6 +962,7 @@ def checkpoint_filter_fn(
973962 out_dict = {}
974963 state_dict = state_dict .get ('model' , state_dict )
975964 state_dict = state_dict .get ('state_dict' , state_dict )
965+ prefix = ''
976966
977967 if 'visual.class_embedding' in state_dict :
978968 return _convert_openai_clip (state_dict , model )
@@ -981,13 +971,17 @@ def checkpoint_filter_fn(
981971 state_dict = _convert_dinov2 (state_dict , model )
982972
983973 if "encoder" in state_dict :
984- state_dict = _convert_ijepa (state_dict , model )
974+ state_dict = state_dict ['encoder' ]
975+ prefix = 'module.'
985976
986977 if 'visual.trunk.pos_embed' in state_dict :
987978 # convert an OpenCLIP model with timm vision encoder
979+ # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
988980 prefix = 'visual.trunk.'
981+
982+ if prefix :
983+ # filter on & remove prefix string from keys
989984 state_dict = {k [len (prefix ):]: v for k , v in state_dict .items () if k .startswith (prefix )}
990- # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
991985
992986 for k , v in state_dict .items ():
993987 if 'patch_embed.proj.weight' in k :
@@ -1529,23 +1523,23 @@ def _cfg(url='', **kwargs):
15291523 license = 'cc-by-nc-4.0' ,
15301524 mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ),
15311525
1532- 'vit_huge_patch14_ijepa_224.in1k ' : _cfg (
1526+ 'vit_huge_patch14_gap_224.in1k_ijepa ' : _cfg (
15331527 url = 'https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar' ,
15341528 # hf_hub_id='timm/',
15351529 license = 'cc-by-nc-4.0' ,
15361530 mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ),
1537- 'vit_huge_patch14_ijepa_224.in22k ' : _cfg (
1531+ 'vit_huge_patch14_gap_224.in22k_ijepa ' : _cfg (
15381532 url = 'https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar' ,
15391533 # hf_hub_id='timm/',
15401534 license = 'cc-by-nc-4.0' ,
15411535 mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ),
1542- 'vit_huge_patch16_ijepa_448.in1k ' : _cfg (
1536+ 'vit_huge_patch16_gap_448.in1k_ijepa ' : _cfg (
15431537 url = 'https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar' ,
15441538 # hf_hub_id='timm/',
15451539 license = 'cc-by-nc-4.0' ,
15461540 input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
15471541 mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ),
1548- 'vit_gigantic_patch16_ijepa_224.in22k ' : _cfg (
1542+ 'vit_giant_patch16_gap_224.in22k_ijepa ' : _cfg (
15491543 url = 'https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar' ,
15501544 # hf_hub_id='timm/',
15511545 license = 'cc-by-nc-4.0' ,
@@ -1856,7 +1850,7 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs) -> VisionTransformer:
18561850
18571851@register_model
18581852def vit_base_patch16_gap_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
1859- """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
1853+ """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
18601854 """
18611855 model_args = dict (
18621856 patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 16 , class_token = False , global_pool = 'avg' , fc_norm = False )
@@ -1865,6 +1859,40 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
18651859 return model
18661860
18671861
1862+ @register_model
1863+ def vit_huge_patch14_gap_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
1864+ """ ViT-Huge model (ViT-H/14) w/ no class token, avg pool
1865+ """
1866+ model_args = dict (
1867+ patch_size = 14 , embed_dim = 1280 , depth = 32 , num_heads = 16 , class_token = False , global_pool = 'avg' , fc_norm = False )
1868+ model = _create_vision_transformer (
1869+ 'vit_huge_patch14_gap_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1870+ return model
1871+
1872+
1873+ @register_model
1874+ def vit_huge_patch16_gap_448 (pretrained = False , ** kwargs ) -> VisionTransformer :
1875+ """ ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448
1876+ """
1877+ model_args = dict (
1878+ patch_size = 16 , embed_dim = 1280 , depth = 32 , num_heads = 16 , class_token = False , global_pool = 'avg' , fc_norm = False )
1879+ model = _create_vision_transformer (
1880+ 'vit_huge_patch16_gap_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1881+ return model
1882+
1883+
1884+ @register_model
1885+ def vit_giant_patch16_gap_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
1886+ """ ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool
1887+ """
1888+ model_args = dict (
1889+ patch_size = 16 , embed_dim = 1408 , depth = 40 , num_heads = 16 , mlp_ratio = 48 / 11 ,
1890+ class_token = False , global_pool = 'avg' , fc_norm = False )
1891+ model = _create_vision_transformer (
1892+ 'vit_giant_patch16_gap_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1893+ return model
1894+
1895+
18681896@register_model
18691897def vit_base_patch32_clip_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
18701898 """ ViT-B/32 CLIP image tower @ 224x224
@@ -2190,37 +2218,6 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
21902218 return model
21912219
21922220
2193- @register_model
2194- def vit_huge_patch14_ijepa_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
2195- """ ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243
2196- """
2197- model_args = dict (patch_size = 14 , embed_dim = 1280 , depth = 32 , num_heads = 16 , class_token = False , global_pool = 'avg' )
2198- model = _create_vision_transformer (
2199- 'vit_huge_patch14_ijepa_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2200- return model
2201-
2202-
2203- @register_model
2204- def vit_huge_patch16_ijepa_448 (pretrained = False , ** kwargs ) -> VisionTransformer :
2205- """ ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243
2206- """
2207- model_args = dict (
2208- patch_size = 16 , embed_dim = 1280 , depth = 32 , num_heads = 16 , class_token = False , global_pool = 'avg' , img_size = 448 )
2209- model = _create_vision_transformer (
2210- 'vit_huge_patch16_ijepa_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2211- return model
2212-
2213-
2214- @register_model
2215- def vit_gigantic_patch16_ijepa_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
2216- """ ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243
2217- """
2218- model_args = dict (patch_size = 16 , embed_dim = 1664 , mlp_ratio = 64 / 13 , depth = 48 , num_heads = 16 )
2219- model = _create_vision_transformer (
2220- 'vit_gigantic_patch16_ijepa_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2221- return model
2222-
2223-
22242221@register_model
22252222def vit_base_patch16_siglip_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
22262223 model_args = dict (
0 commit comments