Skip to content

Commit eb76536

Browse files
committed
Monster commit, activation refactor, VoVNet, norm_act improvements, more
* refactor activations into basic PyTorch, jit scripted, and memory efficient custom auto * implement hard-mish, better grad for hard-swish * add initial VovNet V1/V2 impl, fix #151 * VovNet and DenseNet first models to use NormAct layers (support BatchNormAct2d, EvoNorm, InplaceIABN) * Wrap IABN for any models that use it * make more models torchscript compatible (DPN, PNasNet, Res2Net, SelecSLS) and add tests
1 parent ff94ffc commit eb76536

37 files changed

+1467
-316
lines changed

tests/test_models.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import fnmatch
66

7-
from timm import list_models, create_model
7+
from timm import list_models, create_model, set_scriptable
88

99

1010
if 'GITHUB_ACTIONS' in os.environ and 'Linux' in platform.system():
@@ -53,6 +53,8 @@ def test_model_backward(model_name, batch_size):
5353
inputs = torch.randn((batch_size, *input_size))
5454
outputs = model(inputs)
5555
outputs.mean().backward()
56+
for n, x in model.named_parameters():
57+
assert x.grad is not None, f'No gradient for {n}'
5658
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
5759

5860
assert outputs.shape[-1] == 42
@@ -83,3 +85,25 @@ def test_model_default_cfgs(model_name, batch_size):
8385
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
8486
assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params'
8587
assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params'
88+
89+
90+
EXCLUDE_JIT_FILTERS = [
91+
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
92+
'dla*', 'hrnet*', # hopefully fix at some point
93+
]
94+
95+
96+
@pytest.mark.timeout(120)
97+
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS))
98+
@pytest.mark.parametrize('batch_size', [1])
99+
def test_model_forward_torchscript(model_name, batch_size):
100+
"""Run a single forward pass with each model"""
101+
with set_scriptable(True):
102+
model = create_model(model_name, pretrained=False)
103+
model.eval()
104+
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
105+
model = torch.jit.script(model)
106+
outputs = model(torch.randn((batch_size, *input_size)))
107+
108+
assert outputs.shape[0] == batch_size
109+
assert not torch.isnan(outputs).any(), 'Output included NaNs'

timm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .version import __version__
2-
from .models import create_model, list_models, is_model, list_modules, model_entrypoint
2+
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
3+
is_scriptable, is_exportable, set_scriptable, set_exportable

timm/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from .tresnet import *
2121
from .resnest import *
2222
from .regnet import *
23+
from .vovnet import *
2324

2425
from .registry import *
2526
from .factory import create_model
2627
from .helpers import load_checkpoint, resume_checkpoint
2728
from .layers import TestTimePoolHead, apply_test_time_pool
2829
from .layers import convert_splitbn_model
30+
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit

timm/models/densenet.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ def _cfg(url=''):
4141

4242

