Skip to content

Commit 59b6222

Browse files
committed
Change ijepa names, add pretrain cfg for reg experimentts
1 parent 7136516 commit 59b6222

File tree

1 file changed

+27
-31
lines changed

1 file changed

+27
-31
lines changed

timm/models/vision_transformer.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,23 +1529,23 @@ def _cfg(url='', **kwargs):
15291529
license='cc-by-nc-4.0',
15301530
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
15311531

1532-
'vit_huge_patch14_224_ijepa.in1k': _cfg(
1532+
'vit_huge_patch14_ijepa_224.in1k': _cfg(
15331533
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
15341534
# hf_hub_id='timm/',
15351535
license='cc-by-nc-4.0',
15361536
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1537-
'vit_huge_patch14_224_ijepa.in22k': _cfg(
1537+
'vit_huge_patch14_ijepa_224.in22k': _cfg(
15381538
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
15391539
# hf_hub_id='timm/',
15401540
license='cc-by-nc-4.0',
15411541
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1542-
'vit_huge_patch16_448_ijepa.in1k': _cfg(
1542+
'vit_huge_patch16_ijepa_448.in1k': _cfg(
15431543
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
15441544
# hf_hub_id='timm/',
15451545
license='cc-by-nc-4.0',
15461546
input_size=(3, 448, 448), crop_pct=1.0,
15471547
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1548-
'vit_gigantic_patch16_224_ijepa.in22k': _cfg(
1548+
'vit_gigantic_patch16_ijepa_224.in22k': _cfg(
15491549
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
15501550
# hf_hub_id='timm/',
15511551
license='cc-by-nc-4.0',
@@ -1589,6 +1589,12 @@ def _cfg(url='', **kwargs):
15891589
hf_hub_filename='open_clip_pytorch_model.bin',
15901590
input_size=(3, 384, 384),
15911591
num_classes=0),
1592+
1593+
'vit_medium_patch16_reg4_256': _cfg(
1594+
input_size=(3, 256, 256)),
1595+
'vit_medium_patch16_reg4_gap_256': _cfg(
1596+
input_size=(3, 256, 256)),
1597+
'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)),
15921598
})
15931599

15941600

@@ -2185,33 +2191,33 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
21852191

21862192

21872193
@register_model
2188-
def vit_huge_patch14_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
2194+
def vit_huge_patch14_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer:
21892195
""" ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243
21902196
"""
21912197
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg')
21922198
model = _create_vision_transformer(
2193-
'vit_huge_patch14_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
2199+
'vit_huge_patch14_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs))
21942200
return model
21952201

21962202

21972203
@register_model
2198-
def vit_huge_patch16_448_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
2204+
def vit_huge_patch16_ijepa_448(pretrained=False, **kwargs) -> VisionTransformer:
21992205
""" ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243
22002206
"""
22012207
model_args = dict(
22022208
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448)
22032209
model = _create_vision_transformer(
2204-
'vit_huge_patch16_448_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
2210+
'vit_huge_patch16_ijepa_448', pretrained=pretrained, **dict(model_args, **kwargs))
22052211
return model
22062212

22072213

22082214
@register_model
2209-
def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
2215+
def vit_gigantic_patch16_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer:
22102216
""" ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243
22112217
"""
22122218
model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
22132219
model = _create_vision_transformer(
2214-
'vit_gigantic_patch16_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
2220+
'vit_gigantic_patch16_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs))
22152221
return model
22162222

22172223

@@ -2296,45 +2302,35 @@ def vit_so400m_patch14_siglip_384(pretrained=False, **kwargs) -> VisionTransform
22962302

22972303

22982304
@register_model
2299-
def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer:
2300-
model_args = dict(
2301-
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, no_embed_class=True,
2302-
reg_tokens=8,
2303-
)
2304-
model = _create_vision_transformer(
2305-
'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs))
2306-
return model
2307-
2308-
2309-
@register_model
2310-
def vit_base_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer:
2305+
def vit_medium_patch16_reg4_256(pretrained=False, **kwargs) -> VisionTransformer:
23112306
model_args = dict(
2312-
patch_size=16, embed_dim=768, depth=12, num_heads=8,
2313-
class_token=True, no_embed_class=True, reg_tokens=8,
2307+
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
2308+
no_embed_class=True, reg_tokens=4,
23142309
)
23152310
model = _create_vision_transformer(
2316-
'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs))
2311+
'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
23172312
return model
23182313

23192314

23202315
@register_model
2321-
def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer:
2316+
def vit_medium_patch16_reg4_gap_256(pretrained=False, **kwargs) -> VisionTransformer:
23222317
model_args = dict(
23232318
patch_size=16, embed_dim=512, depth=12, num_heads=8,
2324-
class_token=True, no_embed_class=True, reg_tokens=8,
2319+
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
23252320
)
23262321
model = _create_vision_transformer(
2327-
'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs))
2322+
'vit_medium_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
23282323
return model
23292324

23302325

23312326
@register_model
2332-
def vit_base_patch16_reg8_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
2327+
def vit_base_patch16_reg8_gap_256(pretrained=False, **kwargs) -> VisionTransformer:
23332328
model_args = dict(
2334-
patch_size=16, embed_dim=768, depth=12, num_heads=8, global_pool='avg', reg_tokens=8,
2329+
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
2330+
no_embed_class=True, global_pool='avg', reg_tokens=8,
23352331
)
23362332
model = _create_vision_transformer(
2337-
'vit_base_patch16_reg8_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
2333+
'vit_base_patch16_reg8_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
23382334
return model
23392335

23402336

0 commit comments

Comments
 (0)