Skip to content

Commit 678ba4e

Browse files
committed
Add NFNet-F model weights ported from DeepMind Haiku impl and new set of models w/ compatible config.
1 parent 4ea5931 commit 678ba4e

File tree

4 files changed

+234
-63
lines changed

4 files changed

+234
-63
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,20 @@
22

33
## What's New
44

5+
### Feb 18, 2021
6+
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
7+
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
8+
* These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants.
9+
* Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated).
10+
* Matching the original pre-processing as closely as possible I get these results:
11+
* `dm_nfnet_f6` - 86.352
12+
* `dm_nfnet_f5` - 86.100
13+
* `dm_nfnet_f4` - 85.834
14+
* `dm_nfnet_f3` - 85.676
15+
* `dm_nfnet_f2` - 85.178
16+
* `dm_nfnet_f1` - 84.696
17+
* `dm_nfnet_f0` - 83.464
18+
519
### Feb 16, 2021
620
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
721
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`

timm/models/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@
2929
from .space_to_depth import SpaceToDepthModule
3030
from .split_attn import SplitAttnConv2d
3131
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
32-
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d
32+
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
3333
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
3434
from .weight_init import trunc_normal_

timm/models/layers/std_conv.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from .padding import get_padding
6-
from .conv2d_same import conv2d_same
5+
from .padding import get_padding, get_padding_value, pad_same
76

87

98
def get_weight(module):
@@ -19,8 +18,8 @@ class StdConv2d(nn.Conv2d):
1918
https://arxiv.org/abs/1903.10520v2
2019
"""
2120
def __init__(
22-
self, in_channel, out_channels, kernel_size, stride=1,
23-
padding=None, dilation=1, groups=1, bias=False, eps=1e-5):
21+
self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1,
22+
groups=1, bias=False, eps=1e-5):
2423
if padding is None:
2524
padding = get_padding(kernel_size, stride, dilation)
2625
super().__init__(
@@ -45,10 +44,13 @@ class StdConv2dSame(nn.Conv2d):
4544
https://arxiv.org/abs/1903.10520v2
4645
"""
4746
def __init__(
48-
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5):
47+
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1,
48+
groups=1, bias=False, eps=1e-5):
49+
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
4950
super().__init__(
50-
in_channel, out_channels, kernel_size, stride=stride,
51-
padding=0, dilation=dilation, groups=groups, bias=bias)
51+
in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
52+
groups=groups, bias=bias)
53+
self.same_pad = is_dynamic
5254
self.eps = eps
5355

5456
def get_weight(self):
@@ -57,7 +59,9 @@ def get_weight(self):
5759
return weight
5860

5961
def forward(self, x):
60-
x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
62+
if self.same_pad:
63+
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
64+
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
6165
return x
6266

6367

@@ -68,27 +72,71 @@ class ScaledStdConv2d(nn.Conv2d):
6872
https://arxiv.org/abs/2101.08692
6973
"""
7074

71-
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
72-
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
75+
def __init__(
76+
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
77+
bias=True, gamma=1.0, eps=1e-5, use_layernorm=False):
7378
if padding is None:
7479
padding = get_padding(kernel_size, stride, dilation)
7580
super().__init__(
76-
in_channels, out_channels, kernel_size, stride=stride,
77-
padding=padding, dilation=dilation, groups=groups, bias=bias)
78-
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
81+
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
82+
groups=groups, bias=bias)
83+
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
7984
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
8085
self.eps = eps ** 2 if use_layernorm else eps
81-
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use
86+
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
8287

8388
def get_weight(self):
8489
if self.use_layernorm:
8590
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
8691
else:
8792
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
8893
weight = self.scale * (self.weight - mean) / (std + self.eps)
89-
if self.gain is not None:
90-
weight = weight * self.gain
91-
return weight
94+
return self.gain * weight
95+
96+
def forward(self, x):
97+
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
98+
99+
100+
class ScaledStdConv2dSame(nn.Conv2d):
101+
"""Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support
102+
103+
NOTE: operations and default eps slightly changed from non-SAME impl to closer match Deepmind Haiku impl.
104+
Fore the sake of completeness, numeric differences are minor with arprox .005 top-1 difference.
105+
106+
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
107+
https://arxiv.org/abs/2101.08692
108+
"""
109+
110+
def __init__(
111+
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1,
112+
bias=True, gamma=1.0, eps=1e-5, use_layernorm=False):
113+
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
114+
super().__init__(
115+
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
116+
groups=groups, bias=bias)
117+
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
118+
self.scale = gamma * self.weight[0].numel() ** -0.5
119+
self.same_pad = is_dynamic
120+
self.eps = eps ** 2 if use_layernorm else eps
121+
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
122+
123+
# NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem
124+
# to make much numerical difference (+/- .002 to .004) in top-1 during eval.
125+
# def get_weight(self):
126+
# var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
127+
# scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain
128+
# weight = (self.weight - mean) * scale
129+
# return self.gain * weight
130+
131+
def get_weight(self):
132+
if self.use_layernorm:
133+
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
134+
else:
135+
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
136+
weight = self.scale * (self.weight - mean) / (std + self.eps)
137+
return self.gain * weight
92138

93139
def forward(self, x):
140+
if self.same_pad:
141+
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
94142
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

0 commit comments

Comments
 (0)