Skip to content

Commit c9d093a

Browse files
committed
update norm eps for efficientvit large
1 parent 87ba43a commit c9d093a

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

timm/models/efficientvit_mit.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88

99
__all__ = ['EfficientVit']
1010
from typing import Optional
11+
from functools import partial
1112

1213
import torch
1314
import torch.nn as nn
1415
import torch.nn.functional as F
1516
from torch.nn.modules.batchnorm import _BatchNorm
1617

1718
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18-
from timm.layers import SelectAdaptivePool2d, create_conv2d
19+
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
1920
from ._builder import build_model_with_cfg
2021
from ._features_fx import register_notrace_module
2122
from ._manipulate import checkpoint_seq
@@ -71,10 +72,7 @@ def __init__(
7172
bias=bias,
7273
)
7374
self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity()
74-
if act_layer is not None:
75-
self.act = act_layer(inplace=True) if act_layer is not nn.GELU else act_layer(approximate="tanh")
76-
else:
77-
self.act = nn.Identity()
75+
self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
7876

7977
def forward(self, x):
8078
x = self.dropout(x)
@@ -641,14 +639,15 @@ def __init__(
641639
norm_layer=nn.BatchNorm2d,
642640
act_layer=nn.Hardswish,
643641
global_pool='avg',
642+
norm_eps=1e-5,
644643
):
645644
super(ClassifierHead, self).__init__()
646645
self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer)
647646
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
648647
self.classifier = nn.Sequential(
649648
nn.Linear(widths[0], widths[1], bias=False),
650-
nn.LayerNorm(widths[1]),
651-
act_layer(inplace=True) if act_layer is not nn.GELU else act_layer(approximate="tanh"),
649+
nn.LayerNorm(widths[1], eps=norm_eps),
650+
act_layer(inplace=True) if act_layer is not None else nn.Identity(),
652651
nn.Dropout(dropout, inplace=False),
653652
nn.Linear(widths[1], n_classes, bias=True),
654653
)
@@ -784,17 +783,19 @@ def __init__(
784783
depths=(),
785784
head_dim=32,
786785
norm_layer=nn.BatchNorm2d,
787-
act_layer=nn.GELU,
786+
act_layer=GELUTanh,
788787
global_pool='avg',
789788
head_widths=(),
790789
drop_rate=0.0,
791790
num_classes=1000,
792-
eps=1e-7,
791+
norm_eps=1e-7,
793792
):
794793
super(EfficientVitLarge, self).__init__()
795794
self.grad_checkpointing = False
796795
self.global_pool = global_pool
797796
self.num_classes = num_classes
797+
self.norm_eps = norm_eps
798+
norm_layer = partial(norm_layer, eps=self.norm_eps)
798799

799800
# input stem
800801
self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large')
@@ -830,20 +831,13 @@ def __init__(
830831
dropout=self.head_dropout,
831832
global_pool=self.global_pool,
832833
act_layer=act_layer,
834+
norm_eps=self.norm_eps,
833835
)
834836
else:
835837
if self.global_pool == 'avg':
836838
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
837839
else:
838840
self.head = nn.Identity()
839-
self.set_norm_eps(eps)
840-
841-
@torch.jit.ignore
842-
def set_norm_eps(self, eps):
843-
for m in self.modules():
844-
if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
845-
if eps is not None:
846-
m.eps = eps
847841

848842
@torch.jit.ignore
849843
def group_matcher(self, coarse=False):
@@ -875,6 +869,7 @@ def reset_classifier(self, num_classes, global_pool=None):
875869
n_classes=num_classes,
876870
dropout=self.head_dropout,
877871
global_pool=self.global_pool,
872+
norm_eps=self.norm_eps
878873
)
879874
else:
880875
if self.global_pool == 'avg':
@@ -1056,19 +1051,19 @@ def efficientvit_l3(pretrained=False, **kwargs):
10561051
@register_model
10571052
def efficientvit_l0_sam(pretrained=False, **kwargs):
10581053
model_args = dict(
1059-
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights
1054+
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights
10601055
return _create_efficientvit_large('efficientvit_l0_sam', pretrained=pretrained, **dict(model_args, **kwargs))
10611056

10621057

10631058
@register_model
10641059
def efficientvit_l1_sam(pretrained=False, **kwargs):
10651060
model_args = dict(
1066-
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights
1061+
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights
10671062
return _create_efficientvit_large('efficientvit_l1_sam', pretrained=pretrained, **dict(model_args, **kwargs))
10681063

10691064

10701065
@register_model
10711066
def efficientvit_l2_sam(pretrained=False, **kwargs):
10721067
model_args = dict(
1073-
widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights
1068+
widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights
10741069
return _create_efficientvit_large('efficientvit_l2_sam', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)