Skip to content

Commit 24720ab

Browse files
committed
Merge branch 'master' into attn_update
2 parents 4027412 + 1c9284c commit 24720ab

File tree

6 files changed

+928
-10
lines changed

6 files changed

+928
-10
lines changed

tests/test_models.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# transformer models don't support many of the spatial / feature based model functionalities
1818
NON_STD_FILTERS = [
1919
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
20-
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*']
20+
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*']
2121
NUM_NON_STD = len(NON_STD_FILTERS)
2222

2323
# exclude models that cause specific test failures
@@ -188,23 +188,22 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
188188

189189
input_tensor = torch.randn((batch_size, *input_size))
190190

191-
# test forward_features (always unpooled)
192191
outputs = model.forward_features(input_tensor)
193-
if isinstance(outputs, tuple):
192+
if isinstance(outputs, (tuple, list)):
194193
outputs = outputs[0]
195194
assert outputs.shape[1] == model.num_features
196195

197196
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
198197
model.reset_classifier(0)
199198
outputs = model.forward(input_tensor)
200-
if isinstance(outputs, tuple):
199+
if isinstance(outputs, (tuple, list)):
201200
outputs = outputs[0]
202201
assert len(outputs.shape) == 2
203202
assert outputs.shape[1] == model.num_features
204203

205204
model = create_model(model_name, pretrained=False, num_classes=0).eval()
206205
outputs = model.forward(input_tensor)
207-
if isinstance(outputs, tuple):
206+
if isinstance(outputs, (tuple, list)):
208207
outputs = outputs[0]
209208
assert len(outputs.shape) == 2
210209
assert outputs.shape[1] == model.num_features

tests/test_optim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,10 @@ def test_sgd(optimizer):
319319
# lambda opt: ReduceLROnPlateau(opt)]
320320
# )
321321
_test_basic_cases(
322-
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1)
322+
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1)
323323
)
324324
_test_basic_cases(
325-
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=.1)
325+
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1, weight_decay=.1)
326326
)
327327
_test_rosenbrock(
328328
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)

timm/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from .beit import *
12
from .byoanet import *
23
from .byobnet import *
34
from .cait import *
45
from .coat import *
56
from .convit import *
7+
from .crossvit import *
68
from .cspnet import *
79
from .densenet import *
810
from .dla import *
@@ -36,6 +38,7 @@
3638
from .swin_transformer import *
3739
from .tnt import *
3840
from .tresnet import *
41+
from .twins import *
3942
from .vgg import *
4043
from .visformer import *
4144
from .vision_transformer import *
@@ -44,7 +47,6 @@
4447
from .xception import *
4548
from .xception_aligned import *
4649
from .xcit import *
47-
from .twins import *
4850

4951
from .factory import create_model, split_model_name, safe_model_name
5052
from .helpers import load_checkpoint, resume_checkpoint, model_parameters

0 commit comments

Comments
 (0)