Skip to content

Commit 43aa84e

Browse files
committed
Add 'fast' layer norm that doesn't cast to float32, support APEX LN impl for slight speed gain, update norm and act factories, tweak SE for ability to disable bias (needed by GCVit)
1 parent c486aa7 commit 43aa84e

File tree

7 files changed

+201
-15
lines changed

7 files changed

+201
-15
lines changed

timm/models/layers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
from .create_act import create_act_layer, get_act_layer, get_act_fn
1212
from .create_attn import get_attn, create_attn
1313
from .create_conv2d import create_conv2d
14+
from .create_norm import get_norm_layer, create_norm_layer
1415
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
1516
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
1617
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
1718
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
1819
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
20+
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
1921
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
2022
from .gather_excite import GatherExcite
2123
from .global_context import GlobalContext
@@ -25,7 +27,7 @@
2527
from .mixed_conv2d import MixedConv2d
2628
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
2729
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
28-
from .norm import GroupNorm, GroupNorm1, LayerNorm2d
30+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
2931
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
3032
from .padding import get_padding, get_same_padding, pad_same
3133
from .patch_embed import PatchEmbed

timm/models/layers/create_act.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,10 @@ def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
145145
act_layer = get_act_layer(name)
146146
if act_layer is None:
147147
return None
148-
return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)
148+
if inplace is None:
149+
return act_layer(**kwargs)
150+
try:
151+
return act_layer(inplace=inplace, **kwargs)
152+
except TypeError:
153+
# recover if act layer doesn't have inplace arg
154+
return act_layer(**kwargs)

timm/models/layers/create_norm.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
""" Norm Layer Factory
2+
3+
Create norm modules by string (to mirror create_act and creat_norm-act fns)
4+
5+
Copyright 2022 Ross Wightman
6+
"""
7+
import types
8+
import functools
9+
10+
import torch.nn as nn
11+
12+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
13+
14+
_NORM_MAP = dict(
15+
batchnorm=nn.BatchNorm2d,
16+
batchnorm2d=nn.BatchNorm2d,
17+
batchnorm1d=nn.BatchNorm1d,
18+
groupnorm=GroupNorm,
19+
groupnorm1=GroupNorm1,
20+
layernorm=LayerNorm,
21+
layernorm2d=LayerNorm2d,
22+
)
23+
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
24+
25+
26+
def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs):
27+
layer = get_norm_layer(layer_name, act_layer=act_layer)
28+
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
29+
return layer_instance
30+
31+
32+
def get_norm_layer(norm_layer):
33+
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
34+
norm_kwargs = {}
35+
36+
# unbind partial fn, so args can be rebound later
37+
if isinstance(norm_layer, functools.partial):
38+
norm_kwargs.update(norm_layer.keywords)
39+
norm_layer = norm_layer.func
40+
41+
if isinstance(norm_layer, str):
42+
layer_name = norm_layer.replace('_', '')
43+
norm_layer = _NORM_MAP.get(layer_name, None)
44+
elif norm_layer in _NORM_TYPES:
45+
norm_layer = norm_layer
46+
elif isinstance(norm_layer, types.FunctionType):
47+
# if function type, assume it is a lambda/fn that creates a norm layer
48+
norm_layer = norm_layer
49+
else:
50+
type_name = norm_layer.__name__.lower().replace('_', '')
51+
norm_layer = _NORM_MAP.get(type_name, None)
52+
assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
53+
54+
if norm_kwargs:
55+
norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
56+
return norm_layer

timm/models/layers/fast_norm.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import List, Optional
2+
3+
import torch
4+
from torch.nn import functional as F
5+
6+
try:
7+
from apex.normalization.fused_layer_norm import fused_layer_norm_affine
8+
has_apex = True
9+
except ImportError:
10+
has_apex = False
11+
12+
13+
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
14+
_USE_FAST_NORM = False # defaulting to False for now
15+
16+
17+
def is_fast_norm():
18+
return _USE_FAST_NORM
19+
20+
21+
def set_fast_norm(enable=True):
22+
global _USE_FAST_NORM
23+
_USE_FAST_NORM = enable
24+
25+
26+
def fast_group_norm(
27+
x: torch.Tensor,
28+
num_groups: int,
29+
weight: Optional[torch.Tensor] = None,
30+
bias: Optional[torch.Tensor] = None,
31+
eps: float = 1e-5
32+
) -> torch.Tensor:
33+
if torch.jit.is_scripting():
34+
# currently cannot use is_autocast_enabled within torchscript
35+
return F.group_norm(x, num_groups, weight, bias, eps)
36+
37+
if torch.is_autocast_enabled():
38+
# normally native AMP casts GN inputs to float32
39+
# here we use the low precision autocast dtype
40+
dt = torch.get_autocast_gpu_dtype()
41+
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
42+
43+
with torch.cuda.amp.autocast(enabled=False):
44+
return F.group_norm(x, num_groups, weight, bias, eps)
45+
46+
47+
def fast_layer_norm(
48+
x: torch.Tensor,
49+
normalized_shape: List[int],
50+
weight: Optional[torch.Tensor] = None,
51+
bias: Optional[torch.Tensor] = None,
52+
eps: float = 1e-5
53+
) -> torch.Tensor:
54+
if torch.jit.is_scripting():
55+
# currently cannot use is_autocast_enabled within torchscript
56+
return F.layer_norm(x, normalized_shape, weight, bias, eps)
57+
58+
if has_apex:
59+
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
60+
61+
if torch.is_autocast_enabled():
62+
# normally native AMP casts LN inputs to float32
63+
# apex LN does not, this is behaving like Apex
64+
dt = torch.get_autocast_gpu_dtype()
65+
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
66+
67+
with torch.cuda.amp.autocast(enabled=False):
68+
return F.layer_norm(x, normalized_shape, weight, bias, eps)

