3030from timm .layers import Format , get_spatial_dim , get_channel_dim
3131from timm .models import get_notrace_modules , get_notrace_functions
3232
33+ import importlib
34+ import os
35+
36+ torch_backend = os .environ .get ('TORCH_BACKEND' )
37+ if torch_backend is not None :
38+ importlib .import_module (torch_backend )
39+ torch_device = os .environ .get ('TORCH_DEVICE' , 'cpu' )
40+ timeout = os .environ .get ('TIMEOUT' )
41+ timeout120 = int (timeout ) if timeout else 120
42+ timeout300 = int (timeout ) if timeout else 300
43+
3344if hasattr (torch ._C , '_jit_set_profiling_executor' ):
3445 # legacy executor is too slow to compile large models for unit tests
3546 # no need for the fusion performance here
@@ -100,7 +111,7 @@ def _get_input_size(model=None, model_name='', target=None):
100111
101112
102113@pytest .mark .base
103- @pytest .mark .timeout (120 )
114+ @pytest .mark .timeout (timeout120 )
104115@pytest .mark .parametrize ('model_name' , list_models (exclude_filters = EXCLUDE_FILTERS ))
105116@pytest .mark .parametrize ('batch_size' , [1 ])
106117def test_model_forward (model_name , batch_size ):
@@ -112,14 +123,16 @@ def test_model_forward(model_name, batch_size):
112123 if max (input_size ) > MAX_FWD_SIZE :
113124 pytest .skip ("Fixed input size model > limit." )
114125 inputs = torch .randn ((batch_size , * input_size ))
126+ inputs = inputs .to (torch_device )
127+ model .to (torch_device )
115128 outputs = model (inputs )
116129
117130 assert outputs .shape [0 ] == batch_size
118131 assert not torch .isnan (outputs ).any (), 'Output included NaNs'
119132
120133
121134@pytest .mark .base
122- @pytest .mark .timeout (120 )
135+ @pytest .mark .timeout (timeout120 )
123136@pytest .mark .parametrize ('model_name' , list_models (exclude_filters = EXCLUDE_FILTERS , name_matches_cfg = True ))
124137@pytest .mark .parametrize ('batch_size' , [2 ])
125138def test_model_backward (model_name , batch_size ):
@@ -133,6 +146,8 @@ def test_model_backward(model_name, batch_size):
133146 model .train ()
134147
135148 inputs = torch .randn ((batch_size , * input_size ))
149+ inputs = inputs .to (torch_device )
150+ model .to (torch_device )
136151 outputs = model (inputs )
137152 if isinstance (outputs , tuple ):
138153 outputs = torch .cat (outputs )
@@ -147,14 +162,15 @@ def test_model_backward(model_name, batch_size):
147162
148163
149164@pytest .mark .cfg
150- @pytest .mark .timeout (300 )
165+ @pytest .mark .timeout (timeout300 )
151166@pytest .mark .parametrize ('model_name' , list_models (
152167 exclude_filters = EXCLUDE_FILTERS + NON_STD_FILTERS , include_tags = True ))
153168@pytest .mark .parametrize ('batch_size' , [1 ])
154169def test_model_default_cfgs (model_name , batch_size ):
155170 """Run a single forward pass with each model"""
156171 model = create_model (model_name , pretrained = False )
157172 model .eval ()
173+ model .to (torch_device )
158174 state_dict = model .state_dict ()
159175 cfg = model .default_cfg
160176
@@ -169,7 +185,7 @@ def test_model_default_cfgs(model_name, batch_size):
169185 not any ([fnmatch .fnmatch (model_name , x ) for x in EXCLUDE_FILTERS ]):
170186 # output sizes only checked if default res <= 448 * 448 to keep resource down
171187 input_size = tuple ([min (x , MAX_FWD_OUT_SIZE ) for x in input_size ])
172- input_tensor = torch .randn ((batch_size , * input_size ))
188+ input_tensor = torch .randn ((batch_size , * input_size ), device = torch_device )
173189
174190 # test forward_features (always unpooled)
175191 outputs = model .forward_features (input_tensor )
@@ -180,12 +196,14 @@ def test_model_default_cfgs(model_name, batch_size):
180196
181197 # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
182198 model .reset_classifier (0 )
199+ model .to (torch_device )
183200 outputs = model .forward (input_tensor )
184201 assert len (outputs .shape ) == 2
185202 assert outputs .shape [1 ] == model .num_features
186203
187204 # test model forward without pooling and classifier
188205 model .reset_classifier (0 , '' ) # reset classifier and set global pooling to pass-through
206+ model .to (torch_device )
189207 outputs = model .forward (input_tensor )
190208 assert len (outputs .shape ) == 4
191209 if not isinstance (model , (timm .models .MobileNetV3 , timm .models .GhostNet , timm .models .RepGhostNet , timm .models .VGG )):
@@ -195,6 +213,7 @@ def test_model_default_cfgs(model_name, batch_size):
195213 if 'pruned' not in model_name : # FIXME better pruned model handling
196214 # test classifier + global pool deletion via __init__
197215 model = create_model (model_name , pretrained = False , num_classes = 0 , global_pool = '' ).eval ()
216+ model .to (torch_device )
198217 outputs = model .forward (input_tensor )
199218 assert len (outputs .shape ) == 4
200219 if not isinstance (model , (timm .models .MobileNetV3 , timm .models .GhostNet , timm .models .RepGhostNet , timm .models .VGG )):
@@ -218,21 +237,22 @@ def test_model_default_cfgs(model_name, batch_size):
218237
219238
220239@pytest .mark .cfg
221- @pytest .mark .timeout (300 )
240+ @pytest .mark .timeout (timeout300 )
222241@pytest .mark .parametrize ('model_name' , list_models (filter = NON_STD_FILTERS , exclude_filters = NON_STD_EXCLUDE_FILTERS , include_tags = True ))
223242@pytest .mark .parametrize ('batch_size' , [1 ])
224243def test_model_default_cfgs_non_std (model_name , batch_size ):
225244 """Run a single forward pass with each model"""
226245 model = create_model (model_name , pretrained = False )
227246 model .eval ()
247+ model .to (torch_device )
228248 state_dict = model .state_dict ()
229249 cfg = model .default_cfg
230250
231251 input_size = _get_input_size (model = model )
232252 if max (input_size ) > 320 : # FIXME const
233253 pytest .skip ("Fixed input size model > limit." )
234254
235- input_tensor = torch .randn ((batch_size , * input_size ))
255+ input_tensor = torch .randn ((batch_size , * input_size ), device = torch_device )
236256 feat_dim = getattr (model , 'feature_dim' , None )
237257
238258 outputs = model .forward_features (input_tensor )
@@ -246,6 +266,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
246266
247267 # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
248268 model .reset_classifier (0 )
269+ model .to (torch_device )
249270 outputs = model .forward (input_tensor )
250271 if isinstance (outputs , (tuple , list )):
251272 outputs = outputs [0 ]
@@ -254,6 +275,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
254275 assert outputs .shape [feat_dim ] == model .num_features , 'pooled num_features != config'
255276
256277 model = create_model (model_name , pretrained = False , num_classes = 0 ).eval ()
278+ model .to (torch_device )
257279 outputs = model .forward (input_tensor )
258280 if isinstance (outputs , (tuple , list )):
259281 outputs = outputs [0 ]
@@ -297,7 +319,7 @@ def test_model_features_pretrained(model_name, batch_size):
297319
298320
299321@pytest .mark .torchscript
300- @pytest .mark .timeout (120 )
322+ @pytest .mark .timeout (timeout120 )
301323@pytest .mark .parametrize (
302324 'model_name' , list_models (exclude_filters = EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS , name_matches_cfg = True ))
303325@pytest .mark .parametrize ('batch_size' , [1 ])
@@ -312,6 +334,7 @@ def test_model_forward_torchscript(model_name, batch_size):
312334 model .eval ()
313335
314336 model = torch .jit .script (model )
337+ model .to (torch_device )
315338 outputs = model (torch .randn ((batch_size , * input_size )))
316339
317340 assert outputs .shape [0 ] == batch_size
0 commit comments