Skip to content

Commit afb6bd0

Browse files
committed
Add backward and default_cfg tests and fix a few issues found. Fix #153
1 parent ea2e59c commit afb6bd0

File tree

12 files changed

+95
-36
lines changed

12 files changed

+95
-36
lines changed

tests/test_inference.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

tests/test_models.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
import torch
3+
4+
from timm import list_models, create_model
5+
6+
7+
@pytest.mark.timeout(120)
8+
@pytest.mark.parametrize('model_name', list_models())
9+
@pytest.mark.parametrize('batch_size', [1])
10+
def test_model_forward(model_name, batch_size):
11+
"""Run a single forward pass with each model"""
12+
model = create_model(model_name, pretrained=False)
13+
model.eval()
14+
15+
input_size = model.default_cfg['input_size']
16+
if any([x > 448 for x in input_size]):
17+
# cap forward test at max res 448 * 448 to keep resource down
18+
input_size = tuple([min(x, 448) for x in input_size])
19+
inputs = torch.randn((batch_size, *input_size))
20+
outputs = model(inputs)
21+
22+
assert outputs.shape[0] == batch_size
23+
assert not torch.isnan(outputs).any(), 'Output included NaNs'
24+
25+
26+
@pytest.mark.timeout(120)
27+
@pytest.mark.parametrize('model_name', list_models(exclude_filters='dla*')) # DLA models have an issue TBD
28+
@pytest.mark.parametrize('batch_size', [2])
29+
def test_model_backward(model_name, batch_size):
30+
"""Run a single forward pass with each model"""
31+
model = create_model(model_name, pretrained=False, num_classes=42)
32+
num_params = sum([x.numel() for x in model.parameters()])
33+
model.eval()
34+
35+
input_size = model.default_cfg['input_size']
36+
if any([x > 128 for x in input_size]):
37+
# cap backward test at 128 * 128 to keep resource usage down
38+
input_size = tuple([min(x, 128) for x in input_size])
39+
inputs = torch.randn((batch_size, *input_size))
40+
outputs = model(inputs)
41+
outputs.mean().backward()
42+
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
43+
44+
assert outputs.shape[-1] == 42
45+
assert num_params == num_grad, 'Some parameters are missing gradients'
46+
assert not torch.isnan(outputs).any(), 'Output included NaNs'
47+
48+
49+
@pytest.mark.timeout(120)
50+
@pytest.mark.parametrize('model_name', list_models())
51+
@pytest.mark.parametrize('batch_size', [1])
52+
def test_model_default_cfgs(model_name, batch_size):
53+
"""Run a single forward pass with each model"""
54+
model = create_model(model_name, pretrained=False)
55+
model.eval()
56+
state_dict = model.state_dict()
57+
cfg = model.default_cfg
58+
59+
classifier = cfg['classifier']
60+
first_conv = cfg['first_conv']
61+
pool_size = cfg['pool_size']
62+
input_size = model.default_cfg['input_size']
63+
64+
if all([x <= 448 for x in input_size]):
65+
# pool size only checked if default res <= 448 * 448 to keep resource down
66+
input_size = tuple([min(x, 448) for x in input_size])
67+
outputs = model.forward_features(torch.randn((batch_size, *input_size)))
68+
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
69+
assert any([k.startswith(cfg['classifier']) for k in state_dict.keys()]), f'{classifier} not in model params'
70+
assert any([k.startswith(cfg['first_conv']) for k in state_dict.keys()]), f'{first_conv} not in model params'

