Skip to content

Commit 90980de

Browse files
committed
Fix up a few details in NFResNet models, managed stable training. Add support for gamma gain to be applied in activation or ScaleStdConv. Some tweaks to ScaledStdConv.
1 parent 5a8e1e6 commit 90980de

File tree

2 files changed

+106
-88
lines changed

2 files changed

+106
-88
lines changed

timm/models/layers/std_conv.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
import numpy as np
54

65
from .padding import get_padding
76
from .conv2d_same import conv2d_same
@@ -69,20 +68,24 @@ class ScaledStdConv2d(nn.Conv2d):
6968
https://arxiv.org/abs/2101.08692
7069
"""
7170

72-
def __init__(self, in_channels, out_channels, kernel_size,
73-
stride=1, padding=None, dilation=1, groups=1, bias=True, gain=True, gamma=1.0, eps=1e-5):
71+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
72+
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
7473
if padding is None:
7574
padding = get_padding(kernel_size, stride, dilation)
7675
super().__init__(
7776
in_channels, out_channels, kernel_size, stride=stride,
7877
padding=padding, dilation=dilation, groups=groups, bias=bias)
7978
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
80-
self.gamma = gamma * self.weight[0].numel() ** 0.5 # gamma * sqrt(fan-in)
81-
self.eps = eps
79+
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
80+
self.eps = eps ** 2 if use_layernorm else eps
81+
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use
8282

8383
def get_weight(self):
84-
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
85-
weight = (self.weight - mean) / (self.gamma * std + self.eps)
84+
if self.use_layernorm:
85+
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
86+
else:
87+
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
88+
weight = self.scale * (self.weight - mean) / (std + self.eps)
8689
if self.gain is not None:
8790
weight = weight * self.gain
8891
return weight

timm/models/nfnet.py

Lines changed: 96 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1919
from .helpers import build_model_with_cfg
2020
from .registry import register_model
21-
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible
21+
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible, get_act_fn
2222

2323

2424
def _dcfg(url='', **kwargs):
@@ -40,17 +40,17 @@ def _dcfg(url='', **kwargs):
4040
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320)),
4141
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384)),
4242

43-
'nf_resnet26d': _dcfg(url='', first_conv='stem.conv1'),
44-
'nf_resnet50d': _dcfg(url='', first_conv='stem.conv1'),
45-
'nf_resnet101d': _dcfg(url='', first_conv='stem.conv1'),
43+
'nf_resnet26': _dcfg(url='', first_conv='stem.conv'),
44+
'nf_resnet50': _dcfg(url='', first_conv='stem.conv'),
45+
'nf_resnet101': _dcfg(url='', first_conv='stem.conv'),
4646

47-
'nf_seresnet26d': _dcfg(url='', first_conv='stem.conv1'),
48-
'nf_seresnet50d': _dcfg(url='', first_conv='stem.conv1'),
49-
'nf_seresnet101d': _dcfg(url='', first_conv='stem.conv1'),
47+
'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'),
48+
'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'),
49+
'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'),
5050

51-
'nf_ecaresnet26d': _dcfg(url='', first_conv='stem.conv1'),
52-
'nf_ecaresnet50d': _dcfg(url='', first_conv='stem.conv1'),
53-
'nf_ecaresnet101d': _dcfg(url='', first_conv='stem.conv1'),
51+
'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'),
52+
'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'),
53+
'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'),
5454
}
5555

5656

@@ -59,6 +59,7 @@ class NfCfg:
5959
depths: Tuple[int, int, int, int]
6060
channels: Tuple[int, int, int, int]
6161
alpha: float = 0.2
62+
gamma_in_act: bool = False
6263
stem_type: str = '3x3'
6364
stem_chs: Optional[int] = None
6465
group_size: Optional[int] = 8
@@ -84,68 +85,65 @@ class NfCfg:
8485
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048),
8586

8687
# ResNet (preact, D style deep stem/avg down) defs
87-
nf_resnet26d=NfCfg(
88+
nf_resnet26=NfCfg(
8889
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
89-
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
90+
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
9091
act_layer='relu', attn_layer=None,),
91-
nf_resnet50d=NfCfg(
92+
nf_resnet50=NfCfg(
9293
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
93-
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
94+
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
9495
act_layer='relu', attn_layer=None),
95-
nf_resnet101d=NfCfg(
96+
nf_resnet101=NfCfg(
9697
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
97-
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
98+
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
9899
act_layer='relu', attn_layer=None),
99100

100101

101-
nf_seresnet26d=NfCfg(
102+
nf_seresnet26=NfCfg(
102103
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
103-
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
104+
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
104105
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
105-
nf_seresnet50d=NfCfg(
106+
nf_seresnet50=NfCfg(
106107
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
107-
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
108+
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
108109
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
109-
nf_seresnet101d=NfCfg(
110+
nf_seresnet101=NfCfg(
110111
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
111-
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
112+
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
112113
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
113114

114115

115-
nf_ecaresnet26d=NfCfg(
116+
nf_ecaresnet26=NfCfg(
116117
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
117-
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
118+
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
118119
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
119-
nf_ecaresnet50d=NfCfg(
120+
nf_ecaresnet50=NfCfg(
120121
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
121-
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
122+
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
122123
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
123-
nf_ecaresnet101d=NfCfg(
124+
nf_ecaresnet101=NfCfg(
124125
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
125-
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
126+
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
126127
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
127128

128129
)
129130

130-
# class NormFreeSiLU(nn.Module):
131-
# _K = 1. / 0.5595
132-
# def __init__(self, inplace=False):
133-
# super().__init__()
134-
# self.inplace = inplace
135-
#
136-
# def forward(self, x):
137-
# return F.silu(x, inplace=self.inplace) * self._K
138-
#
139-
#
140-
# class NormFreeReLU(nn.Module):
141-
# _K = (0.5 * (1. - 1. / math.pi)) ** -0.5
142-
#
143-
# def __init__(self, inplace=False):
144-
# super().__init__()
145-
# self.inplace = inplace
146-
#
147-
# def forward(self, x):
148-
# return F.relu(x, inplace=self.inplace) * self._K
131+
132+
class GammaAct(nn.Module):
133+
def __init__(self, act_type='relu', gamma: float = 1.0, inplace=False):
134+
super().__init__()
135+
self.act_fn = get_act_fn(act_type)
136+
self.gamma = gamma
137+
self.inplace = inplace
138+
139+
def forward(self, x):
140+
return self.gamma * self.act_fn(x, inplace=self.inplace)
141+
142+
143+
def act_with_gamma(act_type, gamma: float = 1.):
144+
def _create(inplace=False):
145+
return GammaAct(act_type, gamma=gamma, inplace=inplace)
146+
return _create
149147

150148

151149
class DownsampleAvg(nn.Module):
@@ -178,10 +176,9 @@ def __init__(
178176
out_chs = out_chs or in_chs
179177
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
180178
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div)
181-
groups = 1
182-
if group_size is not None:
183-
# NOTE: not correcting the mid_chs % group_size, fix model def if broken. I want % ch_div == 0 to stand.
184-
groups = mid_chs // group_size
179+
groups = 1 if group_size is None else mid_chs // group_size
180+
if group_size and group_size % ch_div == 0:
181+
mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
185182
self.alpha = alpha
186183
self.beta = beta
187184
self.attn_gain = attn_gain
@@ -229,10 +226,11 @@ def forward(self, x):
229226

230227

231228
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
229+
stem_stride = 2
232230
stem = OrderedDict()
233-
assert stem_type in ('', 'deep', '3x3', '7x7')
231+
assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
234232
if 'deep' in stem_type:
235-
# 3 deep 3x3 conv stack as in ResNet V1D models
233+
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
236234
mid_chs = out_chs // 2
237235
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
238236
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
@@ -244,12 +242,16 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
244242
# 7x7 stem conv as in ResNet
245243
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
246244

247-
return nn.Sequential(stem)
245+
if 'pool' in stem_type:
246+
stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
247+
stem_stride = 4
248+
249+
return nn.Sequential(stem), stem_stride
248250

249251

250252
_nonlin_gamma = dict(
251-
silu=.5595,
252-
relu=(0.5 * (1. - 1. / math.pi)) ** 0.5,
253+
silu=1./.5595,
254+
relu=(0.5 * (1. - 1. / math.pi)) ** -0.5,
253255
identity=1.0
254256
)
255257

@@ -264,9 +266,12 @@ class NormalizerFreeNet(nn.Module):
264266
the (preact) ResNet models described earlier in the paper.
265267
266268
There are a few differences:
267-
* channels are rounded to be divisible by 8 by default (keep TC happy), this changes param counts
269+
* channels are rounded to be divisible by 8 by default (keep tensor core kernels happy),
270+
this changes channel dim and param counts slightly from the paper models
268271
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
269272
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
273+
* a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
274+
apply it in each activation. This is slightly slower, and yields slightly different results.
270275
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
271276
for what it is/does. Approx 8-10% throughput loss.
272277
"""
@@ -275,29 +280,33 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
275280
super().__init__()
276281
self.num_classes = num_classes
277282
self.drop_rate = drop_rate
278-
act_layer = get_act_layer(cfg.act_layer)
279283
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
280-
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
284+
if cfg.gamma_in_act:
285+
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
286+
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True)
287+
else:
288+
act_layer = get_act_layer(cfg.act_layer)
289+
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
281290
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
282291

