|
1 | 1 | """ BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) |
2 | 2 |
|
3 | 3 | Model from official source: https://github.com/microsoft/unilm/tree/master/beit |
| 4 | +and |
| 5 | +https://github.com/microsoft/unilm/tree/master/beit2 |
| 6 | +
|
| 7 | +@inproceedings{beit, |
| 8 | +title={{BEiT}: {BERT} Pre-Training of Image Transformers}, |
| 9 | +author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei}, |
| 10 | +booktitle={International Conference on Learning Representations}, |
| 11 | +year={2022}, |
| 12 | +url={https://openreview.net/forum?id=p-BhZSz59o4} |
| 13 | +} |
| 14 | +
|
| 15 | +@article{beitv2, |
| 16 | +title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers}, |
| 17 | +author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei}, |
| 18 | +year={2022}, |
| 19 | +eprint={2208.06366}, |
| 20 | +archivePrefix={arXiv}, |
| 21 | +primaryClass={cs.CV} |
| 22 | +} |
4 | 23 |
|
5 | 24 | At this point only the 1k fine-tuned classification weights and model configs have been added, |
6 | 25 | see original source above for pre-training models and procedure. |
|
27 | 46 | import torch.nn.functional as F |
28 | 47 | from torch.utils.checkpoint import checkpoint |
29 | 48 |
|
| 49 | +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
30 | 50 | from .helpers import build_model_with_cfg |
31 | 51 | from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ |
32 | 52 | from .registry import register_model |
@@ -69,6 +89,26 @@ def _cfg(url='', **kwargs): |
69 | 89 | url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', |
70 | 90 | num_classes=21841, |
71 | 91 | ), |
| 92 | + |
| 93 | + 'beitv2_base_patch16_224': _cfg( |
| 94 | + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth', |
| 95 | + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD |
| 96 | + ), |
| 97 | + 'beitv2_base_patch16_224_in22k': _cfg( |
| 98 | + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth', |
| 99 | + num_classes=21841, |
| 100 | + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD |
| 101 | + ), |
| 102 | + 'beitv2_large_patch16_224': _cfg( |
| 103 | + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth', |
| 104 | + crop_pct=0.95, |
| 105 | + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD |
| 106 | + ), |
| 107 | + 'beitv2_large_patch16_224_in22k': _cfg( |
| 108 | + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth', |
| 109 | + num_classes=21841, |
| 110 | + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD |
| 111 | + ), |
72 | 112 | } |
73 | 113 |
|
74 | 114 |
|
@@ -417,3 +457,39 @@ def beit_large_patch16_224_in22k(pretrained=False, **kwargs): |
417 | 457 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) |
418 | 458 | model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) |
419 | 459 | return model |
| 460 | + |
| 461 | + |
| 462 | +@register_model |
| 463 | +def beitv2_base_patch16_224(pretrained=False, **kwargs): |
| 464 | + model_kwargs = dict( |
| 465 | + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, |
| 466 | + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) |
| 467 | + model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs) |
| 468 | + return model |
| 469 | + |
| 470 | + |
| 471 | +@register_model |
| 472 | +def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs): |
| 473 | + model_kwargs = dict( |
| 474 | + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, |
| 475 | + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) |
| 476 | + model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) |
| 477 | + return model |
| 478 | + |
| 479 | + |
| 480 | +@register_model |
| 481 | +def beitv2_large_patch16_224(pretrained=False, **kwargs): |
| 482 | + model_kwargs = dict( |
| 483 | + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, |
| 484 | + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) |
| 485 | + model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) |
| 486 | + return model |
| 487 | + |
| 488 | + |
| 489 | +@register_model |
| 490 | +def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs): |
| 491 | + model_kwargs = dict( |
| 492 | + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, |
| 493 | + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) |
| 494 | + model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) |
| 495 | + return model |
0 commit comments