Skip to content

Commit f332fc2

Browse files
committed
Fix some test failures, torchscript issues
1 parent 6e559e9 commit f332fc2

File tree

4 files changed

+5
-6
lines changed

4 files changed

+5
-6
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
NON_STD_FILTERS = [
2828
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
2929
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
30-
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*']
30+
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*']
3131
NUM_NON_STD = len(NON_STD_FILTERS)
3232

3333
# exclude models that cause specific test failures

timm/models/efficientformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
def _cfg(url='', **kwargs):
2727
return {
2828
'url': url,
29-
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
29+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
3030
'crop_pct': .95, 'interpolation': 'bicubic',
3131
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
3232
'first_conv': 'stem.conv1', 'classifier': 'head',

timm/models/gcvit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,7 @@ def __init__(
209209

210210
def forward(self, x, q_global: Optional[torch.Tensor] = None):
211211
B, N, C = x.shape
212-
if self.use_global:
213-
_assert(q_global is not None, 'q_global must be passed in global mode')
212+
if self.use_global and q_global is not None:
214213
_assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal')
215214

216215
kv = self.qkv(x)

timm/models/pvt_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,6 @@ def __init__(
286286
self.num_classes = num_classes
287287
assert global_pool in ('avg', '')
288288
self.global_pool = global_pool
289-
self.img_size = to_2tuple(img_size) if img_size is not None else None
290289
self.depths = depths
291290
num_stages = len(depths)
292291
mlp_ratios = to_ntuple(num_stages)(mlp_ratios)
@@ -324,7 +323,8 @@ def __init__(
324323
cur += depths[i]
325324

326325
# classification head
327-
self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
326+
self.num_features = embed_dims[-1]
327+
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
328328

329329
self.apply(self._init_weights)
330330

0 commit comments

Comments
 (0)