Skip to content

Commit ba2ca4b

Browse files
committed
One codepath for stdconv, switch layernorm to batchnorm so gain included. Tweak epsilon values for nfnet, resnetv2, vit hybrid.
1 parent 2f5ed2d commit ba2ca4b

File tree

4 files changed

+33
-65
lines changed

4 files changed

+33
-65
lines changed

timm/models/layers/std_conv.py

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,20 @@ class StdConv2d(nn.Conv2d):
1818
https://arxiv.org/abs/1903.10520v2
1919
"""
2020
def __init__(
21-
self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1,
22-
groups=1, bias=False, eps=1e-5, use_layernorm=True):
21+
self, in_channel, out_channels, kernel_size, stride=1, padding=None,
22+
dilation=1, groups=1, bias=False, eps=1e-6):
2323
if padding is None:
2424
padding = get_padding(kernel_size, stride, dilation)
2525
super().__init__(
2626
in_channel, out_channels, kernel_size, stride=stride,
2727
padding=padding, dilation=dilation, groups=groups, bias=bias)
2828
self.eps = eps
29-
self.use_layernorm = use_layernorm
30-
31-
def get_weight(self):
32-
if self.use_layernorm:
33-
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
34-
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
35-
else:
36-
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
37-
weight = (self.weight - mean) / (std + self.eps)
38-
return weight
3929

4030
def forward(self, x):
41-
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
31+
weight = F.batch_norm(
32+
self.weight.view(1, self.out_channels, -1), None, None,
33+
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
34+
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
4235
return x
4336

4437

@@ -49,29 +42,22 @@ class StdConv2dSame(nn.Conv2d):
4942
https://arxiv.org/abs/1903.10520v2
5043
"""
5144
def __init__(
52-
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1,
53-
groups=1, bias=False, eps=1e-5, use_layernorm=True):
45+
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME',
46+
dilation=1, groups=1, bias=False, eps=1e-6):
5447
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
5548
super().__init__(
5649
in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
5750
groups=groups, bias=bias)
5851
self.same_pad = is_dynamic
5952
self.eps = eps
60-
self.use_layernorm = use_layernorm
61-
62-
def get_weight(self):
63-
if self.use_layernorm:
64-
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
65-
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
66-
else:
67-
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
68-
weight = (self.weight - mean) / (std + self.eps)
69-
return weight
7053

7154
def forward(self, x):
7255
if self.same_pad:
7356
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
74-
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
57+
weight = F.batch_norm(
58+
self.weight.view(1, self.out_channels, -1), None, None,
59+
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
60+
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
7561
return x
7662

7763

@@ -85,8 +71,8 @@ class ScaledStdConv2d(nn.Conv2d):
8571
"""
8672

8773
def __init__(
88-
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
89-
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True):
74+
self, in_channels, out_channels, kernel_size, stride=1, padding=None,
75+
dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
9076
if padding is None:
9177
padding = get_padding(kernel_size, stride, dilation)
9278
super().__init__(
@@ -95,19 +81,13 @@ def __init__(
9581
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
9682
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
9783
self.eps = eps
98-
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
99-
100-
def get_weight(self):
101-
if self.use_layernorm:
102-
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
103-
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
104-
else:
105-
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
106-
weight = (self.weight - mean) / (std + self.eps)
107-
return weight.mul_(self.gain * self.scale)
10884

10985
def forward(self, x):
110-
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
86+
weight = F.batch_norm(
87+
self.weight.view(1, self.out_channels, -1), None, None,
88+
weight=(self.gain * self.scale).view(-1),
89+
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
90+
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
11191

11292

11393
class ScaledStdConv2dSame(nn.Conv2d):
@@ -120,8 +100,8 @@ class ScaledStdConv2dSame(nn.Conv2d):
120100
"""
121101

122102
def __init__(
123-
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1,
124-
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True):
103+
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME',
104+
dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
125105
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
126106
super().__init__(
127107
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
@@ -130,18 +110,12 @@ def __init__(
130110
self.scale = gamma * self.weight[0].numel() ** -0.5
131111
self.same_pad = is_dynamic
132112
self.eps = eps
133-
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
134-
135-
def get_weight(self):
136-
if self.use_layernorm:
137-
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
138-
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
139-
else:
140-
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
141-
weight = (self.weight - mean) / (std + self.eps)
142-
return weight.mul_(self.gain * self.scale)
143113

144114
def forward(self, x):
145115
if self.same_pad:
146116
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
147-
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
117+
weight = F.batch_norm(
118+
self.weight.view(1, self.out_channels, -1), None, None,
119+
weight=(self.gain * self.scale).view(-1),
120+
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
121+
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

timm/models/nfnet.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ class NfCfg:
167167
gamma_in_act: bool = False
168168
same_padding: bool = False
169169
std_conv_eps: float = 1e-5
170-
std_conv_ln: bool = True # use layer-norm impl to normalize in std-conv, works in PyTorch XLA, slightly faster
171170
skipinit: bool = False # disabled by default, non-trivial performance impact
172171
zero_init_fc: bool = False
173172
act_layer: str = 'silu'
@@ -484,11 +483,10 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
484483
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
485484
if cfg.gamma_in_act:
486485
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
487-
conv_layer = partial(conv_layer, eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln)
486+
conv_layer = partial(conv_layer, eps=cfg.std_conv_eps)
488487
else:
489488
act_layer = get_act_layer(cfg.act_layer)
490-
conv_layer = partial(
491-
conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln)
489+
conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps)
492490
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
493491

494492
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)

timm/models/resnetv2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def forward(self, x):
276276

277277
def create_resnetv2_stem(
278278
in_chs, out_chs=64, stem_type='', preact=True,
279-
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
279+
conv_layer=partial(StdConv2d, eps=1e-8), norm_layer=partial(GroupNormAct, num_groups=32)):
280280
stem = OrderedDict()
281281
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
282282

@@ -315,8 +315,8 @@ class ResNetV2(nn.Module):
315315
def __init__(self, layers, channels=(256, 512, 1024, 2048),
316316
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
317317
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
318-
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
319-
drop_rate=0., drop_path_rate=0.):
318+
act_layer=nn.ReLU, conv_layer=partial(StdConv2d, eps=1e-8),
319+
norm_layer=partial(GroupNormAct, num_groups=32), drop_rate=0., drop_path_rate=0.):
320320
super().__init__()
321321
self.num_classes = num_classes
322322
self.drop_rate = drop_rate

timm/models/vision_transformer_hybrid.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,8 @@ def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwa
116116
def _resnetv2(layers=(3, 4, 9), **kwargs):
117117
""" ResNet-V2 backbone helper"""
118118
padding_same = kwargs.get('padding_same', True)
119-
if padding_same:
120-
stem_type = 'same'
121-
conv_layer = partial(StdConv2dSame, eps=1e-5)
122-
else:
123-
stem_type = ''
124-
conv_layer = StdConv2d
119+
stem_type = 'same' if padding_same else ''
120+
conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8)
125121
if len(layers):
126122
backbone = ResNetV2(
127123
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),

0 commit comments

Comments
 (0)