11import pytest
22import torch
3+ import platform
4+ import os
5+ import fnmatch
36
47from timm import list_models , create_model
58
6- MAX_FWD_SIZE = 320
9+
10+ if 'GITHUB_ACTIONS' in os .environ and 'Linux' in platform .system ():
11+ # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
12+ EXCLUDE_FILTERS = ['*efficientnet_l2*' ]
13+ else :
14+ EXCLUDE_FILTERS = []
15+ MAX_FWD_SIZE = 384
716MAX_BWD_SIZE = 128
817MAX_FWD_FEAT_SIZE = 448
918
1019
1120@pytest .mark .timeout (120 )
12- @pytest .mark .parametrize ('model_name' , list_models ())
21+ @pytest .mark .parametrize ('model_name' , list_models (exclude_filters = EXCLUDE_FILTERS ))
1322@pytest .mark .parametrize ('batch_size' , [1 ])
1423def test_model_forward (model_name , batch_size ):
1524 """Run a single forward pass with each model"""
@@ -28,7 +37,8 @@ def test_model_forward(model_name, batch_size):
2837
2938
3039@pytest .mark .timeout (120 )
31- @pytest .mark .parametrize ('model_name' , list_models (exclude_filters = 'dla*' )) # DLA models have an issue TBD
40+ # DLA models have an issue TBD, add them to exclusions
41+ @pytest .mark .parametrize ('model_name' , list_models (exclude_filters = EXCLUDE_FILTERS + ['dla*' ]))
3242@pytest .mark .parametrize ('batch_size' , [2 ])
3343def test_model_backward (model_name , batch_size ):
3444 """Run a single forward pass with each model"""
@@ -65,7 +75,8 @@ def test_model_default_cfgs(model_name, batch_size):
6575 pool_size = cfg ['pool_size' ]
6676 input_size = model .default_cfg ['input_size' ]
6777
68- if all ([x <= MAX_FWD_FEAT_SIZE for x in input_size ]) and 'efficientnet_l2' not in model_name :
78+ if all ([x <= MAX_FWD_FEAT_SIZE for x in input_size ]) and \
79+ not any ([fnmatch .fnmatch (model_name , x ) for x in EXCLUDE_FILTERS ]):
6980 # pool size only checked if default res <= 448 * 448 to keep resource down
7081 input_size = tuple ([min (x , MAX_FWD_FEAT_SIZE ) for x in input_size ])
7182 outputs = model .forward_features (torch .randn ((batch_size , * input_size )))
0 commit comments