@@ -35,26 +35,34 @@ def _cfg(url='', **kwargs):
3535
3636
3737default_cfgs = {
38- # hybrid in-21k models (weights ported from official Google JAX impl where they exist)
39- 'vit_base_r50_s16_224_in21k' : _cfg (
40- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth' ,
41- num_classes = 21843 , crop_pct = 0.9 ),
42-
43- # hybrid in-1k models (weights ported from official JAX impl)
44- 'vit_base_r50_s16_384' : _cfg (
45- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth' ,
46- input_size = (3 , 384 , 384 ), crop_pct = 1.0 ),
47-
48- # hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
38+ # hybrid in-1k models (weights ported from official JAX impl where they exist)
4939 'vit_tiny_r_s16_p8_224' : _cfg (first_conv = 'patch_embed.backbone.conv' ),
40+ 'vit_tiny_r_s16_p8_384' : _cfg (
41+ first_conv = 'patch_embed.backbone.conv' , input_size = (3 , 384 , 384 ), crop_pct = 1.0 ),
5042 'vit_small_r_s16_p8_224' : _cfg (first_conv = 'patch_embed.backbone.conv' ),
5143 'vit_small_r20_s16_p2_224' : _cfg (),
5244 'vit_small_r20_s16_224' : _cfg (),
5345 'vit_small_r26_s32_224' : _cfg (),
46+ 'vit_small_r26_s32_384' : _cfg (
47+ input_size = (3 , 384 , 384 ), crop_pct = 1.0 ),
5448 'vit_base_r20_s16_224' : _cfg (),
5549 'vit_base_r26_s32_224' : _cfg (),
5650 'vit_base_r50_s16_224' : _cfg (),
51+ 'vit_base_r50_s16_384' : _cfg (
52+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth' ,
53+ input_size = (3 , 384 , 384 ), crop_pct = 1.0 ),
5754 'vit_large_r50_s32_224' : _cfg (),
55+ 'vit_large_r50_s32_384' : _cfg (),
56+
57+ # hybrid in-21k models (weights ported from official Google JAX impl where they exist)
58+ 'vit_small_r26_s32_224_in21k' : _cfg (
59+ num_classes = 21843 , crop_pct = 0.9 ),
60+ 'vit_small_r26_s32_384_in21k' : _cfg (
61+ num_classes = 21843 , input_size = (3 , 384 , 384 ), crop_pct = 1.0 ),
62+ 'vit_base_r50_s16_224_in21k' : _cfg (
63+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth' ,
64+ num_classes = 21843 , crop_pct = 0.9 ),
65+ 'vit_large_r50_s32_224_in21k' : _cfg (num_classes = 21843 , crop_pct = 0.9 ),
5866
5967 # hybrid models (using timm resnet backbones)
6068 'vit_small_resnet26d_224' : _cfg (
@@ -99,7 +107,8 @@ def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_c
99107 else :
100108 feature_dim = self .backbone .num_features
101109 assert feature_size [0 ] % patch_size [0 ] == 0 and feature_size [1 ] % patch_size [1 ] == 0
102- self .num_patches = feature_size [0 ] // patch_size [0 ] * feature_size [1 ] // patch_size [1 ]
110+ self .grid_size = (feature_size [0 ] // patch_size [0 ], feature_size [1 ] // patch_size [1 ])
111+ self .num_patches = self .grid_size [0 ] * self .grid_size [1 ]
103112 self .proj = nn .Conv2d (feature_dim , embed_dim , kernel_size = patch_size , stride = patch_size )
104113
105114 def forward (self , x ):
@@ -133,37 +142,35 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
133142
134143
135144@register_model
136- def vit_base_r50_s16_224_in21k (pretrained = False , ** kwargs ):
137- """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
138- ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
145+ def vit_tiny_r_s16_p8_224 (pretrained = False , ** kwargs ):
146+ """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
139147 """
140- backbone = _resnetv2 (layers = (3 , 4 , 9 ), ** kwargs )
141- model_kwargs = dict (embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 768 , ** kwargs )
148+ backbone = _resnetv2 (layers = (), ** kwargs )
149+ model_kwargs = dict (patch_size = 8 , embed_dim = 192 , depth = 12 , num_heads = 3 , ** kwargs )
142150 model = _create_vision_transformer_hybrid (
143- 'vit_base_r50_s16_224_in21k ' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
151+ 'vit_tiny_r_s16_p8_224 ' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
144152 return model
145153
146154
147155@register_model
148- def vit_base_r50_s16_384 (pretrained = False , ** kwargs ):
149- """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
150- ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
156+ def vit_tiny_r_s16_p8_384 (pretrained = False , ** kwargs ):
157+ """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
151158 """
152- backbone = _resnetv2 (( 3 , 4 , 9 ), ** kwargs )
153- model_kwargs = dict (embed_dim = 768 , depth = 12 , num_heads = 12 , ** kwargs )
159+ backbone = _resnetv2 (layers = ( ), ** kwargs )
160+ model_kwargs = dict (patch_size = 8 , embed_dim = 192 , depth = 12 , num_heads = 3 , ** kwargs )
154161 model = _create_vision_transformer_hybrid (
155- 'vit_base_r50_s16_384 ' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
162+ 'vit_tiny_r_s16_p8_384 ' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
156163 return model
157164
158165
159166@register_model
160- def vit_tiny_r_s16_p8_224 (pretrained = False , ** kwargs ):
161- """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224 .
167+ def vit_tiny_r_s16_p8_384 (pretrained = False , ** kwargs ):
168+ """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384 .
162169 """
163170 backbone = _resnetv2 (layers = (), ** kwargs )
164171 model_kwargs = dict (patch_size = 8 , embed_dim = 192 , depth = 12 , num_heads = 3 , ** kwargs )
165172 model = _create_vision_transformer_hybrid (
166- 'vit_tiny_r_s16_p8_224 ' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
173+ 'vit_tiny_r_s16_p8_384 ' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
167174 return model
168175
169176
@@ -212,6 +219,17 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs):
212219 return model
213220
214221
222+ @register_model
223+ def vit_small_r26_s32_384 (pretrained = False , ** kwargs ):
224+ """ R26+ViT-S/S32 hybrid.
225+ """
226+ backbone = _resnetv2 ((2 , 2 , 2 , 2 ), ** kwargs )
227+ model_kwargs = dict (embed_dim = 384 , depth = 12 , num_heads = 6 , ** kwargs )
228+ model = _create_vision_transformer_hybrid (
229+ 'vit_small_r26_s32_384' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
230+ return model
231+
232+
215233@register_model
216234def vit_base_r20_s16_224 (pretrained = False , ** kwargs ):
217235 """ R20+ViT-B/S16 hybrid.
@@ -245,17 +263,74 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs):
245263 return model
246264
247265
266+ @register_model
267+ def vit_base_r50_s16_384 (pretrained = False , ** kwargs ):
268+ """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
269+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
270+ """
271+ backbone = _resnetv2 ((3 , 4 , 9 ), ** kwargs )
272+ model_kwargs = dict (embed_dim = 768 , depth = 12 , num_heads = 12 , ** kwargs )
273+ model = _create_vision_transformer_hybrid (
274+ 'vit_base_r50_s16_384' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
275+ return model
276+
277+
248278@register_model
249279def vit_large_r50_s32_224 (pretrained = False , ** kwargs ):
250280 """ R50+ViT-L/S32 hybrid.
251281 """
252282 backbone = _resnetv2 ((3 , 4 , 6 , 3 ), ** kwargs )
253- model_kwargs = dict (embed_dim = 768 , depth = 12 , num_heads = 12 , ** kwargs )
283+ model_kwargs = dict (embed_dim = 1024 , depth = 24 , num_heads = 16 , ** kwargs )
254284 model = _create_vision_transformer_hybrid (
255285 'vit_large_r50_s32_224' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
256286 return model
257287
258288
289+ @register_model
290+ def vit_large_r50_s32_384 (pretrained = False , ** kwargs ):
291+ """ R50+ViT-L/S32 hybrid.
292+ """
293+ backbone = _resnetv2 ((3 , 4 , 6 , 3 ), ** kwargs )
294+ model_kwargs = dict (embed_dim = 1024 , depth = 24 , num_heads = 16 , ** kwargs )
295+ model = _create_vision_transformer_hybrid (
296+ 'vit_large_r50_s32_384' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
297+ return model
298+
299+
300+ @register_model
301+ def vit_small_r26_s32_224_in21k (pretrained = False , ** kwargs ):
302+ """ R26+ViT-S/S32 hybrid.
303+ """
304+ backbone = _resnetv2 ((2 , 2 , 2 , 2 ), ** kwargs )
305+ model_kwargs = dict (embed_dim = 384 , depth = 12 , num_heads = 6 , ** kwargs )
306+ model = _create_vision_transformer_hybrid (
307+ 'vit_small_r26_s32_224_in21k' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
308+ return model
309+
310+
311+ @register_model
312+ def vit_small_r26_s32_384_in21k (pretrained = False , ** kwargs ):
313+ """ R26+ViT-S/S32 hybrid.
314+ """
315+ backbone = _resnetv2 ((2 , 2 , 2 , 2 ), ** kwargs )
316+ model_kwargs = dict (embed_dim = 384 , depth = 12 , num_heads = 6 , ** kwargs )
317+ model = _create_vision_transformer_hybrid (
318+ 'vit_small_r26_s32_384_in21k' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
319+ return model
320+
321+
322+ @register_model
323+ def vit_base_r50_s16_224_in21k (pretrained = False , ** kwargs ):
324+ """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
325+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
326+ """
327+ backbone = _resnetv2 (layers = (3 , 4 , 9 ), ** kwargs )
328+ model_kwargs = dict (embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 768 , ** kwargs )
329+ model = _create_vision_transformer_hybrid (
330+ 'vit_base_r50_s16_224_in21k' , backbone = backbone , pretrained = pretrained , ** model_kwargs )
331+ return model
332+
333+
259334@register_model
260335def vit_small_resnet26d_224 (pretrained = False , ** kwargs ):
261336 """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
0 commit comments