4343
class DenseLayer(nn.Module):
44-
def __init__(self, num_input_features, growth_rate, bn_size, norm_act_layer=BatchNormAct2d,
44+
def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d,
4545
drop_rate=0., memory_efficient=False):
4646
super(DenseLayer, self).__init__()
47-
self.add_module('norm1', norm_act_layer(num_input_features)),
47+
self.add_module('norm1', norm_layer(num_input_features)),
4848
self.add_module('conv1', nn.Conv2d(
4949
num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
50-
self.add_module('norm2', norm_act_layer(bn_size * growth_rate)),
50+
self.add_module('norm2', norm_layer(bn_size * growth_rate)),
5151
self.add_module('conv2', nn.Conv2d(
5252
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
5353
self.drop_rate = float(drop_rate)
@@ -109,15 +109,15 @@ def forward(self, x): # noqa: F811
109109
class DenseBlock(nn.ModuleDict):
110110
_version = 2
111111

112-
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_act_layer=nn.ReLU,
112+
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
113113
drop_rate=0., memory_efficient=False):
114114
super(DenseBlock, self).__init__()
115115
for i in range(num_layers):
116116
layer = DenseLayer(
117117
num_input_features + i * growth_rate,
118118
growth_rate=growth_rate,
119119
bn_size=bn_size,
120-
norm_act_layer=norm_act_layer,
120+
norm_layer=norm_layer,
121121
drop_rate=drop_rate,
122122
memory_efficient=memory_efficient,
123123
)
@@ -132,9 +132,9 @@ def forward(self, init_features):
132132

133133

134134
class DenseTransition(nn.Sequential):
135-
def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d, aa_layer=None):
135+
def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None):
136136
super(DenseTransition, self).__init__()
137-
self.add_module('norm', norm_act_layer(num_input_features))
137+
self.add_module('norm', norm_layer(num_input_features))
138138
self.add_module('conv', nn.Conv2d(
139139
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
140140
if aa_layer is not None:
@@ -160,7 +160,7 @@ class DenseNet(nn.Module):
160160

161161
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='',
162162
num_classes=1000, in_chans=3, global_pool='avg',
163-
norm_act_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False):
163+
norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False):
164164
self.num_classes = num_classes
165165
self.drop_rate = drop_rate
166166
super(DenseNet, self).__init__()
@@ -181,17 +181,17 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem
181181
stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4)
182182
self.features = nn.Sequential(OrderedDict([
183183
('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)),
184-
('norm0', norm_act_layer(stem_chs_1)),
184+
('norm0', norm_layer(stem_chs_1)),
185185
('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)),
186-
('norm1', norm_act_layer(stem_chs_2)),
186+
('norm1', norm_layer(stem_chs_2)),
187187
('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)),
188-
('norm2', norm_act_layer(num_init_features)),
188+
('norm2', norm_layer(num_init_features)),
189189
('pool0', stem_pool),
190190
]))
191191
else:
192192
self.features = nn.Sequential(OrderedDict([
193193
('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
194-
('norm0', norm_act_layer(num_init_features)),
194+
('norm0', norm_layer(num_init_features)),
195195
('pool0', stem_pool),
196196
]))
197197

@@ -203,7 +203,7 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem
203203
num_input_features=num_features,
204204
bn_size=bn_size,
205205
growth_rate=growth_rate,
206-
norm_act_layer=norm_act_layer,
206+
norm_layer=norm_layer,
207207
drop_rate=drop_rate,
208208
memory_efficient=memory_efficient
209209
)
@@ -212,12 +212,12 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem
212212
if i != len(block_config) - 1:
213213
trans = DenseTransition(
214214
num_input_features=num_features, num_output_features=num_features // 2,
215-
norm_act_layer=norm_act_layer)
215+
norm_layer=norm_layer)
216216
self.features.add_module('transition%d' % (i + 1), trans)
217217
num_features = num_features // 2
218218

219219
# Final batch norm
220-
self.features.add_module('norm5', norm_act_layer(num_features))
220+
self.features.add_module('norm5', norm_layer(num_features))
221221

222222
# Linear layer
223223
self.num_features = num_features
@@ -346,7 +346,7 @@ def norm_act_fn(num_features, **kwargs):
346346
return create_norm_act('EvoNormBatch', num_features, jit=True, **kwargs)
347347
model = _densenet(
348348
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
349-
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
349+
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
350350
return model
351351

352352

@@ -359,7 +359,7 @@ def norm_act_fn(num_features, **kwargs):
359359
return create_norm_act('EvoNormSample', num_features, jit=True, **kwargs)
360360
model = _densenet(
361361
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
362-
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
362+
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
363363
return model
364364

365365

@@ -372,7 +372,7 @@ def norm_act_fn(num_features, **kwargs):
372372
return create_norm_act('iabn', num_features, **kwargs)
373373
model = _densenet(
374374
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
375-
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
375+
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
376376
return model
377377

378378

timm/models/dpn.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from __future__ import print_function
1111

1212
from collections import OrderedDict
13+
from typing import Union, Optional, List, Tuple
1314

1415
import torch
1516
import torch.nn as nn
@@ -54,8 +55,19 @@ def __init__(self, in_chs, activation_fn=nn.ReLU(inplace=True)):
5455
self.bn = nn.BatchNorm2d(in_chs, eps=0.001)
5556
self.act = activation_fn
5657

58+
@torch.jit._overload_method # noqa: F811
5759
def forward(self, x):
58-
x = torch.cat(x, dim=1) if isinstance(x, tuple) else x
60+
# type: (Tuple[torch.Tensor, torch.Tensor]) -> (torch.Tensor)
61+
pass
62+
63+
@torch.jit._overload_method # noqa: F811
64+
def forward(self, x):
65+
# type: (torch.Tensor) -> (torch.Tensor)
66+
pass
67+
68+
def forward(self, x):
69+
if isinstance(x, tuple):
70+
x = torch.cat(x, dim=1)
5971
return self.act(self.bn(x))
6072

6173

@@ -107,6 +119,8 @@ def __init__(
107119
self.key_stride = 1
108120
self.has_proj = False
109121

122+
self.c1x1_w_s1 = None
123+
self.c1x1_w_s2 = None
110124
if self.has_proj:
111125
# Using different member names here to allow easier parameter key matching for conversion
112126
if self.key_stride == 2:
@@ -115,6 +129,7 @@ def __init__(
115129
else:
116130
self.c1x1_w_s1 = BnActConv2d(
117131
in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1)
132+
118133
self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1)
119134
self.c3x3_b = BnActConv2d(
120135
in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3,
@@ -125,27 +140,46 @@ def __init__(
125140
self.c1x1_c2 = nn.Conv2d(num_3x3_b, inc, kernel_size=1, bias=False)
126141
else:
127142
self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1)
143+
self.c1x1_c1 = None
144+
self.c1x1_c2 = None
128145

146+
@torch.jit._overload_method # noqa: F811
129147
def forward(self, x):
130-
x_in = torch.cat(x, dim=1) if isinstance(x, tuple) else x
131-
if self.has_proj:
132-
if self.key_stride == 2:
133-
x_s = self.c1x1_w_s2(x_in)
134-
else:
135-
x_s = self.c1x1_w_s1(x_in)
136-
x_s1 = x_s[:, :self.num_1x1_c, :, :]
137-
x_s2 = x_s[:, self.num_1x1_c:, :, :]
148+
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
149+
pass
150+
151+
@torch.jit._overload_method # noqa: F811
152+
def forward(self, x):
153+
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
154+
pass
155+
156+
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
157+
if isinstance(x, tuple):
158+
x_in = torch.cat(x, dim=1)
138159
else:
160+
x_in = x
161+
if self.c1x1_w_s1 is None and self.c1x1_w_s2 is None:
162+
# self.has_proj == False, torchscript requires condition on module == None
139163
x_s1 = x[0]
140164
x_s2 = x[1]
165+
else:
166+
# self.has_proj == True
167+
if self.c1x1_w_s1 is not None:
168+
# self.key_stride = 1
169+
x_s = self.c1x1_w_s1(x_in)
170+
else:
171+
# self.key_stride = 2
172+
x_s = self.c1x1_w_s2(x_in)
173+
x_s1 = x_s[:, :self.num_1x1_c, :, :]
174+
x_s2 = x_s[:, self.num_1x1_c:, :, :]
141175
x_in = self.c1x1_a(x_in)
142176
x_in = self.c3x3_b(x_in)
143-
if self.b:
144-
x_in = self.c1x1_c(x_in)
177+
x_in = self.c1x1_c(x_in)
178+
if self.c1x1_c1 is not None:
179+
# self.b == True, using None check for torchscript compat
145180
out1 = self.c1x1_c1(x_in)
146181
out2 = self.c1x1_c2(x_in)
147182
else:
148-
x_in = self.c1x1_c(x_in)
149183
out1 = x_in[:, :self.num_1x1_c, :, :]
150184
out2 = x_in[:, self.num_1x1_c:, :, :]
151185
resid = x_s1 + out1
@@ -167,11 +201,9 @@ def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
167201

168202
# conv1
169203
if small:
170-
blocks['conv1_1'] = InputBlock(
171-
num_init_features, in_chans=in_chans, kernel_size=3, padding=1)
204+
blocks['conv1_1'] = InputBlock(num_init_features, in_chans=in_chans, kernel_size=3, padding=1)
172205
else:
173-
blocks['conv1_1'] = InputBlock(
174-
num_init_features, in_chans=in_chans, kernel_size=7, padding=3)
206+
blocks['conv1_1'] = InputBlock(num_init_features, in_chans=in_chans, kernel_size=7, padding=3)
175207

176208
# conv2
177209
bw = 64 * bw_factor

timm/models/efficientnet.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,15 @@
2424
2525
Hacked together by Ross Wightman
2626
"""
27+
import torch.nn as nn
28+
import torch.nn.functional as F
29+
2730
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
28-
from .efficientnet_builder import *
31+
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
32+
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
2933
from .feature_hooks import FeatureHooks
3034
from .helpers import load_pretrained, adapt_model_from_file
31-
from .layers import SelectAdaptivePool2d
35+
from .layers import SelectAdaptivePool2d, create_conv2d
3236
from .registry import register_model
3337

3438
__all__ = ['EfficientNet']
@@ -631,7 +635,7 @@ def _gen_mobilenet_v2(
631635
fix_stem=fix_stem_head,
632636
channel_multiplier=channel_multiplier,
633637
norm_kwargs=resolve_bn_args(kwargs),
634-
act_layer=nn.ReLU6,
638+
act_layer=resolve_act_layer(kwargs, 'relu6'),
635639
**kwargs
636640
)
637641
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@@ -741,7 +745,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
741745
num_features=round_channels(1280, channel_multiplier, 8, None),
742746
stem_size=32,
743747
channel_multiplier=channel_multiplier,
744-
act_layer=Swish,
748+
act_layer=resolve_act_layer(kwargs, 'swish'),
745749
norm_kwargs=resolve_bn_args(kwargs),
746750
variant=variant,
747751
**kwargs,
@@ -772,7 +776,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
772776
stem_size=32,
773777
channel_multiplier=channel_multiplier,
774778
norm_kwargs=resolve_bn_args(kwargs),
775-
act_layer=nn.ReLU,
779+
act_layer=resolve_act_layer(kwargs, 'relu'),
776780
**kwargs,
777781
)
778782
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@@ -802,7 +806,7 @@ def _gen_efficientnet_condconv(
802806
stem_size=32,
803807
channel_multiplier=channel_multiplier,
804808
norm_kwargs=resolve_bn_args(kwargs),
805-
act_layer=Swish,
809+
act_layer=resolve_act_layer(kwargs, 'swish'),
806810
**kwargs,
807811
)
808812
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@@ -842,7 +846,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
842846
stem_size=32,
843847
fix_stem=True,
844848
channel_multiplier=channel_multiplier,
845-
act_layer=nn.ReLU6,
849+
act_layer=resolve_act_layer(kwargs, 'relu6'),
846850
norm_kwargs=resolve_bn_args(kwargs),
847851
**kwargs,
848852
)

0 commit comments

Comments
 (0)