Skip to content

Commit 0ea53ce

Browse files
committed
Merge branch 'master' into densenet_update_and_more
2 parents 6441e9c + d79ac48 commit 0ea53ce

File tree

16 files changed

+607
-38
lines changed

16 files changed

+607
-38
lines changed

sotabench.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
126126
_entry('skresnet34', 'SK-ResNet-34', '1903.06586'),
127127
_entry('skresnext50_32x4d', 'SKNet-50', '1903.06586'),
128128

129+
_entry('ecaresnetlight', 'ECA-ResNet-Light', '1910.03151',
130+
model_desc='A tweaked ResNet50d with ECA attn.'),
131+
_entry('ecaresnet50d', 'ECA-ResNet-50d', '1910.03151',
132+
model_desc='A ResNet50d with ECA attn'),
133+
_entry('ecaresnet101d', 'ECA-ResNet-101d', '1910.03151',
134+
model_desc='A ResNet101d with ECA attn'),
135+
136+
_entry('resnetblur50', 'ResNet-Blur-50', '1904.11486'),
137+
129138
_entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946',
130139
model_desc='Ported from official Google AI Tensorflow weights'),
131140
_entry('tf_efficientnet_b1', 'EfficientNet-B1 (AutoAugment)', '1905.11946',

tests/test_inference.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

tests/test_models.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pytest
2+
import torch
3+
import platform
4+
import os
5+
import fnmatch
6+
7+
from timm import list_models, create_model
8+
9+
10+
if 'GITHUB_ACTIONS' in os.environ and 'Linux' in platform.system():
11+
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
12+
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d']
13+
else:
14+
EXCLUDE_FILTERS = []
15+
MAX_FWD_SIZE = 384
16+
MAX_BWD_SIZE = 128
17+
MAX_FWD_FEAT_SIZE = 448
18+
19+
20+
@pytest.mark.timeout(120)
21+
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
22+
@pytest.mark.parametrize('batch_size', [1])
23+
def test_model_forward(model_name, batch_size):
24+
"""Run a single forward pass with each model"""
25+
model = create_model(model_name, pretrained=False)
26+
model.eval()
27+
28+
input_size = model.default_cfg['input_size']
29+
if any([x > MAX_FWD_SIZE for x in input_size]):
30+
# cap forward test at max res 448 * 448 to keep resource down
31+
input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size])
32+
inputs = torch.randn((batch_size, *input_size))
33+
outputs = model(inputs)
34+
35+
assert outputs.shape[0] == batch_size
36+
assert not torch.isnan(outputs).any(), 'Output included NaNs'
37+
38+
39+
@pytest.mark.timeout(120)
40+
# DLA models have an issue TBD, add them to exclusions
41+
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + ['dla*']))
42+
@pytest.mark.parametrize('batch_size', [2])
43+
def test_model_backward(model_name, batch_size):
44+
"""Run a single forward pass with each model"""
45+
model = create_model(model_name, pretrained=False, num_classes=42)
46+
num_params = sum([x.numel() for x in model.parameters()])
47+
model.eval()
48+
49+
input_size = model.default_cfg['input_size']
50+
if any([x > MAX_BWD_SIZE for x in input_size]):
51+
# cap backward test at 128 * 128 to keep resource usage down
52+
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
53+
inputs = torch.randn((batch_size, *input_size))
54+
outputs = model(inputs)
55+
outputs.mean().backward()
56+
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
57+
58+
assert outputs.shape[-1] == 42
59+
assert num_params == num_grad, 'Some parameters are missing gradients'
60+
assert not torch.isnan(outputs).any(), 'Output included NaNs'
61+
62+
63+
@pytest.mark.timeout(120)
64+
@pytest.mark.parametrize('model_name', list_models())
65+
@pytest.mark.parametrize('batch_size', [1])
66+
def test_model_default_cfgs(model_name, batch_size):
67+
"""Run a single forward pass with each model"""
68+
model = create_model(model_name, pretrained=False)
69+
model.eval()
70+
state_dict = model.state_dict()
71+
cfg = model.default_cfg
72+
73+
classifier = cfg['classifier']
74+
first_conv = cfg['first_conv']
75+
pool_size = cfg['pool_size']
76+
input_size = model.default_cfg['input_size']
77+
78+
if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \
79+
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
80+
# pool size only checked if default res <= 448 * 448 to keep resource down
81+
input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size])
82+
outputs = model.forward_features(torch.randn((batch_size, *input_size)))
83+
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
84+
assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params'
85+
assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params'

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .sknet import *
2020
from .tresnet import *
2121
from .resnest import *
22+
from .regnet import *
2223

