Skip to content

Commit e628ed7

Browse files
joao-alex-cunharwightman
authored andcommitted
device agnostic testing
1 parent 7c685a4 commit e628ed7

File tree

3 files changed

+62
-22
lines changed

3 files changed

+62
-22
lines changed

tests/test_layers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33

44
from timm.layers import create_act_layer, set_layer_config
55

6+
import importlib
7+
import os
8+
9+
torch_backend = os.environ.get('TORCH_BACKEND')
10+
if torch_backend is not None:
11+
importlib.import_module(torch_backend)
12+
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
613

714
class MLP(nn.Module):
815
def __init__(self, act_layer="relu", inplace=True):
@@ -30,6 +37,9 @@ def _run(x, act_layer=''):
3037
l = (out - 0).pow(2).sum()
3138
return l
3239

40+
x = x.to(device=torch_device)
41+
m.to(device=torch_device)
42+
3343
out_me = _run(x)
3444

3545
with set_layer_config(scriptable=True):

tests/test_models.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@
3030
from timm.layers import Format, get_spatial_dim, get_channel_dim
3131
from timm.models import get_notrace_modules, get_notrace_functions
3232

33+
import importlib
34+
import os
35+
36+
torch_backend = os.environ.get('TORCH_BACKEND')
37+
if torch_backend is not None:
38+
importlib.import_module(torch_backend)
39+
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
40+
timeout = os.environ.get('TIMEOUT')
41+
timeout120 = int(timeout) if timeout else 120
42+
timeout300 = int(timeout) if timeout else 300
43+
3344
if hasattr(torch._C, '_jit_set_profiling_executor'):
3445
# legacy executor is too slow to compile large models for unit tests
3546
# no need for the fusion performance here
@@ -100,7 +111,7 @@ def _get_input_size(model=None, model_name='', target=None):
100111

101112

102113
@pytest.mark.base
103-
@pytest.mark.timeout(120)
114+
@pytest.mark.timeout(timeout120)
104115
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
105116
@pytest.mark.parametrize('batch_size', [1])
106117
def test_model_forward(model_name, batch_size):
@@ -112,14 +123,16 @@ def test_model_forward(model_name, batch_size):
112123
if max(input_size) > MAX_FWD_SIZE:
113124
pytest.skip("Fixed input size model > limit.")
114125
inputs = torch.randn((batch_size, *input_size))
126+
inputs = inputs.to(torch_device)
127+
model.to(torch_device)
115128
outputs = model(inputs)
116129

117130
assert outputs.shape[0] == batch_size
118131
assert not torch.isnan(outputs).any(), 'Output included NaNs'
119132

120133

121134
@pytest.mark.base
122-
@pytest.mark.timeout(120)
135+
@pytest.mark.timeout(timeout120)
123136
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
124137
@pytest.mark.parametrize('batch_size', [2])
125138
def test_model_backward(model_name, batch_size):
@@ -133,6 +146,8 @@ def test_model_backward(model_name, batch_size):
133146
model.train()
134147

135148
inputs = torch.randn((batch_size, *input_size))
149+
inputs = inputs.to(torch_device)
150+
model.to(torch_device)
136151
outputs = model(inputs)
137152
if isinstance(outputs, tuple):
138153
outputs = torch.cat(outputs)
@@ -147,14 +162,15 @@ def test_model_backward(model_name, batch_size):
147162

148163

149164
@pytest.mark.cfg
150-
@pytest.mark.timeout(300)
165+
@pytest.mark.timeout(timeout300)
151166
@pytest.mark.parametrize('model_name', list_models(
152167
exclude_filters=EXCLUDE_FILTERS + NON_STD_FILTERS, include_tags=True))
153168
@pytest.mark.parametrize('batch_size', [1])
154169
def test_model_default_cfgs(model_name, batch_size):
155170
"""Run a single forward pass with each model"""
156171
model = create_model(model_name, pretrained=False)
157172
model.eval()
173+
model.to(torch_device)
158174
state_dict = model.state_dict()
159175
cfg = model.default_cfg
160176

@@ -169,7 +185,7 @@ def test_model_default_cfgs(model_name, batch_size):
169185
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
170186
# output sizes only checked if default res <= 448 * 448 to keep resource down
171187
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
172-
input_tensor = torch.randn((batch_size, *input_size))
188+
input_tensor = torch.randn((batch_size, *input_size), device=torch_device)
173189

174190
# test forward_features (always unpooled)
175191
outputs = model.forward_features(input_tensor)
@@ -180,12 +196,14 @@ def test_model_default_cfgs(model_name, batch_size):
180196

181197
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
182198
model.reset_classifier(0)
199+
model.to(torch_device)
183200
outputs = model.forward(input_tensor)
184201
assert len(outputs.shape) == 2
185202
assert outputs.shape[1] == model.num_features
186203

187204
# test model forward without pooling and classifier
188205
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
206+
model.to(torch_device)
189207
outputs = model.forward(input_tensor)
190208
assert len(outputs.shape) == 4
191209
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
@@ -195,6 +213,7 @@ def test_model_default_cfgs(model_name, batch_size):
195213
if 'pruned' not in model_name: # FIXME better pruned model handling
196214
# test classifier + global pool deletion via __init__
197215
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
216+
model.to(torch_device)
198217
outputs = model.forward(input_tensor)
199218
assert len(outputs.shape) == 4
200219
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
@@ -218,21 +237,22 @@ def test_model_default_cfgs(model_name, batch_size):
218237

219238

