Skip to content

Commit 086bd55

Browse files
committed
Add EfficientFormer-V2, refactor EfficientFormer and Levit for more uniformity across the 3 related arch. Add features_out support to levit conv models and efficientformer_v2. All weights on hub.
1 parent 29fda20 commit 086bd55

File tree

5 files changed

+1336
-307
lines changed

5 files changed

+1336
-307
lines changed

timm/layers/create_norm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
2424

2525

26-
def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs):
27-
layer = get_norm_layer(layer_name, act_layer=act_layer)
28-
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
26+
def create_norm_layer(layer_name, num_features, **kwargs):
27+
layer = get_norm_layer(layer_name)
28+
layer_instance = layer(num_features, **kwargs)
2929
return layer_instance
3030

3131

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .dpn import *
1616
from .edgenext import *
1717
from .efficientformer import *
18+
from .efficientformer_v2 import *
1819
from .efficientnet import *
1920
from .gcvit import *
2021
from .ghostnet import *

timm/models/efficientformer.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,12 @@
2020
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2121
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
2222
from ._builder import build_model_with_cfg
23+
from ._pretrained import generate_default_cfgs
2324
from ._registry import register_model
2425

2526
__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this
2627

2728

28-
def _cfg(url='', **kwargs):
29-
return {
30-
'url': url,
31-
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
32-
'crop_pct': .95, 'interpolation': 'bicubic',
33-
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
34-
'first_conv': 'stem.conv1', 'classifier': ('head', 'head_dist'),
35-
**kwargs
36-
}
37-
38-
39-
default_cfgs = dict(
40-
efficientformer_l1=_cfg(
41-
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l1_1000d_224-5b08fab0.pth",
42-
),
43-
efficientformer_l3=_cfg(
44-
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l3_300d_224-6816624f.pth",
45-
),
46-
efficientformer_l7=_cfg(
47-
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l7_300d_224-e957ab75.pth",
48-
),
49-
)
50-
5129
EfficientFormer_width = {
5230
'l1': (48, 96, 224, 448),
5331
'l3': (64, 128, 320, 512),
@@ -99,7 +77,7 @@ def train(self, mode=True):
9977
self.attention_bias_cache = {} # clear ab cache
10078

10179
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
102-
if self.training:
80+
if torch.jit.is_tracing() or self.training:
10381
return self.attention_biases[:, self.attention_bias_idxs]
10482
else:
10583
device_key = str(device)
@@ -279,16 +257,17 @@ def __init__(
279257
):
280258
super().__init__()
281259
self.token_mixer = Pooling(pool_size=pool_size)
260+
self.ls1 = LayerScale2d(dim, layer_scale_init_value)
261+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
262+
282263
self.mlp = ConvMlpWithNorm(
283264
dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, drop=drop)
284-
285-
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
286-
self.ls1 = LayerScale2d(dim, layer_scale_init_value)
287265
self.ls2 = LayerScale2d(dim, layer_scale_init_value)
266+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
288267

289268
def forward(self, x):
290-
x = x + self.drop_path(self.ls1(self.token_mixer(x)))
291-
x = x + self.drop_path(self.ls2(self.mlp(x)))
269+
x = x + self.drop_path1(self.ls1(self.token_mixer(x)))
270+
x = x + self.drop_path2(self.ls2(self.mlp(x)))
292271
return x
293272

294273

@@ -514,6 +493,30 @@ def _checkpoint_filter_fn(state_dict, model):
514493
return out_dict
515494

516495

496+
def _cfg(url='', **kwargs):
497+
return {
498+
'url': url,
499+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
500+
'crop_pct': .95, 'interpolation': 'bicubic',
501+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
502+
'first_conv': 'stem.conv1', 'classifier': ('head', 'head_dist'),
503+
**kwargs
504+
}
505+
506+
507+
default_cfgs = generate_default_cfgs({
508+
'efficientformer_l1.snap_dist_in1k': _cfg(
509+
hf_hub_id='timm/',
510+
),
511+
'efficientformer_l3.snap_dist_in1k': _cfg(
512+
hf_hub_id='timm/',
513+
),
514+
'efficientformer_l7.snap_dist_in1k': _cfg(
515+
hf_hub_id='timm/',
516+
),
517+
})
518+
519+
517520
def _create_efficientformer(variant, pretrained=False, **kwargs):
518521
model = build_model_with_cfg(
519522
EfficientFormer, variant, pretrained,
@@ -524,30 +527,30 @@ def _create_efficientformer(variant, pretrained=False, **kwargs):
524527

525528
@register_model
526529
def efficientformer_l1(pretrained=False, **kwargs):
527-
model_kwargs = dict(
530+
model_args = dict(
528531
depths=EfficientFormer_depth['l1'],
529532
embed_dims=EfficientFormer_width['l1'],
530533
num_vit=1,
531-
**kwargs)
532-
return _create_efficientformer('efficientformer_l1', pretrained=pretrained, **model_kwargs)
534+
)
535+
return _create_efficientformer('efficientformer_l1', pretrained=pretrained, **dict(model_args, **kwargs))
533536

534537

535538
@register_model
536539
def efficientformer_l3(pretrained=False, **kwargs):
537-
model_kwargs = dict(
540+
model_args = dict(
538541
depths=EfficientFormer_depth['l3'],
539542
embed_dims=EfficientFormer_width['l3'],
540543
num_vit=4,
541-
**kwargs)
542-
return _create_efficientformer('efficientformer_l3', pretrained=pretrained, **model_kwargs)
544+
)
545+
return _create_efficientformer('efficientformer_l3', pretrained=pretrained, **dict(model_args, **kwargs))
543546

544547

545548
@register_model
546549
def efficientformer_l7(pretrained=False, **kwargs):
547-
model_kwargs = dict(
550+
model_args = dict(
548551
depths=EfficientFormer_depth['l7'],
549552
embed_dims=EfficientFormer_width['l7'],
550553
num_vit=8,
551-
**kwargs)
552-
return _create_efficientformer('efficientformer_l7', pretrained=pretrained, **model_kwargs)
554+
)
555+
return _create_efficientformer('efficientformer_l7', pretrained=pretrained, **dict(model_args, **kwargs))
553556

0 commit comments

Comments
 (0)