Skip to content

Commit 914544f

Browse files
committed
Add beitv2 224x224 checkpoints from https://github.com/microsoft/unilm/tree/master/beit2
1 parent dc90816 commit 914544f

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

timm/models/beit.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
11
""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
22
33
Model from official source: https://github.com/microsoft/unilm/tree/master/beit
4+
and
5+
https://github.com/microsoft/unilm/tree/master/beit2
6+
7+
@inproceedings{beit,
8+
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
9+
author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
10+
booktitle={International Conference on Learning Representations},
11+
year={2022},
12+
url={https://openreview.net/forum?id=p-BhZSz59o4}
13+
}
14+
15+
@article{beitv2,
16+
title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
17+
author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
18+
year={2022},
19+
eprint={2208.06366},
20+
archivePrefix={arXiv},
21+
primaryClass={cs.CV}
22+
}
423
524
At this point only the 1k fine-tuned classification weights and model configs have been added,
625
see original source above for pre-training models and procedure.
@@ -27,6 +46,7 @@
2746
import torch.nn.functional as F
2847
from torch.utils.checkpoint import checkpoint
2948

49+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
3050
from .helpers import build_model_with_cfg
3151
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
3252
from .registry import register_model
@@ -69,6 +89,26 @@ def _cfg(url='', **kwargs):
6989
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
7090
num_classes=21841,
7191
),
92+
93+
'beitv2_base_patch16_224': _cfg(
94+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
95+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
96+
),
97+
'beitv2_base_patch16_224_in22k': _cfg(
98+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
99+
num_classes=21841,
100+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
101+
),
102+
'beitv2_large_patch16_224': _cfg(
103+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
104+
crop_pct=0.95,
105+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
106+
),
107+
'beitv2_large_patch16_224_in22k': _cfg(
108+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
109+
num_classes=21841,
110+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
111+
),
72112
}
73113

74114

@@ -417,3 +457,39 @@ def beit_large_patch16_224_in22k(pretrained=False, **kwargs):
417457
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
418458
model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
419459
return model
460+
461+
462+
@register_model
463+
def beitv2_base_patch16_224(pretrained=False, **kwargs):
464+
model_kwargs = dict(
465+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
466+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
467+
model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs)
468+
return model
469+
470+
471+
@register_model
472+
def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs):
473+
model_kwargs = dict(
474+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
475+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
476+
model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
477+
return model
478+
479+
480+
@register_model
481+
def beitv2_large_patch16_224(pretrained=False, **kwargs):
482+
model_kwargs = dict(
483+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
484+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
485+
model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
486+
return model
487+
488+
489+
@register_model
490+
def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs):
491+
model_kwargs = dict(
492+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
493+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
494+
model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
495+
return model

0 commit comments

Comments
 (0)