Skip to content

Commit b9cfb64

Browse files
committed
Support npz custom load for vision transformer hybrid models. Add posembed rescale for npz load.
1 parent 8319e0c commit b9cfb64

File tree

3 files changed

+181
-56
lines changed

3 files changed

+181
-56
lines changed

timm/models/layers/pool2d_same.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, co
2727
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
2828

2929
def forward(self, x):
30-
return avg_pool2d_same(
30+
x = pad_same(x, self.kernel_size, self.stride)
31+
return F.avg_pool2d(
3132
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
3233

3334

@@ -41,14 +42,15 @@ def max_pool2d_same(
4142
class MaxPool2dSame(nn.MaxPool2d):
4243
""" Tensorflow like 'SAME' wrapper for 2D max pooling
4344
"""
44-
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
45+
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
4546
kernel_size = to_2tuple(kernel_size)
4647
stride = to_2tuple(stride)
4748
dilation = to_2tuple(dilation)
48-
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
49+
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
4950

5051
def forward(self, x):
51-
return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode)
52+
x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
53+
return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
5254

5355

5456
def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):

timm/models/vision_transformer.py

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ def _cfg(url='', **kwargs):
5252
url='',
5353
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
5454
),
55+
'vit_tiny_patch16_384': _cfg(
56+
url='',
57+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0
58+
),
5559
'vit_small_patch16_224': _cfg(
5660
url='',
5761
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
@@ -60,6 +64,14 @@ def _cfg(url='', **kwargs):
6064
url='',
6165
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
6266
),
67+
'vit_small_patch16_384': _cfg(
68+
url='',
69+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0
70+
),
71+
'vit_small_patch32_384': _cfg(
72+
url='',
73+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0
74+
),
6375

6476
# patch models (weights ported from official Google JAX impl)
6577
'vit_base_patch16_224': _cfg(
@@ -102,6 +114,7 @@ def _cfg(url='', **kwargs):
102114
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
103115
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
104116
'vit_huge_patch14_224_in21k': _cfg(
117+
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
105118
hf_hub='timm/vit_huge_patch14_224_in21k',
106119
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
107120

@@ -371,48 +384,72 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
371384
import numpy as np
372385

373386
def _n2p(w, t=True):
374-
if t and w.ndim == 4:
375-
w = w.transpose([3, 2, 0, 1])
376-
elif t and w.ndim == 3:
377-
w = w.transpose([2, 0, 1])
378-
elif t and w.ndim == 2:
379-
w = w.transpose([1, 0])
387+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
388+
w = w.flatten()
389+
if t:
390+
if w.ndim == 4:
391+
w = w.transpose([3, 2, 0, 1])
392+
elif w.ndim == 3:
393+
w = w.transpose([2, 0, 1])
394+
elif w.ndim == 2:
395+
w = w.transpose([1, 0])
380396
return torch.from_numpy(w)
381397

382398
w = np.load(checkpoint_path)
383-
if not prefix:
384-
prefix = 'opt/target/' if 'opt/target/embedding/kernel' in w else prefix
385-
386-
input_conv_w = adapt_input_conv(
387-
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
388-
model.patch_embed.proj.weight.copy_(input_conv_w)
399+
if not prefix and 'opt/target/embedding/kernel' in w:
400+
prefix = 'opt/target/'
401+
402+
if hasattr(model.patch_embed, 'backbone'):
403+
# hybrid
404+
backbone = model.patch_embed.backbone
405+
stem_only = not hasattr(backbone, 'stem')
406+
stem = backbone if stem_only else backbone.stem
407+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
408+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
409+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
410+
if not stem_only:
411+
for i, stage in enumerate(backbone.stages):
412+
for j, block in enumerate(stage.blocks):
413+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
414+
for r in range(3):
415+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
416+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
417+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
418+
if block.downsample is not None:
419+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
420+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
421+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
422+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
423+
else:
424+
embed_conv_w = adapt_input_conv(
425+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
426+
model.patch_embed.proj.weight.copy_(embed_conv_w)
389427
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
390428
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
391-
model.pos_embed.copy_(_n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False))
429+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
430+
if pos_embed_w.shape != model.pos_embed.shape:
431+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
432+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
433+
model.pos_embed.copy_(pos_embed_w)
392434
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
393435
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
394436
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
395437
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
396438
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
397439
for i, block in enumerate(model.blocks.children()):
398440
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
441+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
399442
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
400443
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
401-
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
402444
block.attn.qkv.weight.copy_(torch.cat([
403-
_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T,
404-
_n2p(w[f'{mha_prefix}key/kernel'], t=False).flatten(1).T,
405-
_n2p(w[f'{mha_prefix}value/kernel'], t=False).flatten(1).T]))
445+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
406446
block.attn.qkv.bias.copy_(torch.cat([
407-
_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1),
408-
_n2p(w[f'{mha_prefix}key/bias'], t=False).reshape(-1),
409-
_n2p(w[f'{mha_prefix}value/bias'], t=False).reshape(-1)]))
447+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
410448
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
411449
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
412-
block.mlp.fc1.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/kernel']))
413-
block.mlp.fc1.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/bias']))
414-
block.mlp.fc2.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/kernel']))
415-
block.mlp.fc2.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/bias']))
450+
for r in range(2):
451+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
452+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
416453
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
417454
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
418455

@@ -478,6 +515,7 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
478515
default_cfg=default_cfg,
479516
representation_size=repr_size,
480517
pretrained_filter_fn=checkpoint_filter_fn,
518+
pretrained_custom_load='npz' in default_cfg['url'],
481519
**kwargs)
482520
return model
483521

@@ -510,6 +548,16 @@ def vit_small_patch32_224(pretrained=False, **kwargs):
510548
return model
511549

512550

551+
@register_model
552+
def vit_small_patch16_384(pretrained=False, **kwargs):
553+
""" ViT-Small (ViT-S/16)
554+
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
555+
"""
556+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
557+
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
558+
return model
559+
560+
513561
@register_model
514562
def vit_base_patch16_224(pretrained=False, **kwargs):
515563
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).

