Skip to content

Commit 307a935

Browse files
committed
Add non-local and BAT attention. Merge attn and self-attn factories into one. Add attention references to README. Add mlp 'mode' to ECA.
1 parent 17dc47c commit 307a935

File tree

13 files changed

+276
-92
lines changed

13 files changed

+276
-92
lines changed

README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,24 @@ Several (less common) features that I often utilize in my projects are included.
295295
* SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
296296
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
297297
* DropBlock (https://arxiv.org/abs/1810.12890)
298-
* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151)
299298
* Blur Pooling (https://arxiv.org/abs/1904.11486)
300299
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
301300
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
301+
* An extensive selection of channel and/or spatial attention modules:
302+
* Bottleneck Transformer - https://arxiv.org/abs/2101.11605
303+
* CBAM - https://arxiv.org/abs/1807.06521
304+
* Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667
305+
* Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151
306+
* Gather-Excite (GE) - https://arxiv.org/abs/1810.12348
307+
* Global Context (GC) - https://arxiv.org/abs/1904.11492
308+
* Halo - https://arxiv.org/abs/2103.12731
309+
* Involution - https://arxiv.org/abs/2103.06255
310+
* Lambda Layer - https://arxiv.org/abs/2102.08602
311+
* Non-Local (NL) - https://arxiv.org/abs/1711.07971
312+
* Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507
313+
* Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586
314+
* Split (SPLAT) - https://arxiv.org/abs/2004.08955
315+
* Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030
302316

303317
## Results
304318

timm/models/byobnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
3636
from .helpers import build_model_with_cfg
3737
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
38-
create_conv2d, get_act_layer, convert_norm_act, get_attn, get_self_attn, make_divisible, to_2tuple
38+
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple
3939
from .registry import register_model
4040

4141
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
@@ -935,7 +935,7 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo
935935
else:
936936
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
937937
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
938-
self_attn_layer = partial(get_self_attn(self_attn_layer), *self_attn_kwargs) \
938+
self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \
939939
if self_attn_layer is not None else None
940940
layer_fns = replace(layer_fns, self_attn=self_attn_layer)
941941

@@ -1010,7 +1010,7 @@ def get_layer_fns(cfg: ByoModelCfg):
10101010
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
10111011
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
10121012
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
1013-
self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
1013+
self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
10141014
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
10151015
return layer_fn
10161016

timm/models/efficientnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,8 @@ def eca_efficientnet_b0(pretrained=False, **kwargs):
12341234
""" EfficientNet-B0 w/ ECA attn """
12351235
# NOTE experimental config
12361236
model = _gen_efficientnet(
1237-
'eca_efficientnet_b0', se_layer='eca', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1237+
'eca_efficientnet_b0', se_layer='ecam', channel_multiplier=1.0, depth_multiplier=1.0,
1238+
pretrained=pretrained, **kwargs)
12381239
return model
12391240

12401241

@@ -1243,7 +1244,8 @@ def gc_efficientnet_b0(pretrained=False, **kwargs):
12431244
""" EfficientNet-B0 w/ GlobalContext """
12441245
# NOTE experminetal config
12451246
model = _gen_efficientnet(
1246-
'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1247+
'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0,
1248+
pretrained=pretrained, **kwargs)
12471249
return model
12481250

12491251

timm/models/layers/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from .create_attn import get_attn, create_attn
1313
from .create_conv2d import create_conv2d
1414
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
15-
from .create_self_attn import get_self_attn, create_self_attn
1615
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
1716
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
1817
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
@@ -24,16 +23,17 @@
2423
from .linear import Linear
2524
from .mixed_conv2d import MixedConv2d
2625
from .mlp import Mlp, GluMlp, GatedMlp
26+
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
2727
from .norm import GroupNorm, LayerNorm2d
2828
from .norm_act import BatchNormAct2d, GroupNormAct
2929
from .padding import get_padding, get_same_padding, pad_same
3030
from .patch_embed import PatchEmbed
3131
from .pool2d_same import AvgPool2dSame, create_pool2d
3232
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
33-
from .selective_kernel import SelectiveKernelConv
33+
from .selective_kernel import SelectiveKernel
3434
from .separable_conv import SeparableConv2d, SeparableConvBnAct
3535
from .space_to_depth import SpaceToDepthModule
36-
from .split_attn import SplitAttnConv2d
36+
from .split_attn import SplitAttn
3737
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
3838
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
3939
from .test_time_pool import TestTimePoolHead, apply_test_time_pool

timm/models/layers/create_attn.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1-
""" Select AttentionFactory Method
1+
""" Attention Factory
22
3-
Hacked together by / Copyright 2020 Ross Wightman
3+
Hacked together by / Copyright 2021 Ross Wightman
44
"""
55
import torch
6+
from functools import partial
67

8+
from .bottleneck_attn import BottleneckAttn
79
from .cbam import CbamModule, LightCbamModule
810
from .eca import EcaModule, CecaModule
911
from .gather_excite import GatherExcite
1012
from .global_context import GlobalContext
13+
from .halo_attn import HaloAttn
14+
from .involution import Involution
15+
from .lambda_layer import LambdaLayer
16+
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
17+
from .selective_kernel import SelectiveKernel
18+
from .split_attn import SplitAttn
1119
from .squeeze_excite import SEModule, EffectiveSEModule
20+
from .swin_attn import WindowAttention
1221

1322

1423
def get_attn(attn_type):
@@ -18,12 +27,16 @@ def get_attn(attn_type):
1827
if attn_type is not None:
1928
if isinstance(attn_type, str):
2029
attn_type = attn_type.lower()
30+
# Lightweight attention modules (channel and/or coarse spatial).
31+
# Typically added to existing network architecture blocks in addition to existing convolutions.
2132
if attn_type == 'se':
2233
module_cls = SEModule
2334
elif attn_type == 'ese':
2435
module_cls = EffectiveSEModule
2536
elif attn_type == 'eca':
2637
module_cls = EcaModule
38+
elif attn_type == 'ecam':
39+
module_cls = partial(EcaModule, use_mlp=True)
2740
elif attn_type == 'ceca':
2841
module_cls = CecaModule
2942
elif attn_type == 'ge':
@@ -34,6 +47,34 @@ def get_attn(attn_type):
3447
module_cls = CbamModule
3548
elif attn_type == 'lcbam':
3649
module_cls = LightCbamModule
50+
51+
# Attention / attention-like modules w/ significant params
52+
# Typically replace some of the existing workhorse convs in a network architecture.
53+
# All of these accept a stride argument and can spatially downsample the input.
54+
elif attn_type == 'sk':
55+
module_cls = SelectiveKernel
56+
elif attn_type == 'splat':
57+
module_cls = SplitAttn
58+
59+
# Self-attention / attention-like modules w/ significant compute and/or params
60+
# Typically replace some of the existing workhorse convs in a network architecture.
61+
# All of these accept a stride argument and can spatially downsample the input.
62+
elif attn_type == 'lambda':
63+
return LambdaLayer
64+
elif attn_type == 'bottleneck':
65+
return BottleneckAttn
66+
elif attn_type == 'halo':
67+
return HaloAttn
68+
elif attn_type == 'swin':
69+
return WindowAttention
70+
elif attn_type == 'involution':
71+
return Involution
72+
elif attn_type == 'nl':
73+
module_cls = NonLocalAttn
74+
elif attn_type == 'bat':
75+
module_cls = BatNonLocalAttn
76+
77+
# Woops!
3778
else:
3879
assert False, "Invalid attn module (%s)" % attn_type
3980
elif isinstance(attn_type, bool):

timm/models/layers/create_self_attn.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

timm/models/layers/eca.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040

4141
from .create_act import create_act_layer
42+
from .helpers import make_divisible
4243

4344

4445
class EcaModule(nn.Module):
@@ -56,21 +57,36 @@ class EcaModule(nn.Module):
5657
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
5758
gate_layer: gating non-linearity to use
5859
"""
59-
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'):
60+
def __init__(
61+
self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid',
62+
rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False):
6063
super(EcaModule, self).__init__()
6164
if channels is not None:
6265
t = int(abs(math.log(channels, 2) + beta) / gamma)
6366
kernel_size = max(t if t % 2 else t + 1, 3)
6467
assert kernel_size % 2 == 1
65-
has_act = act_layer is not None
66-
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=has_act)
67-
self.act = create_act_layer(act_layer) if has_act else nn.Identity()
68+
padding = (kernel_size - 1) // 2
69+
if use_mlp:
70+
# NOTE 'mlp' mode is a timm experiment, not in paper
71+
assert channels is not None
72+
if rd_channels is None:
73+
rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor)
74+
act_layer = act_layer or nn.ReLU
75+
self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True)
76+
self.act = create_act_layer(act_layer)
77+
self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True)
78+
else:
79+
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
80+
self.act = None
81+
self.conv2 = None
6882
self.gate = create_act_layer(gate_layer)
6983

7084
def forward(self, x):
7185
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
7286
y = self.conv(y)
73-
y = self.act(y) # NOTE: usually a no-op, added for experimentation
87+
if self.conv2 is not None:
88+
y = self.act(y)
89+
y = self.conv2(y)
7490
y = self.gate(y).view(x.shape[0], -1, 1, 1)
7591
return x * y.expand_as(x)
7692

@@ -115,15 +131,13 @@ def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None
115131
# implement manual circular padding
116132
self.padding = (kernel_size - 1) // 2
117133
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)
118-
self.act = create_act_layer(act_layer) if has_act else nn.Identity()
119134
self.gate = create_act_layer(gate_layer)
120135

121136
def forward(self, x):
122137
y = x.mean((2, 3)).view(x.shape[0], 1, -1)
123138
# Manually implement circular padding, F.pad does not seemed to be bugged
124139
y = F.pad(y, (self.padding, self.padding), mode='circular')
125140
y = self.conv(y)
126-
y = self.act(y) # NOTE: usually a no-op, added for experimentation
127141
y = self.gate(y).view(x.shape[0], -1, 1, 1)
128142
return x * y.expand_as(x)
129143

0 commit comments

Comments
 (0)