Skip to content

Commit 4b383e8

Browse files
authored
Merge pull request #1655 from rwightman/levit_efficientformer_redux
Add EfficientFormer-V2, refactor EfficientFormer and Levit
2 parents 2cb2699 + 13acac8 commit 4b383e8

File tree

5 files changed

+1391
-308
lines changed

5 files changed

+1391
-308
lines changed

timm/layers/create_norm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
2424

2525

26-
def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs):
27-
layer = get_norm_layer(layer_name, act_layer=act_layer)
28-
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
26+
def create_norm_layer(layer_name, num_features, **kwargs):
27+
layer = get_norm_layer(layer_name)
28+
layer_instance = layer(num_features, **kwargs)
2929
return layer_instance
3030

3131

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .dpn import *
1616
from .edgenext import *
1717
from .efficientformer import *
18+
from .efficientformer_v2 import *
1819
from .efficientnet import *
1920
from .gcvit import *
2021
from .ghostnet import *

timm/models/efficientformer.py

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,13 @@
2020
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2121
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
2222
from ._builder import build_model_with_cfg
23+
from ._manipulate import checkpoint_seq
24+
from ._pretrained import generate_default_cfgs
2325
from ._registry import register_model
2426

2527
__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this
2628

2729

28-
def _cfg(url='', **kwargs):
29-
return {
30-
'url': url,
31-
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
32-
'crop_pct': .95, 'interpolation': 'bicubic',
33-
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
34-
'first_conv': 'stem.conv1', 'classifier': ('head', 'head_dist'),
35-
**kwargs
36-
}
37-
38-
39-
default_cfgs = dict(
40-
efficientformer_l1=_cfg(
41-
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l1_1000d_224-5b08fab0.pth",
42-
),
43-
efficientformer_l3=_cfg(
44-
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l3_300d_224-6816624f.pth",
45-
),
46-
efficientformer_l7=_cfg(
47-
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l7_300d_224-e957ab75.pth",
48-
),
49-
)
50-
5130
EfficientFormer_width = {
5231
'l1': (48, 96, 224, 448),
5332
'l3': (64, 128, 320, 512),
@@ -99,7 +78,7 @@ def train(self, mode=True):
9978
self.attention_bias_cache = {} # clear ab cache
10079

10180
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
102-
if self.training:
81+
if torch.jit.is_tracing() or self.training:
10382
return self.attention_biases[:, self.attention_bias_idxs]
10483
else:
10584
device_key = str(device)
@@ -279,16 +258,17 @@ def __init__(
279258
):
280259
super().__init__()
281260
self.token_mixer = Pooling(pool_size=pool_size)
261+
self.ls1 = LayerScale2d(dim, layer_scale_init_value)
262+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
263+
282264
self.mlp = ConvMlpWithNorm(
283265
dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, drop=drop)
284-
285-
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
286-
self.ls1 = LayerScale2d(dim, layer_scale_init_value)
287266
self.ls2 = LayerScale2d(dim, layer_scale_init_value)
267+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
288268

289269
def forward(self, x):
290-
x = x + self.drop_path(self.ls1(self.token_mixer(x)))
291-
x = x + self.drop_path(self.ls2(self.mlp(x)))
270+
x = x + self.drop_path1(self.ls1(self.token_mixer(x)))
271+
x = x + self.drop_path2(self.ls2(self.mlp(x)))
292272
return x
293273

294274

@@ -356,7 +336,10 @@ def __init__(
356336

357337
def forward(self, x):
358338
x = self.downsample(x)
359-
x = self.blocks(x)
339+
if self.grad_checkpointing and not torch.jit.is_scripting():
340+
x = checkpoint_seq(self.blocks, x)
341+
else:
342+
x = self.blocks(x)
360343
return x
361344

362345

@@ -514,6 +497,30 @@ def _checkpoint_filter_fn(state_dict, model):
514497
return out_dict
515498

516499

500+
def _cfg(url='', **kwargs):
501+
return {
502+
'url': url,
503+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
504+
'crop_pct': .95, 'interpolation': 'bicubic',
505+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
506+
'first_conv': 'stem.conv1', 'classifier': ('head', 'head_dist'),
507+
**kwargs
508+
}
509+
510+
511+
default_cfgs = generate_default_cfgs({
512+
'efficientformer_l1.snap_dist_in1k': _cfg(
513+
hf_hub_id='timm/',
514+
),
515+
'efficientformer_l3.snap_dist_in1k': _cfg(
516+
hf_hub_id='timm/',
517+
),
518+
'efficientformer_l7.snap_dist_in1k': _cfg(
519+
hf_hub_id='timm/',
520+
),
521+
})
522+
523+
517524
def _create_efficientformer(variant, pretrained=False, **kwargs):
518525
model = build_model_with_cfg(
519526
EfficientFormer, variant, pretrained,
@@ -524,30 +531,30 @@ def _create_efficientformer(variant, pretrained=False, **kwargs):
524531

525532
@register_model
526533
def efficientformer_l1(pretrained=False, **kwargs):
527-
model_kwargs = dict(
534+
model_args = dict(
528535
depths=EfficientFormer_depth['l1'],
529536
embed_dims=EfficientFormer_width['l1'],
530537
num_vit=1,
531-
**kwargs)
532-
return _create_efficientformer('efficientformer_l1', pretrained=pretrained, **model_kwargs)
538+
)
539+
return _create_efficientformer('efficientformer_l1', pretrained=pretrained, **dict(model_args, **kwargs))
533540

534541

535542
@register_model
536543
def efficientformer_l3(pretrained=False, **kwargs):
537-
model_kwargs = dict(
544+
model_args = dict(
538545
depths=EfficientFormer_depth['l3'],
539546
embed_dims=EfficientFormer_width['l3'],
540547
num_vit=4,
541-
**kwargs)
542-
return _create_efficientformer('efficientformer_l3', pretrained=pretrained, **model_kwargs)
548+
)
549+
return _create_efficientformer('efficientformer_l3', pretrained=pretrained, **dict(model_args, **kwargs))
543550

544551

545552
@register_model
546553
def efficientformer_l7(pretrained=False, **kwargs):
547-
model_kwargs = dict(
554+
model_args = dict(
548555
depths=EfficientFormer_depth['l7'],
549556
embed_dims=EfficientFormer_width['l7'],
550557
num_vit=8,
551-
**kwargs)
552-
return _create_efficientformer('efficientformer_l7', pretrained=pretrained, **model_kwargs)
558+
)
559+
return _create_efficientformer('efficientformer_l7', pretrained=pretrained, **dict(model_args, **kwargs))
553560

0 commit comments

Comments
 (0)