Skip to content

Commit d07d015

Browse files
authored
Merge pull request #1249 from okojoalg/sequencer
Add Sequencer
2 parents d30685c + 39b725e commit d07d015

File tree

3 files changed

+426
-4
lines changed

3 files changed

+426
-4
lines changed

tests/test_models.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
NON_STD_FILTERS = [
2626
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
2727
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
28-
'poolformer_*', 'volo_*']
28+
'poolformer_*', 'volo_*', 'sequencer2d_*']
2929
NUM_NON_STD = len(NON_STD_FILTERS)
3030

3131
# exclude models that cause specific test failures
@@ -202,28 +202,32 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
202202
pytest.skip("Fixed input size model > limit.")
203203

204204
input_tensor = torch.randn((batch_size, *input_size))
205+
feat_dim = getattr(model, 'feature_dim', None)
205206

206207
outputs = model.forward_features(input_tensor)
207208
if isinstance(outputs, (tuple, list)):
208209
# cannot currently verify multi-tensor output.
209210
pass
210211
else:
211-
feat_dim = -1 if outputs.ndim == 3 else 1
212+
if feat_dim is None:
213+
feat_dim = -1 if outputs.ndim == 3 else 1
212214
assert outputs.shape[feat_dim] == model.num_features
213215

214216
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
215217
model.reset_classifier(0)
216218
outputs = model.forward(input_tensor)
217219
if isinstance(outputs, (tuple, list)):
218220
outputs = outputs[0]
219-
feat_dim = -1 if outputs.ndim == 3 else 1
221+
if feat_dim is None:
222+
feat_dim = -1 if outputs.ndim == 3 else 1
220223
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
221224

222225
model = create_model(model_name, pretrained=False, num_classes=0).eval()
223226
outputs = model.forward(input_tensor)
224227
if isinstance(outputs, (tuple, list)):
225228
outputs = outputs[0]
226-
feat_dim = -1 if outputs.ndim == 3 else 1
229+
if feat_dim is None:
230+
feat_dim = -1 if outputs.ndim == 3 else 1
227231
assert outputs.shape[feat_dim] == model.num_features
228232

229233
# check classifier name matches default_cfg

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .rexnet import *
4040
from .selecsls import *
4141
from .senet import *
42+
from .sequencer import *
4243
from .sknet import *
4344
from .swin_transformer import *
4445
from .swin_transformer_v2_cr import *

0 commit comments

Comments
 (0)