timm/models/layers/norm.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
""" Normalization layers and wrappers
22
"""
3+
34
import torch
45
import torch.nn as nn
56
import torch.nn.functional as F
67

8+
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
9+
710

811
class GroupNorm(nn.GroupNorm):
912
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
1013
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
1114
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
15+
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
1216

1317
def forward(self, x):
14-
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
18+
if self.fast_norm:
19+
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
20+
else:
21+
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
1522

1623

1724
class GroupNorm1(nn.GroupNorm):
@@ -21,22 +28,48 @@ class GroupNorm1(nn.GroupNorm):
2128

2229
def __init__(self, num_channels, **kwargs):
2330
super().__init__(1, num_channels, **kwargs)
31+
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
32+
33+
def forward(self, x: torch.Tensor) -> torch.Tensor:
34+
if self.fast_norm:
35+
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
36+
else:
37+
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
38+
39+
40+
class LayerNorm(nn.LayerNorm):
41+
""" LayerNorm w/ fast norm option
42+
"""
43+
def __init__(self, num_channels, eps=1e-6, affine=True):
44+
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
45+
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
46+
47+
def forward(self, x: torch.Tensor) -> torch.Tensor:
48+
if self._fast_norm:
49+
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
50+
else:
51+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
52+
return x
2453

2554

2655
class LayerNorm2d(nn.LayerNorm):
2756
""" LayerNorm for channels of '2D' spatial NCHW tensors """
2857
def __init__(self, num_channels, eps=1e-6, affine=True):
2958
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
59+
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
3060

3161
def forward(self, x: torch.Tensor) -> torch.Tensor:
32-
return F.layer_norm(
33-
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
62+
x = x.permute(0, 2, 3, 1)
63+
if self._fast_norm:
64+
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
65+
else:
66+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
67+
x = x.permute(0, 3, 1, 2)
68+
return x
3469

3570

3671
def _is_contiguous(tensor: torch.Tensor) -> bool:
3772
# jit is oh so lovely :/
38-
# if torch.jit.is_tracing():
39-
# return True
4073
if torch.jit.is_scripting():
4174
return tensor.is_contiguous()
4275
else:
@@ -51,6 +84,14 @@ def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ep
5184
return x
5285

5386

87+
def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
88+
u = x.mean(dim=1, keepdim=True)
89+
s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)
90+
x = (x - u) * torch.rsqrt(s + eps)
91+
x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)
92+
return x
93+
94+
5495
class LayerNormExp2d(nn.LayerNorm):
5596
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
5697

timm/models/layers/norm_act.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from torch import nn as nn
77
from torch.nn import functional as F
88

9-
from .trace_utils import _assert
109
from .create_act import get_act_layer
10+
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
11+
from .trace_utils import _assert
1112

1213

1314
class BatchNormAct2d(nn.BatchNorm2d):
@@ -177,9 +178,13 @@ def __init__(
177178
self.act = act_layer(**act_args)
178179
else:
179180
self.act = nn.Identity()
181+
self._fast_norm = is_fast_norm()
180182

181183
def forward(self, x):
182-
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
184+
if self._fast_norm:
185+
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
186+
else:
187+
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
183188
x = self.drop(x)
184189
x = self.act(x)
185190
return x
@@ -197,9 +202,13 @@ def __init__(
197202
self.act = act_layer(**act_args)
198203
else:
199204
self.act = nn.Identity()
205+
self._fast_norm = is_fast_norm()
200206

201207
def forward(self, x):
202-
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
208+
if self._fast_norm:
209+
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
210+
else:
211+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
203212
x = self.drop(x)
204213
x = self.act(x)
205214
return x
@@ -219,8 +228,12 @@ def __init__(
219228
self.act = nn.Identity()
220229

221230
def forward(self, x):
222-
x = F.layer_norm(
223-
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
231+
x = x.permute(0, 2, 3, 1)
232+
if self._fast_norm:
233+
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
234+
else:
235+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
236+
x = x.permute(0, 3, 1, 2)
224237
x = self.drop(x)
225238
x = self.act(x)
226239
return x

timm/models/layers/squeeze_excite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ class SEModule(nn.Module):
2727
"""
2828
def __init__(
2929
self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
30-
act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
30+
bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
3131
super(SEModule, self).__init__()
3232
self.add_maxpool = add_maxpool
3333
if not rd_channels:
3434
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
35-
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
35+
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias)
3636
self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
3737
self.act = create_act_layer(act_layer, inplace=True)
38-
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
38+
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias)
3939
self.gate = create_act_layer(gate_layer)
4040

4141
def forward(self, x):

0 commit comments

Comments
 (0)