220239
@pytest.mark.cfg
221-
@pytest.mark.timeout(300)
240+
@pytest.mark.timeout(timeout300)
222241
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
223242
@pytest.mark.parametrize('batch_size', [1])
224243
def test_model_default_cfgs_non_std(model_name, batch_size):
225244
"""Run a single forward pass with each model"""
226245
model = create_model(model_name, pretrained=False)
227246
model.eval()
247+
model.to(torch_device)
228248
state_dict = model.state_dict()
229249
cfg = model.default_cfg
230250

231251
input_size = _get_input_size(model=model)
232252
if max(input_size) > 320: # FIXME const
233253
pytest.skip("Fixed input size model > limit.")
234254

235-
input_tensor = torch.randn((batch_size, *input_size))
255+
input_tensor = torch.randn((batch_size, *input_size), device=torch_device)
236256
feat_dim = getattr(model, 'feature_dim', None)
237257

238258
outputs = model.forward_features(input_tensor)
@@ -246,6 +266,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
246266

247267
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
248268
model.reset_classifier(0)
269+
model.to(torch_device)
249270
outputs = model.forward(input_tensor)
250271
if isinstance(outputs, (tuple, list)):
251272
outputs = outputs[0]
@@ -254,6 +275,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
254275
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
255276

256277
model = create_model(model_name, pretrained=False, num_classes=0).eval()
278+
model.to(torch_device)
257279
outputs = model.forward(input_tensor)
258280
if isinstance(outputs, (tuple, list)):
259281
outputs = outputs[0]
@@ -297,7 +319,7 @@ def test_model_features_pretrained(model_name, batch_size):
297319

298320

299321
@pytest.mark.torchscript
300-
@pytest.mark.timeout(120)
322+
@pytest.mark.timeout(timeout120)
301323
@pytest.mark.parametrize(
302324
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
303325
@pytest.mark.parametrize('batch_size', [1])
@@ -312,6 +334,7 @@ def test_model_forward_torchscript(model_name, batch_size):
312334
model.eval()
313335

314336
model = torch.jit.script(model)
337+
model.to(torch_device)
315338
outputs = model(torch.randn((batch_size, *input_size)))
316339

317340
assert outputs.shape[0] == batch_size

tests/test_optim.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515

1616
from timm.optim import create_optimizer_v2
1717

18+
import importlib
19+
import os
20+
21+
torch_backend = os.environ.get('TORCH_BACKEND')
22+
if torch_backend is not None:
23+
importlib.import_module(torch_backend)
24+
torch_device = os.environ.get('TORCH_DEVICE', 'cuda')
1825

1926
# HACK relying on internal PyTorch test functionality for comparisons that I don't want to write
2027
torch_tc = TestCase()
@@ -61,7 +68,7 @@ def _test_state_dict(weight, bias, input, constructor):
6168

6269
def fn_base(optimizer, weight, bias):
6370
optimizer.zero_grad()
64-
i = input_cuda if weight.is_cuda else input
71+
i = input_device if weight.device.type != 'cpu' else input
6572
loss = (weight.mv(i) + bias).pow(2).sum()
6673
loss.backward()
6774
return loss
@@ -97,28 +104,28 @@ def fn_base(optimizer, weight, bias):
97104

98105
# Check that state dict can be loaded even when we cast parameters
99106
# to a different type and move to a different device.
100-
if not torch.cuda.is_available():
107+
if torch_device == 'cpu':
101108
return
102109

103110
with torch.no_grad():
104-
input_cuda = Parameter(input.clone().detach().float().cuda())
105-
weight_cuda = Parameter(weight.clone().detach().cuda())
106-
bias_cuda = Parameter(bias.clone().detach().cuda())
107-
optimizer_cuda = constructor(weight_cuda, bias_cuda)
108-
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
111+
input_device = Parameter(input.clone().detach().float().to(torch_device))
112+
weight_device = Parameter(weight.clone().detach().to(torch_device))
113+
bias_device = Parameter(bias.clone().detach().to(torch_device))
114+
optimizer_device = constructor(weight_device, bias_device)
115+
fn_device = functools.partial(fn_base, optimizer_device, weight_device, bias_device)
109116

110117
state_dict = deepcopy(optimizer.state_dict())
111118
state_dict_c = deepcopy(optimizer.state_dict())
112-
optimizer_cuda.load_state_dict(state_dict_c)
119+
optimizer_device.load_state_dict(state_dict_c)
113120

114121
# Make sure state dict wasn't modified
115122
torch_tc.assertEqual(state_dict, state_dict_c)
116123

117124
for _i in range(20):
118125
optimizer.step(fn)
119-
optimizer_cuda.step(fn_cuda)
120-
torch_tc.assertEqual(weight, weight_cuda)
121-
torch_tc.assertEqual(bias, bias_cuda)
126+
optimizer_device.step(fn_device)
127+
torch_tc.assertEqual(weight, weight_device)
128+
torch_tc.assertEqual(bias, bias_device)
122129

123130
# validate deepcopy() copies all public attributes
124131
def getPublicAttr(obj):
@@ -152,12 +159,12 @@ def _test_basic_cases(constructor, scheduler_constructors=None):
152159
scheduler_constructors
153160
)
154161
# CUDA
155-
if not torch.cuda.is_available():
162+
if torch_device == 'cpu':
156163
return
157164
_test_basic_cases_template(
158-
torch.randn(10, 5).cuda(),
159-
torch.randn(10).cuda(),
160-
torch.randn(5).cuda(),
165+
torch.randn(10, 5).to(torch_device),
166+
torch.randn(10).to(torch_device),
167+
torch.randn(5).to(torch_device),
161168
constructor,
162169
scheduler_constructors
163170
)

0 commit comments

Comments
 (0)