Skip to content

Commit 39b725e

Browse files
committed
Fix tests for rank-4 output where feature channels dim is -1 (3) and not 1
1 parent d79f3d9 commit 39b725e

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

tests/test_models.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,28 +202,32 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
202202
pytest.skip("Fixed input size model > limit.")
203203

204204
input_tensor = torch.randn((batch_size, *input_size))
205+
feat_dim = getattr(model, 'feature_dim', None)
205206

206207
outputs = model.forward_features(input_tensor)
207208
if isinstance(outputs, (tuple, list)):
208209
# cannot currently verify multi-tensor output.
209210
pass
210211
else:
211-
feat_dim = -1 if outputs.ndim == 3 else 1
212+
if feat_dim is None:
213+
feat_dim = -1 if outputs.ndim == 3 else 1
212214
assert outputs.shape[feat_dim] == model.num_features
213215

214216
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
215217
model.reset_classifier(0)
216218
outputs = model.forward(input_tensor)
217219
if isinstance(outputs, (tuple, list)):
218220
outputs = outputs[0]
219-
feat_dim = -1 if outputs.ndim == 3 else 1
221+
if feat_dim is None:
222+
feat_dim = -1 if outputs.ndim == 3 else 1
220223
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
221224

222225
model = create_model(model_name, pretrained=False, num_classes=0).eval()
223226
outputs = model.forward(input_tensor)
224227
if isinstance(outputs, (tuple, list)):
225228
outputs = outputs[0]
226-
feat_dim = -1 if outputs.ndim == 3 else 1
229+
if feat_dim is None:
230+
feat_dim = -1 if outputs.ndim == 3 else 1
227231
assert outputs.shape[feat_dim] == model.num_features
228232

229233
# check classifier name matches default_cfg

timm/models/sequencer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def __init__(
288288
self.num_classes = num_classes
289289
self.global_pool = global_pool
290290
self.num_features = embed_dims[-1] # num_features for consistency with other models
291+
self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC)
291292
self.embed_dims = embed_dims
292293
self.stem = PatchEmbed(
293294
img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans,
@@ -333,7 +334,7 @@ def get_classifier(self):
333334

334335
def reset_classifier(self, num_classes, global_pool=None):
335336
self.num_classes = num_classes
336-
if self.global_pool is not None:
337+
if global_pool is not None:
337338
assert global_pool in ('', 'avg')
338339
self.global_pool = global_pool
339340
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

0 commit comments

Comments
 (0)