Skip to content

Commit 5a8e1e6

Browse files
committed
Initial Normalizer-Free Reg/ResNet impl. A bit of related layer refactoring.
1 parent 9a38416 commit 5a8e1e6

File tree

10 files changed

+572
-58
lines changed

10 files changed

+572
-58
lines changed

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .inception_v4 import *
1212
from .mobilenetv3 import *
1313
from .nasnet import *
14+
from .nfnet import *
1415
from .pnasnet import *
1516
from .regnet import *
1617
from .res2net import *

timm/models/layers/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
from .conv2d_same import Conv2dSame, conv2d_same
1111
from .conv_bn_act import ConvBnAct
1212
from .create_act import create_act_layer, get_act_layer, get_act_fn
13-
from .create_attn import create_attn
13+
from .create_attn import get_attn, create_attn
1414
from .create_conv2d import create_conv2d
1515
from .create_norm_act import create_norm_act, get_norm_act_layer
1616
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
1717
from .eca import EcaModule, CecaModule
1818
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
19-
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
19+
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
2020
from .inplace_abn import InplaceAbn
2121
from .linear import Linear
2222
from .mixed_conv2d import MixedConv2d
@@ -29,5 +29,6 @@
2929
from .space_to_depth import SpaceToDepthModule
3030
from .split_attn import SplitAttnConv2d
3131
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
32+
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d
3233
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
3334
from .weight_init import trunc_normal_

timm/models/layers/create_attn.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .cbam import CbamModule, LightCbamModule
99

1010

11-
def create_attn(attn_type, channels, **kwargs):
11+
def get_attn(attn_type):
1212
module_cls = None
1313
if attn_type is not None:
1414
if isinstance(attn_type, str):
@@ -32,6 +32,12 @@ def create_attn(attn_type, channels, **kwargs):
3232
module_cls = SEModule
3333
else:
3434
module_cls = attn_type
35+
return module_cls
36+
37+
38+
def create_attn(attn_type, channels, **kwargs):
39+
module_cls = get_attn(attn_type)
3540
if module_cls is not None:
41+
# NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
3642
return module_cls(channels, **kwargs)
3743
return None

timm/models/layers/helpers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def parse(x):
2222
to_ntuple = _ntuple
2323

2424

25-
26-
27-
25+
def make_divisible(v, divisor=8, min_value=None):
26+
min_value = min_value or divisor
27+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
28+
# Make sure that round down does not go down by more than 10%.
29+
if new_v < 0.9 * v:
30+
new_v += divisor
31+
return new_v

timm/models/layers/se.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
from torch import nn as nn
2+
import torch.nn.functional as F
3+
24
from .create_act import create_act_layer
5+
from .helpers import make_divisible
36

47

58
class SEModule(nn.Module):
6-
7-
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
8-
gate_layer='sigmoid'):
9+
""" SE Module as defined in original SE-Nets with a few additions
10+
Additions include:
11+
* min_channels can be specified to keep reduced channel count at a minimum (default: 8)
12+
* divisor can be specified to keep channels rounded to specified values (default: 1)
13+
* reduction channels can be specified directly by arg (if reduction_channels is set)
14+
* reduction channels can be specified by float ratio (if reduction_ratio is set)
15+
"""
16+
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid',
17+
reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1):
918
super(SEModule, self).__init__()
10-
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
19+
if reduction_channels is not None:
20+
reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done
21+
elif reduction_ratio is not None:
22+
reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels)
23+
else:
24+
reduction_channels = make_divisible(channels // reduction, divisor, min_channels)
1125
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
1226
self.act = act_layer(inplace=True)
1327
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)

timm/models/layers/std_conv.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
6+
from .padding import get_padding
7+
from .conv2d_same import conv2d_same
8+
9+
10+
def get_weight(module):
11+
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
12+
weight = (module.weight - mean) / (std + module.eps)
13+
return weight
14+
15+
16+
class StdConv2d(nn.Conv2d):
17+
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
18+
19+
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
20+
https://arxiv.org/abs/1903.10520v2
21+
"""
22+
def __init__(
23+
self, in_channel, out_channels, kernel_size, stride=1,
24+
padding=None, dilation=1, groups=1, bias=False, eps=1e-5):
25+
if padding is None:
26+
padding = get_padding(kernel_size, stride, dilation)
27+
super().__init__(
28+
in_channel, out_channels, kernel_size, stride=stride,
29+
padding=padding, dilation=dilation, groups=groups, bias=bias)
30+
self.eps = eps
31+
32+
def get_weight(self):
33+
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
34+
weight = (self.weight - mean) / (std + self.eps)
35+
return weight
36+
37+
def forward(self, x):
38+
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
39+
return x
40+
41+
42+
class StdConv2dSame(nn.Conv2d):
43+
"""Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.
44+
45+
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
46+
https://arxiv.org/abs/1903.10520v2
47+
"""
48+
def __init__(
49+
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5):
50+
super().__init__(
51+
in_channel, out_channels, kernel_size, stride=stride,
52+
padding=0, dilation=dilation, groups=groups, bias=bias)
53+
self.eps = eps
54+
55+
def get_weight(self):
56+
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
57+
weight = (self.weight - mean) / (std + self.eps)
58+
return weight
59+
60+
def forward(self, x):
61+
x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
62+
return x
63+
64+
65+
class ScaledStdConv2d(nn.Conv2d):
66+
"""Conv2d layer with Scaled Weight Standardization.
67+
68+
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
69+
https://arxiv.org/abs/2101.08692
70+
"""
71+
72+
def __init__(self, in_channels, out_channels, kernel_size,
73+
stride=1, padding=None, dilation=1, groups=1, bias=True, gain=True, gamma=1.0, eps=1e-5):
74+
if padding is None:
75+
padding = get_padding(kernel_size, stride, dilation)
76+
super().__init__(
77+
in_channels, out_channels, kernel_size, stride=stride,
78+
padding=padding, dilation=dilation, groups=groups, bias=bias)
79+
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
80+
self.gamma = gamma * self.weight[0].numel() ** 0.5 # gamma * sqrt(fan-in)
81+
self.eps = eps
82+
83+
def get_weight(self):
84+
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
85+
weight = (self.weight - mean) / (self.gamma * std + self.eps)
86+
if self.gain is not None:
87+
weight = weight * self.gain
88+
return weight
89+
90+
def forward(self, x):
91+
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

0 commit comments

Comments
 (0)