Skip to content

Commit 7ab2491

Browse files
committed
Better handling of crossvit for tests / forward_features, fix torchscript regression in my changes
1 parent 702982d commit 7ab2491

File tree

2 files changed

+12
-21
lines changed

2 files changed

+12
-21
lines changed

tests/test_models.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,25 +188,22 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
188188

189189
input_tensor = torch.randn((batch_size, *input_size))
190190

191-
# test forward_features (always unpooled)
192-
if 'crossvit' not in model_name:
193-
# FIXME remove crossvit exception
194-
outputs = model.forward_features(input_tensor)
195-
if isinstance(outputs, tuple):
196-
outputs = outputs[0]
197-
assert outputs.shape[1] == model.num_features
191+
outputs = model.forward_features(input_tensor)
192+
if isinstance(outputs, (tuple, list)):
193+
outputs = outputs[0]
194+
assert outputs.shape[1] == model.num_features
198195

199196
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
200197
model.reset_classifier(0)
201198
outputs = model.forward(input_tensor)
202-
if isinstance(outputs, tuple):
199+
if isinstance(outputs, (tuple, list)):
203200
outputs = outputs[0]
204201
assert len(outputs.shape) == 2
205202
assert outputs.shape[1] == model.num_features
206203

207204
model = create_model(model_name, pretrained=False, num_classes=0).eval()
208205
outputs = model.forward(input_tensor)
209-
if isinstance(outputs, tuple):
206+
if isinstance(outputs, (tuple, list)):
210207
outputs = outputs[0]
211208
assert len(outputs.shape) == 2
212209
assert outputs.shape[1] == model.num_features

timm/models/crossvit.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,9 @@ def __init__(
268268
super().__init__()
269269

270270
self.num_classes = num_classes
271-
if not isinstance(img_size, (tuple, list)):
272-
img_size = to_2tuple(img_size)
273-
self.img_size = img_size
274-
if not isinstance(img_scale, (tuple, list)):
275-
img_scale = to_2tuple(img_scale)
276-
self.img_size_scaled = [tuple([int(sj * si) for sj in img_size]) for si in img_scale]
271+
self.img_size = to_2tuple(img_size)
272+
img_scale = to_2tuple(img_scale)
273+
self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
277274
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
278275
self.num_branches = len(patch_size)
279276
self.embed_dim = embed_dim
@@ -346,7 +343,7 @@ def forward_features(self, x):
346343
xs = []
347344
for i, patch_embed in enumerate(self.patch_embed):
348345
ss = self.img_size_scaled[i]
349-
x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic') if H != ss[0] else x
346+
x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) if H != ss[0] else x
350347
tmp = patch_embed(x_)
351348
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
352349
cls_tokens = cls_tokens.expand(B, -1, -1)
@@ -361,15 +358,12 @@ def forward_features(self, x):
361358

362359
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
363360
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
364-
return tuple([x[:, 0] for x in xs])
361+
return [x[:, 0] for x in xs]
365362

366363
def forward(self, x):
367364
xs = self.forward_features(x)
368365
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
369-
if isinstance(self.head[0], nn.Identity):
370-
# FIXME to pass current passthrough features tests, could use better approach
371-
ce_logits = tuple(ce_logits)
372-
else:
366+
if not isinstance(self.head[0], nn.Identity):
373367
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
374368
return ce_logits
375369

0 commit comments

Comments
 (0)