Skip to content

Commit e069249

Browse files
committed
Add hf hub entries for laion2b clip models, add huggingface_hub dependency, update some setup/reqs, torch >= 1.7
1 parent 9d65557 commit e069249

File tree

5 files changed

+30
-20
lines changed

5 files changed

+30
-20
lines changed

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
torch>=1.4.0
2-
torchvision>=0.5.0
1+
torch>=1.7
2+
torchvision
33
pyyaml
4+
huggingface_hub

setup.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
# 3 - Alpha
2626
# 4 - Beta
2727
# 5 - Production/Stable
28-
'Development Status :: 3 - Alpha',
28+
'Development Status :: 4 - Beta',
2929
'Intended Audience :: Education',
3030
'Intended Audience :: Science/Research',
3131
'License :: OSI Approved :: Apache Software License',
3232
'Programming Language :: Python :: 3.6',
3333
'Programming Language :: Python :: 3.7',
3434
'Programming Language :: Python :: 3.8',
35+
'Programming Language :: Python :: 3.9',
36+
'Programming Language :: Python :: 3.10',
3537
'Topic :: Scientific/Engineering',
3638
'Topic :: Scientific/Engineering :: Artificial Intelligence',
3739
'Topic :: Software Development',
@@ -40,9 +42,10 @@
4042
],
4143

4244
# Note that this is a string of words separated by whitespace, not a list.
43-
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
45+
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
4446
packages=find_packages(exclude=['convert', 'tests', 'results']),
4547
include_package_data=True,
46-
install_requires=['torch >= 1.4', 'torchvision'],
48+
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'],
4749
python_requires='>=3.6',
4850
)
51+

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
EXCLUDE_FILTERS = [
3939
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
4040
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
41-
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_g*', 'swin*huge*',
41+
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
4242
'swin*giant*']
4343
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*']
4444
else:

timm/data/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
66
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
77
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
8+
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
9+
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)

timm/models/vision_transformer.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
import torch.nn.functional as F
3131
import torch.utils.checkpoint
3232

33-
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
33+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\
34+
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
3435
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq
3536
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
3637
from .registry import register_model
@@ -106,7 +107,7 @@ def _cfg(url='', **kwargs):
106107
'vit_large_patch14_224': _cfg(url=''),
107108
'vit_huge_patch14_224': _cfg(url=''),
108109
'vit_giant_patch14_224': _cfg(url=''),
109-
'vit_gee_patch14_224': _cfg(url=''),
110+
'vit_gigantic_patch14_224': _cfg(url=''),
110111

111112

112113
# patch models, imagenet21k (weights from official Google JAX impl)
@@ -179,17 +180,21 @@ def _cfg(url='', **kwargs):
179180
'vit_base_patch16_18x2_224': _cfg(url=''),
180181

181182
'vit_base_patch32_224_clip_laion2b': _cfg(
182-
hf_hub_id='',
183-
num_classes=512),
183+
hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
184+
hf_hub_filename='open_clip_pytorch_model.bin',
185+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
184186
'vit_large_patch14_224_clip_laion2b': _cfg(
185-
hf_hub_id='',
186-
num_classes=768),
187+
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
188+
hf_hub_filename='open_clip_pytorch_model.bin',
189+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768),
187190
'vit_huge_patch14_224_clip_laion2b': _cfg(
188-
hf_hub_id='',
189-
num_classes=1024),
191+
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
192+
hf_hub_filename='open_clip_pytorch_model.bin',
193+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
190194
'vit_giant_patch14_224_clip_laion2b': _cfg(
191-
hf_hub_id='',
192-
num_classes=1024),
195+
hf_hub_id='CLIP-ViT-g-14-laion2B-s12B-b42K',
196+
hf_hub_filename='open_clip_pytorch_model.bin',
197+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
193198

194199
}
195200

@@ -960,12 +965,11 @@ def vit_giant_patch14_224(pretrained=False, **kwargs):
960965

961966

962967
@register_model
963-
def vit_gee_patch14_224(pretrained=False, **kwargs):
964-
""" ViT-GEE (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
965-
As per https://twitter.com/wightmanr/status/1570549064667889666
968+
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
969+
""" ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
966970
"""
967971
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
968-
model = _create_vision_transformer('vit_gee_patch14_224', pretrained=pretrained, **model_kwargs)
972+
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
969973
return model
970974

971975

0 commit comments

Comments
 (0)