283-
self.feature_info = [] # FIXME fill out feature info
284-
285292
stem_chs = cfg.stem_chs or cfg.channels[0]
286293
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
287-
self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
294+
self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
288295

289-
prev_chs = stem_chs
296+
self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4
290297
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
291-
net_stride = 2
298+
prev_chs = stem_chs
299+
net_stride = stem_stride
292300
dilation = 1
293301
expected_var = 1.0
294302
stages = []
295303
for stage_idx, stage_depth in enumerate(cfg.depths):
296-
if net_stride >= output_stride:
297-
dilation *= 2
304+
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
305+
self.feature_info += [dict(
306+
num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')]
307+
if net_stride >= output_stride and stride > 1:
308+
dilation *= stride
298309
stride = 1
299-
else:
300-
stride = 2
301310
net_stride *= stride
302311
first_dilation = 1 if dilation in (1, 2) else 2
303312

@@ -338,7 +347,10 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
338347
else:
339348
self.num_features = prev_chs
340349
self.final_conv = nn.Identity()
350+
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv
341351
self.final_act = act_layer()
352+
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
353+
342354
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
343355

344356
for n, m in self.named_modules():
@@ -373,11 +385,14 @@ def forward(self, x):
373385

374386

375387
def _create_normfreenet(variant, pretrained=False, **kwargs):
388+
model_cfg = model_cfgs[variant]
376389
feature_cfg = dict(flatten_sequential=True)
377390
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
391+
if 'pool' in model_cfg.stem_type:
392+
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet
378393

379394
return build_model_with_cfg(
380-
NormalizerFreeNet, variant, pretrained, model_cfg=model_cfgs[variant], default_cfg=default_cfgs[variant],
395+
NormalizerFreeNet, variant, pretrained, model_cfg=model_cfg, default_cfg=default_cfgs[variant],
381396
feature_cfg=feature_cfg, **kwargs)
382397

383398

@@ -412,30 +427,30 @@ def nf_regnet_b5(pretrained=False, **kwargs):
412427

413428

414429
@register_model
415-
def nf_resnet26d(pretrained=False, **kwargs):
416-
return _create_normfreenet('nf_resnet26d', pretrained=pretrained, **kwargs)
430+
def nf_resnet26(pretrained=False, **kwargs):
431+
return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs)
417432

418433

419434
@register_model
420-
def nf_resnet50d(pretrained=False, **kwargs):
421-
return _create_normfreenet('nf_resnet50d', pretrained=pretrained, **kwargs)
435+
def nf_resnet50(pretrained=False, **kwargs):
436+
return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs)
422437

423438

424439
@register_model
425-
def nf_seresnet26d(pretrained=False, **kwargs):
426-
return _create_normfreenet('nf_seresnet26d', pretrained=pretrained, **kwargs)
440+
def nf_seresnet26(pretrained=False, **kwargs):
441+
return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs)
427442

428443

429444
@register_model
430-
def nf_seresnet50d(pretrained=False, **kwargs):
431-
return _create_normfreenet('nf_seresnet50d', pretrained=pretrained, **kwargs)
445+
def nf_seresnet50(pretrained=False, **kwargs):
446+
return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs)
432447

433448

434449
@register_model
435-
def nf_ecaresnet26d(pretrained=False, **kwargs):
436-
return _create_normfreenet('nf_ecaresnet26d', pretrained=pretrained, **kwargs)
450+
def nf_ecaresnet26(pretrained=False, **kwargs):
451+
return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs)
437452

438453

439454
@register_model
440-
def nf_ecaresnet50d(pretrained=False, **kwargs):
441-
return _create_normfreenet('nf_ecaresnet50d', pretrained=pretrained, **kwargs)
455+
def nf_ecaresnet50(pretrained=False, **kwargs):
456+
return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)

0 commit comments

Comments
 (0)