timm/models/dla.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,11 @@ def __init__(self, levels, block, in_channels, out_channels, stride=1,
237237

238238
def forward(self, x, residual=None, children=None):
239239
children = [] if children is None else children
240-
bottom = self.downsample(x) if self.downsample else x
241-
residual = self.project(bottom) if self.project else bottom
240+
# FIXME the way downsample / project are used here and residual is passed to next level up
241+
# the tree, the residual is overridden and some project weights are thus never used and
242+
# have no gradients. This appears to be an issue with the original model / weights.
243+
bottom = self.downsample(x) if self.downsample is not None else x
244+
residual = self.project(bottom) if self.project is not None else bottom
242245
if self.level_root:
243246
children.append(bottom)
244247
x1 = self.tree1(x, residual)
@@ -354,7 +357,8 @@ def dla60_res2next(pretrained=None, num_classes=1000, in_chans=3, **kwargs):
354357
@register_model
355358
def dla34(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-34
356359
default_cfg = default_cfgs['dla34']
357-
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, **kwargs)
360+
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic,
361+
num_classes=num_classes, in_chans=in_chans, **kwargs)
358362
model.default_cfg = default_cfg
359363
if pretrained:
360364
load_pretrained(model, default_cfg, num_classes, in_chans)

timm/models/gluon_xception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
'url': '',
3737
'input_size': (3, 299, 299),
3838
'crop_pct': 0.875,
39-
'pool_size': (10, 10),
39+
'pool_size': (5, 5),
4040
'interpolation': 'bicubic',
4141
'mean': IMAGENET_DEFAULT_MEAN,
4242
'std': IMAGENET_DEFAULT_STD,

timm/models/hrnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _cfg(url='', **kwargs):
3434
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
3535
'crop_pct': 0.875, 'interpolation': 'bilinear',
3636
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
37-
'first_conv': 'conv1', 'classifier': 'fc',
37+
'first_conv': 'conv1', 'classifier': 'classifier',
3838
**kwargs
3939
}
4040

timm/models/inception_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _cfg(url='', **kwargs):
1515
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
1616
'crop_pct': 0.875, 'interpolation': 'bicubic',
1717
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
18-
'first_conv': 'conv1', 'classifier': 'fc',
18+
'first_conv': 'Conv2d_1a_3x3', 'classifier': 'fc',
1919
**kwargs
2020
}
2121

timm/models/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
def _cfg(url='', **kwargs):
2323
return {
24-
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
24+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
2525
'crop_pct': 0.875, 'interpolation': 'bilinear',
2626
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
2727
'first_conv': 'conv_stem', 'classifier': 'classifier',

timm/models/nasnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
'mean': (0.5, 0.5, 0.5),
2020
'std': (0.5, 0.5, 0.5),
2121
'num_classes': 1001,
22-
'first_conv': 'conv_0.conv',
22+
'first_conv': 'conv0.conv',
2323
'classifier': 'last_linear',
2424
},
2525
}
@@ -612,7 +612,7 @@ def nasnetalarge(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
612612
"""NASNet-A large model architecture.
613613
"""
614614
default_cfg = default_cfgs['nasnetalarge']
615-
model = NASNetALarge(num_classes=1000, in_chans=in_chans, **kwargs)
615+
model = NASNetALarge(num_classes=num_classes, in_chans=in_chans, **kwargs)
616616
model.default_cfg = default_cfg
617617
if pretrained:
618618
load_pretrained(model, default_cfg, num_classes, in_chans)

timm/models/resnest.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,14 @@ def _cfg(url='', **kwargs):
3838
'resnest50d': _cfg(
3939
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'),
4040
'resnest101e': _cfg(
41-
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)),
41+
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth',
42+
input_size=(3, 256, 256), pool_size=(8, 8)),
4243
'resnest200e': _cfg(
43-
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)),
44+
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth',
45+
input_size=(3, 320, 320), pool_size=(10, 10)),
4446
'resnest269e': _cfg(
45-
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)),
47+
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth',
48+
input_size=(3, 416, 416), pool_size=(13, 13)),
4649
'resnest50d_4s2x40d': _cfg(
4750
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth',
4851
interpolation='bicubic'),

timm/models/selecsls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
def _cfg(url='', **kwargs):
2727
return {
2828
'url': url,
29-
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (3, 3),
29+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4),
3030
'crop_pct': 0.875, 'interpolation': 'bilinear',
3131
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
3232
'first_conv': 'stem', 'classifier': 'fc',

0 commit comments

Comments
 (0)