2324
from .registry import *
2425
from .factory import create_model

timm/models/dla.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,11 @@ def __init__(self, levels, block, in_channels, out_channels, stride=1,
237237

238238
def forward(self, x, residual=None, children=None):
239239
children = [] if children is None else children
240-
bottom = self.downsample(x) if self.downsample else x
241-
residual = self.project(bottom) if self.project else bottom
240+
# FIXME the way downsample / project are used here and residual is passed to next level up
241+
# the tree, the residual is overridden and some project weights are thus never used and
242+
# have no gradients. This appears to be an issue with the original model / weights.
243+
bottom = self.downsample(x) if self.downsample is not None else x
244+
residual = self.project(bottom) if self.project is not None else bottom
242245
if self.level_root:
243246
children.append(bottom)
244247
x1 = self.tree1(x, residual)
@@ -355,7 +358,8 @@ def dla60_res2next(pretrained=None, num_classes=1000, in_chans=3, **kwargs):
355358
@register_model
356359
def dla34(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-34
357360
default_cfg = default_cfgs['dla34']
358-
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, **kwargs)
361+
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic,
362+
num_classes=num_classes, in_chans=in_chans, **kwargs)
359363
model.default_cfg = default_cfg
360364
if pretrained:
361365
load_pretrained(model, default_cfg, num_classes, in_chans)

timm/models/gluon_xception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
'url': '',
3737
'input_size': (3, 299, 299),
3838
'crop_pct': 0.875,
39-
'pool_size': (10, 10),
39+
'pool_size': (5, 5),
4040
'interpolation': 'bicubic',
4141
'mean': IMAGENET_DEFAULT_MEAN,
4242
'std': IMAGENET_DEFAULT_STD,

timm/models/hrnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _cfg(url='', **kwargs):
3434
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
3535
'crop_pct': 0.875, 'interpolation': 'bilinear',
3636
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
37-
'first_conv': 'conv1', 'classifier': 'fc',
37+
'first_conv': 'conv1', 'classifier': 'classifier',
3838
**kwargs
3939
}
4040

timm/models/inception_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def _cfg(url='', **kwargs):
1414
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
1515
'crop_pct': 0.875, 'interpolation': 'bicubic',
1616
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
17-
'first_conv': 'conv1', 'classifier': 'fc',
17+
'first_conv': 'Conv2d_1a_3x3', 'classifier': 'fc',
1818
**kwargs
1919
}
2020

timm/models/layers/se.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
class SEModule(nn.Module):
55

6-
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
6+
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None):
77
super(SEModule, self).__init__()
88
self.avg_pool = nn.AdaptiveAvgPool2d(1)
9-
reduction_channels = max(channels // reduction, 8)
9+
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
1010
self.fc1 = nn.Conv2d(
1111
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
1212
self.act = act_layer(inplace=True)

timm/models/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
def _cfg(url='', **kwargs):
2323
return {
24-
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
24+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
2525
'crop_pct': 0.875, 'interpolation': 'bilinear',
2626
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
2727
'first_conv': 'conv_stem', 'classifier': 'classifier',

0 commit comments

Comments
 (0)