timm/models/vision_transformer_hybrid.py

Lines changed: 103 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,34 @@ def _cfg(url='', **kwargs):
3535

3636

3737
default_cfgs = {
38-
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
39-
'vit_base_r50_s16_224_in21k': _cfg(
40-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
41-
num_classes=21843, crop_pct=0.9),
42-
43-
# hybrid in-1k models (weights ported from official JAX impl)
44-
'vit_base_r50_s16_384': _cfg(
45-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
46-
input_size=(3, 384, 384), crop_pct=1.0),
47-
48-
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
38+
# hybrid in-1k models (weights ported from official JAX impl where they exist)
4939
'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
40+
'vit_tiny_r_s16_p8_384': _cfg(
41+
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0),
5042
'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
5143
'vit_small_r20_s16_p2_224': _cfg(),
5244
'vit_small_r20_s16_224': _cfg(),
5345
'vit_small_r26_s32_224': _cfg(),
46+
'vit_small_r26_s32_384': _cfg(
47+
input_size=(3, 384, 384), crop_pct=1.0),
5448
'vit_base_r20_s16_224': _cfg(),
5549
'vit_base_r26_s32_224': _cfg(),
5650
'vit_base_r50_s16_224': _cfg(),
51+
'vit_base_r50_s16_384': _cfg(
52+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
53+
input_size=(3, 384, 384), crop_pct=1.0),
5754
'vit_large_r50_s32_224': _cfg(),
55+
'vit_large_r50_s32_384': _cfg(),
56+
57+
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
58+
'vit_small_r26_s32_224_in21k': _cfg(
59+
num_classes=21843, crop_pct=0.9),
60+
'vit_small_r26_s32_384_in21k': _cfg(
61+
num_classes=21843, input_size=(3, 384, 384), crop_pct=1.0),
62+
'vit_base_r50_s16_224_in21k': _cfg(
63+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
64+
num_classes=21843, crop_pct=0.9),
65+
'vit_large_r50_s32_224_in21k': _cfg(num_classes=21843, crop_pct=0.9),
5866

5967
# hybrid models (using timm resnet backbones)
6068
'vit_small_resnet26d_224': _cfg(
@@ -99,7 +107,8 @@ def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_c
99107
else:
100108
feature_dim = self.backbone.num_features
101109
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
102-
self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1]
110+
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
111+
self.num_patches = self.grid_size[0] * self.grid_size[1]
103112
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
104113

105114
def forward(self, x):
@@ -133,37 +142,35 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
133142

134143

135144
@register_model
136-
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
137-
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
138-
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
145+
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
146+
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
139147
"""
140-
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
141-
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
148+
backbone = _resnetv2(layers=(), **kwargs)
149+
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
142150
model = _create_vision_transformer_hybrid(
143-
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
151+
'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
144152
return model
145153

146154

147155
@register_model
148-
def vit_base_r50_s16_384(pretrained=False, **kwargs):
149-
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
150-
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
156+
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
157+
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
151158
"""
152-
backbone = _resnetv2((3, 4, 9), **kwargs)
153-
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
159+
backbone = _resnetv2(layers=(), **kwargs)
160+
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
154161
model = _create_vision_transformer_hybrid(
155-
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
162+
'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
156163
return model
157164

158165

159166
@register_model
160-
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
161-
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
167+
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
168+
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
162169
"""
163170
backbone = _resnetv2(layers=(), **kwargs)
164171
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
165172
model = _create_vision_transformer_hybrid(
166-
'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
173+
'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
167174
return model
168175

169176

@@ -212,6 +219,17 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs):
212219
return model
213220

214221

222+
@register_model
223+
def vit_small_r26_s32_384(pretrained=False, **kwargs):
224+
""" R26+ViT-S/S32 hybrid.
225+
"""
226+
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
227+
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
228+
model = _create_vision_transformer_hybrid(
229+
'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
230+
return model
231+
232+
215233
@register_model
216234
def vit_base_r20_s16_224(pretrained=False, **kwargs):
217235
""" R20+ViT-B/S16 hybrid.
@@ -245,17 +263,74 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs):
245263
return model
246264

247265

266+
@register_model
267+
def vit_base_r50_s16_384(pretrained=False, **kwargs):
268+
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
269+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
270+
"""
271+
backbone = _resnetv2((3, 4, 9), **kwargs)
272+
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
273+
model = _create_vision_transformer_hybrid(
274+
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
275+
return model
276+
277+
248278
@register_model
249279
def vit_large_r50_s32_224(pretrained=False, **kwargs):
250280
""" R50+ViT-L/S32 hybrid.
251281
"""
252282
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
253-
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
283+
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
254284
model = _create_vision_transformer_hybrid(
255285
'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
256286
return model
257287

258288

289+
@register_model
290+
def vit_large_r50_s32_384(pretrained=False, **kwargs):
291+
""" R50+ViT-L/S32 hybrid.
292+
"""
293+
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
294+
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
295+
model = _create_vision_transformer_hybrid(
296+
'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
297+
return model
298+
299+
300+
@register_model
301+
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
302+
""" R26+ViT-S/S32 hybrid.
303+
"""
304+
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
305+
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
306+
model = _create_vision_transformer_hybrid(
307+
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
308+
return model
309+
310+
311+
@register_model
312+
def vit_small_r26_s32_384_in21k(pretrained=False, **kwargs):
313+
""" R26+ViT-S/S32 hybrid.
314+
"""
315+
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
316+
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
317+
model = _create_vision_transformer_hybrid(
318+
'vit_small_r26_s32_384_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
319+
return model
320+
321+
322+
@register_model
323+
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
324+
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
325+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
326+
"""
327+
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
328+
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
329+
model = _create_vision_transformer_hybrid(
330+
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
331+
return model
332+
333+
259334
@register_model
260335
def vit_small_resnet26d_224(pretrained=False, **kwargs):
261336
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.

0 commit comments

Comments
 (0)