Skip to content

Commit 8bf63b6

Browse files
committed
Able to use other attn layer in EfficientNet now. Create test ECA + GC B0 configs. Make ECA more configurable.
1 parent bcec14d commit 8bf63b6

File tree

4 files changed

+54
-12
lines changed

4 files changed

+54
-12
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
2525
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
2626
EXCLUDE_FILTERS = [
27-
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm',
27+
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm'
2828
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
2929
'*resnetrs350*', '*resnetrs420*']
3030
else:

timm/models/efficientnet.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ def _cfg(url='', **kwargs):
9191
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
9292
interpolation='bilinear'),
9393

94+
# NOTE experimenting with alternate attention
95+
'eca_efficientnet_b0': _cfg(
96+
url=''),
97+
'gc_efficientnet_b0': _cfg(
98+
url=''),
99+
94100
'efficientnet_b0': _cfg(
95101
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'),
96102
'efficientnet_b1': _cfg(
@@ -1223,6 +1229,24 @@ def efficientnet_b0(pretrained=False, **kwargs):
12231229
return model
12241230

12251231

1232+
@register_model
1233+
def eca_efficientnet_b0(pretrained=False, **kwargs):
1234+
""" EfficientNet-B0 w/ ECA attn """
1235+
# NOTE experimental config
1236+
model = _gen_efficientnet(
1237+
'eca_efficientnet_b0', se_layer='eca', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1238+
return model
1239+
1240+
1241+
@register_model
1242+
def gc_efficientnet_b0(pretrained=False, **kwargs):
1243+
""" EfficientNet-B0 w/ GlobalContext """
1244+
# NOTE experminetal config
1245+
model = _gen_efficientnet(
1246+
'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1247+
return model
1248+
1249+
12261250
@register_model
12271251
def efficientnet_b1(pretrained=False, **kwargs):
12281252
""" EfficientNet-B1 """

timm/models/efficientnet_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,9 @@ def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, s
278278
self.norm_layer = norm_layer
279279
self.se_layer = get_attn(se_layer)
280280
try:
281-
self.se_layer(8, rd_ratio=1.0)
281+
self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
282282
self.se_has_ratio = True
283-
except RuntimeError as e:
283+
except TypeError:
284284
self.se_has_ratio = False
285285
self.drop_path_rate = drop_path_rate
286286
if feature_location == 'depthwise':

timm/models/layers/eca.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
import torch.nn.functional as F
3939

4040

41+
from .create_act import create_act_layer
42+
43+
4144
class EcaModule(nn.Module):
4245
"""Constructs an ECA module.
4346
@@ -48,20 +51,27 @@ class EcaModule(nn.Module):
4851
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
4952
(default=None. if channel size not given, use k_size given for kernel size.)
5053
kernel_size: Adaptive selection of kernel size (default=3)
54+
gamm: used in kernel_size calc, see above
55+
beta: used in kernel_size calc, see above
56+
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
57+
gate_layer: gating non-linearity to use
5158
"""
52-
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
59+
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'):
5360
super(EcaModule, self).__init__()
54-
assert kernel_size % 2 == 1
5561
if channels is not None:
5662
t = int(abs(math.log(channels, 2) + beta) / gamma)
5763
kernel_size = max(t if t % 2 else t + 1, 3)
58-
59-
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
64+
assert kernel_size % 2 == 1
65+
has_act = act_layer is not None
66+
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=has_act)
67+
self.act = create_act_layer(act_layer) if has_act else nn.Identity()
68+
self.gate = create_act_layer(gate_layer)
6069

6170
def forward(self, x):
6271
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
6372
y = self.conv(y)
64-
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
73+
y = self.act(y) # NOTE: usually a no-op, added for experimentation
74+
y = self.gate(y).view(x.shape[0], -1, 1, 1)
6575
return x * y.expand_as(x)
6676

6777

@@ -86,27 +96,35 @@ class CecaModule(nn.Module):
8696
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
8797
(default=None. if channel size not given, use k_size given for kernel size.)
8898
kernel_size: Adaptive selection of kernel size (default=3)
99+
gamm: used in kernel_size calc, see above
100+
beta: used in kernel_size calc, see above
101+
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
102+
gate_layer: gating non-linearity to use
89103
"""
90104

91-
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
105+
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'):
92106
super(CecaModule, self).__init__()
93-
assert kernel_size % 2 == 1
94107
if channels is not None:
95108
t = int(abs(math.log(channels, 2) + beta) / gamma)
96109
kernel_size = max(t if t % 2 else t + 1, 3)
110+
has_act = act_layer is not None
111+
assert kernel_size % 2 == 1
97112

98113
# PyTorch circular padding mode is buggy as of pytorch 1.4
99114
# see https://github.com/pytorch/pytorch/pull/17240
100115
# implement manual circular padding
101-
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
102116
self.padding = (kernel_size - 1) // 2
117+
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)
118+
self.act = create_act_layer(act_layer) if has_act else nn.Identity()
119+
self.gate = create_act_layer(gate_layer)
103120

104121
def forward(self, x):
105122
y = x.mean((2, 3)).view(x.shape[0], 1, -1)
106123
# Manually implement circular padding, F.pad does not seemed to be bugged
107124
y = F.pad(y, (self.padding, self.padding), mode='circular')
108125
y = self.conv(y)
109-
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
126+
y = self.act(y) # NOTE: usually a no-op, added for experimentation
127+
y = self.gate(y).view(x.shape[0], -1, 1, 1)
110128
return x * y.expand_as(x)
111129

112130

0 commit comments

Comments
 (0)