Skip to content

Commit 14f370c

Browse files
authored
Merge pull request #1259 from rwightman/swin_v2
Official SwinV2 models
2 parents d07d015 + 347308f commit 14f370c

File tree

10 files changed

+1029
-110
lines changed

10 files changed

+1029
-110
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
2323

2424
## What's New
2525

26+
### May 13, 2022
27+
* Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript.
28+
* Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects.
29+
* More Vision Transformer relative position / residual post-norm experiments w/ 512 dim
30+
* `vit_relpos_small_patch16_224` - 81.5 @ 224, 82.5 @ 320 -- rel pos, layer scale, no class token, avg pool
31+
* `vit_relpos_medium_patch16_rpn_224` - 82.3 @ 224, 83.1 @ 320 -- rel pos + res-post-norm, no class token, avg pool
32+
* `vit_relpos_medium_patch16_224` - 82.5 @ 224, 83.3 @ 320 -- rel pos, layer scale, no class token, avg pool
33+
* `vit_relpos_base_patch16_gapcls_224` - 82.8 @ 224, 83.9 @ 320 -- rel pos, layer scale, class token, avg pool (by mistake)
34+
* Bring 512 dim, 8-head 'medium' ViT model variant back to life (after using in a pre DeiT 'small' model for first ViT impl back in 2020)
35+
* Add ViT relative position support for switching btw existing impl and some additions in official Swin-V2 impl for future trials
36+
* Sequencer2D impl (https://arxiv.org/abs/2205.01972), added via PR from author (https://github.com/okojoalg)
2637

2738
### May 2, 2022
2839
* Vision Transformer experiments adding Relative Position (Swin-V2 log-coord) (`vision_transformer_relpos.py`) and Residual Post-Norm branches (from Swin-V2) (`vision_transformer*.py`)
@@ -390,6 +401,7 @@ A full version of the list below with source links can be found in the [document
390401
* ReXNet - https://arxiv.org/abs/2007.00992
391402
* SelecSLS - https://arxiv.org/abs/1907.00837
392403
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586
404+
* Sequencer2D - https://arxiv.org/abs/2205.01972
393405
* Swin S3 (AutoFormerV2) - https://arxiv.org/abs/2111.14725
394406
* Swin Transformer - https://arxiv.org/abs/2103.14030
395407
* Swin Transformer V2 - https://arxiv.org/abs/2111.09883

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
NON_STD_FILTERS = [
2626
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
2727
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
28-
'poolformer_*', 'volo_*', 'sequencer2d_*']
28+
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*']
2929
NUM_NON_STD = len(NON_STD_FILTERS)
3030

3131
# exclude models that cause specific test failures

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .sequencer import *
4343
from .sknet import *
4444
from .swin_transformer import *
45+
from .swin_transformer_v2 import *
4546
from .swin_transformer_v2_cr import *
4647
from .tnt import *
4748
from .tresnet import *

timm/models/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def build_model_with_cfg(
477477
pretrained_cfg: Optional[Dict] = None,
478478
model_cfg: Optional[Any] = None,
479479
feature_cfg: Optional[Dict] = None,
480-
pretrained_strict: bool = False,
480+
pretrained_strict: bool = True,
481481
pretrained_filter_fn: Optional[Callable] = None,
482482
pretrained_custom_load: bool = False,
483483
kwargs_filter: Optional[Tuple[str]] = None,

timm/models/layers/mlp.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@
1010
class Mlp(nn.Module):
1111
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
1212
"""
13-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
13+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
1414
super().__init__()
1515
out_features = out_features or in_features
1616
hidden_features = hidden_features or in_features
17+
bias = to_2tuple(bias)
1718
drop_probs = to_2tuple(drop)
1819

19-
self.fc1 = nn.Linear(in_features, hidden_features)
20+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
2021
self.act = act_layer()
2122
self.drop1 = nn.Dropout(drop_probs[0])
22-
self.fc2 = nn.Linear(hidden_features, out_features)
23+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
2324
self.drop2 = nn.Dropout(drop_probs[1])
2425

2526
def forward(self, x):
@@ -35,17 +36,18 @@ class GluMlp(nn.Module):
3536
""" MLP w/ GLU style gating
3637
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
3738
"""
38-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
39+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.):
3940
super().__init__()
4041
out_features = out_features or in_features
4142
hidden_features = hidden_features or in_features
4243
assert hidden_features % 2 == 0
44+
bias = to_2tuple(bias)
4345
drop_probs = to_2tuple(drop)
4446

45-
self.fc1 = nn.Linear(in_features, hidden_features)
47+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
4648
self.act = act_layer()
4749
self.drop1 = nn.Dropout(drop_probs[0])
48-
self.fc2 = nn.Linear(hidden_features // 2, out_features)
50+
self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])
4951
self.drop2 = nn.Dropout(drop_probs[1])
5052

5153
def init_weights(self):
@@ -67,14 +69,16 @@ def forward(self, x):
6769
class GatedMlp(nn.Module):
6870
""" MLP as used in gMLP
6971
"""
70-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
71-
gate_layer=None, drop=0.):
72+
def __init__(
73+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
74+
gate_layer=None, bias=True, drop=0.):
7275
super().__init__()
7376
out_features = out_features or in_features
7477
hidden_features = hidden_features or in_features
78+
bias = to_2tuple(bias)
7579
drop_probs = to_2tuple(drop)
7680

77-
self.fc1 = nn.Linear(in_features, hidden_features)
81+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
7882
self.act = act_layer()
7983
self.drop1 = nn.Dropout(drop_probs[0])
8084
if gate_layer is not None:
@@ -83,7 +87,7 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
8387
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
8488
else:
8589
self.gate = nn.Identity()
86-
self.fc2 = nn.Linear(hidden_features, out_features)
90+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
8791
self.drop2 = nn.Dropout(drop_probs[1])
8892

8993
def forward(self, x):
@@ -100,15 +104,18 @@ class ConvMlp(nn.Module):
100104
""" MLP using 1x1 convs that keeps spatial dims
101105
"""
102106
def __init__(
103-
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
107+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
108+
norm_layer=None, bias=True, drop=0.):
104109
super().__init__()
105110
out_features = out_features or in_features
106111
hidden_features = hidden_features or in_features
107-
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
112+
bias = to_2tuple(bias)
113+
114+
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])
108115
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
109116
self.act = act_layer()
110-
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
111117
self.drop = nn.Dropout(drop)
118+
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])
112119

113120
def forward(self, x):
114121
x = self.fc1(x)

0 commit comments

Comments
 (0)