|
8 | 8 |
|
9 | 9 | __all__ = ['EfficientVit'] |
10 | 10 | from typing import Optional |
| 11 | +from functools import partial |
11 | 12 |
|
12 | 13 | import torch |
13 | 14 | import torch.nn as nn |
14 | 15 | import torch.nn.functional as F |
15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm |
16 | 17 |
|
17 | 18 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
18 | | -from timm.layers import SelectAdaptivePool2d, create_conv2d |
| 19 | +from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh |
19 | 20 | from ._builder import build_model_with_cfg |
20 | 21 | from ._features_fx import register_notrace_module |
21 | 22 | from ._manipulate import checkpoint_seq |
@@ -71,10 +72,7 @@ def __init__( |
71 | 72 | bias=bias, |
72 | 73 | ) |
73 | 74 | self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity() |
74 | | - if act_layer is not None: |
75 | | - self.act = act_layer(inplace=True) if act_layer is not nn.GELU else act_layer(approximate="tanh") |
76 | | - else: |
77 | | - self.act = nn.Identity() |
| 75 | + self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity() |
78 | 76 |
|
79 | 77 | def forward(self, x): |
80 | 78 | x = self.dropout(x) |
@@ -641,14 +639,15 @@ def __init__( |
641 | 639 | norm_layer=nn.BatchNorm2d, |
642 | 640 | act_layer=nn.Hardswish, |
643 | 641 | global_pool='avg', |
| 642 | + norm_eps=1e-5, |
644 | 643 | ): |
645 | 644 | super(ClassifierHead, self).__init__() |
646 | 645 | self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer) |
647 | 646 | self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') |
648 | 647 | self.classifier = nn.Sequential( |
649 | 648 | nn.Linear(widths[0], widths[1], bias=False), |
650 | | - nn.LayerNorm(widths[1]), |
651 | | - act_layer(inplace=True) if act_layer is not nn.GELU else act_layer(approximate="tanh"), |
| 649 | + nn.LayerNorm(widths[1], eps=norm_eps), |
| 650 | + act_layer(inplace=True) if act_layer is not None else nn.Identity(), |
652 | 651 | nn.Dropout(dropout, inplace=False), |
653 | 652 | nn.Linear(widths[1], n_classes, bias=True), |
654 | 653 | ) |
@@ -784,17 +783,19 @@ def __init__( |
784 | 783 | depths=(), |
785 | 784 | head_dim=32, |
786 | 785 | norm_layer=nn.BatchNorm2d, |
787 | | - act_layer=nn.GELU, |
| 786 | + act_layer=GELUTanh, |
788 | 787 | global_pool='avg', |
789 | 788 | head_widths=(), |
790 | 789 | drop_rate=0.0, |
791 | 790 | num_classes=1000, |
792 | | - eps=1e-7, |
| 791 | + norm_eps=1e-7, |
793 | 792 | ): |
794 | 793 | super(EfficientVitLarge, self).__init__() |
795 | 794 | self.grad_checkpointing = False |
796 | 795 | self.global_pool = global_pool |
797 | 796 | self.num_classes = num_classes |
| 797 | + self.norm_eps = norm_eps |
| 798 | + norm_layer = partial(norm_layer, eps=self.norm_eps) |
798 | 799 |
|
799 | 800 | # input stem |
800 | 801 | self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large') |
@@ -830,20 +831,13 @@ def __init__( |
830 | 831 | dropout=self.head_dropout, |
831 | 832 | global_pool=self.global_pool, |
832 | 833 | act_layer=act_layer, |
| 834 | + norm_eps=self.norm_eps, |
833 | 835 | ) |
834 | 836 | else: |
835 | 837 | if self.global_pool == 'avg': |
836 | 838 | self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) |
837 | 839 | else: |
838 | 840 | self.head = nn.Identity() |
839 | | - self.set_norm_eps(eps) |
840 | | - |
841 | | - @torch.jit.ignore |
842 | | - def set_norm_eps(self, eps): |
843 | | - for m in self.modules(): |
844 | | - if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): |
845 | | - if eps is not None: |
846 | | - m.eps = eps |
847 | 841 |
|
848 | 842 | @torch.jit.ignore |
849 | 843 | def group_matcher(self, coarse=False): |
@@ -875,6 +869,7 @@ def reset_classifier(self, num_classes, global_pool=None): |
875 | 869 | n_classes=num_classes, |
876 | 870 | dropout=self.head_dropout, |
877 | 871 | global_pool=self.global_pool, |
| 872 | + norm_eps=self.norm_eps |
878 | 873 | ) |
879 | 874 | else: |
880 | 875 | if self.global_pool == 'avg': |
@@ -1056,19 +1051,19 @@ def efficientvit_l3(pretrained=False, **kwargs): |
1056 | 1051 | @register_model |
1057 | 1052 | def efficientvit_l0_sam(pretrained=False, **kwargs): |
1058 | 1053 | model_args = dict( |
1059 | | - widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights |
| 1054 | + widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights |
1060 | 1055 | return _create_efficientvit_large('efficientvit_l0_sam', pretrained=pretrained, **dict(model_args, **kwargs)) |
1061 | 1056 |
|
1062 | 1057 |
|
1063 | 1058 | @register_model |
1064 | 1059 | def efficientvit_l1_sam(pretrained=False, **kwargs): |
1065 | 1060 | model_args = dict( |
1066 | | - widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights |
| 1061 | + widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights |
1067 | 1062 | return _create_efficientvit_large('efficientvit_l1_sam', pretrained=pretrained, **dict(model_args, **kwargs)) |
1068 | 1063 |
|
1069 | 1064 |
|
1070 | 1065 | @register_model |
1071 | 1066 | def efficientvit_l2_sam(pretrained=False, **kwargs): |
1072 | 1067 | model_args = dict( |
1073 | | - widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights |
| 1068 | + widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights |
1074 | 1069 | return _create_efficientvit_large('efficientvit_l2_sam', pretrained=pretrained, **dict(model_args, **kwargs)) |
0 commit comments