Skip to content

Commit b7cb8d0

Browse files
committed
Add Swin-V2 Small-NS weights (83.5 @ 224). Add layer scale like 'init_values' via post-norm LN weight scaling
1 parent 001688d commit b7cb8d0

File tree

1 file changed

+50
-2
lines changed

1 file changed

+50
-2
lines changed

timm/models/swin_transformer_v2_cr.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,15 @@ def _cfg(url='', **kwargs):
7676
'swin_v2_cr_small_224': _cfg(
7777
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
7878
input_size=(3, 224, 224), crop_pct=0.9),
79+
'swin_v2_cr_small_ns_224': _cfg(
80+
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
81+
input_size=(3, 224, 224), crop_pct=0.9),
7982
'swin_v2_cr_base_384': _cfg(
8083
url="", input_size=(3, 384, 384), crop_pct=1.0),
8184
'swin_v2_cr_base_224': _cfg(
8285
url="", input_size=(3, 224, 224), crop_pct=0.9),
86+
'swin_v2_cr_base_ns_224': _cfg(
87+
url="", input_size=(3, 224, 224), crop_pct=0.9),
8388
'swin_v2_cr_large_384': _cfg(
8489
url="", input_size=(3, 384, 384), crop_pct=1.0),
8590
'swin_v2_cr_large_224': _cfg(
@@ -179,7 +184,7 @@ def __init__(
179184
hidden_features=meta_hidden_dim,
180185
out_features=num_heads,
181186
act_layer=nn.ReLU,
182-
drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without?
187+
drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without?
183188
)
184189
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
185190
self._make_pair_wise_relative_positions()
@@ -304,6 +309,7 @@ def __init__(
304309
window_size: Tuple[int, int],
305310
shift_size: Tuple[int, int] = (0, 0),
306311
mlp_ratio: float = 4.0,
312+
init_values: float = 0,
307313
drop: float = 0.0,
308314
drop_attn: float = 0.0,
309315
drop_path: float = 0.0,
@@ -317,6 +323,7 @@ def __init__(
317323
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
318324
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
319325
self.window_area = self.window_size[0] * self.window_size[1]
326+
self.init_values: float = init_values
320327

321328
# attn branch
322329
self.attn = WindowMultiHeadAttention(
@@ -345,6 +352,7 @@ def __init__(
345352
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
346353

347354
self._make_attention_mask()
355+
self.init_weights()
348356

349357
def _calc_window_shift(self, target_window_size):
350358
window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)]
@@ -377,6 +385,12 @@ def _make_attention_mask(self) -> None:
377385
attn_mask = None
378386
self.register_buffer("attn_mask", attn_mask, persistent=False)
379387

388+
def init_weights(self):
389+
# extra, module specific weight init
390+
if self.init_values:
391+
nn.init.constant_(self.norm1.weight, self.init_values)
392+
nn.init.constant_(self.norm2.weight, self.init_values)
393+
380394
def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:
381395
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
382396
@@ -435,7 +449,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
435449
Returns:
436450
output (torch.Tensor): Output tensor of the shape [B, C, H, W]
437451
"""
438-
# NOTE post-norm branches (op -> norm -> drop)
452+
# post-norm branches (op -> norm -> drop)
439453
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
440454
x = x + self.drop_path2(self.norm2(self.mlp(x)))
441455
x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
@@ -522,6 +536,7 @@ def __init__(
522536
feat_size: Tuple[int, int],
523537
window_size: Tuple[int, int],
524538
mlp_ratio: float = 4.0,
539+
init_values: float = 0.0,
525540
drop: float = 0.0,
526541
drop_attn: float = 0.0,
527542
drop_path: Union[List[float], float] = 0.0,
@@ -552,6 +567,7 @@ def _extra_norm(index):
552567
window_size=window_size,
553568
shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),
554569
mlp_ratio=mlp_ratio,
570+
init_values=init_values,
555571
drop=drop,
556572
drop_attn=drop_attn,
557573
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
@@ -634,6 +650,7 @@ def __init__(
634650
depths: Tuple[int, ...] = (2, 2, 6, 2),
635651
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
636652
mlp_ratio: float = 4.0,
653+
init_values: float = 0.0,
637654
drop_rate: float = 0.0,
638655
attn_drop_rate: float = 0.0,
639656
drop_path_rate: float = 0.0,
@@ -674,6 +691,7 @@ def __init__(
674691
num_heads=num_heads,
675692
window_size=window_size,
676693
mlp_ratio=mlp_ratio,
694+
init_values=init_values,
677695
drop=drop_rate,
678696
drop_attn=attn_drop_rate,
679697
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
@@ -786,6 +804,8 @@ def init_weights(module: nn.Module, name: str = ''):
786804
nn.init.xavier_uniform_(module.weight)
787805
if module.bias is not None:
788806
nn.init.zeros_(module.bias)
807+
elif hasattr(module, 'init_weights'):
808+
module.init_weights()
789809

790810

791811
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
@@ -863,6 +883,20 @@ def swin_v2_cr_small_224(pretrained=False, **kwargs):
863883
return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs)
864884

865885

886+
@register_model
887+
def swin_v2_cr_small_ns_224(pretrained=False, **kwargs):
888+
"""Swin-S V2 CR @ 224x224, trained ImageNet-1k"""
889+
model_kwargs = dict(
890+
embed_dim=96,
891+
depths=(2, 2, 18, 2),
892+
num_heads=(3, 6, 12, 24),
893+
init_values=1e-5,
894+
extra_norm_stage=True,
895+
**kwargs
896+
)
897+
return _create_swin_transformer_v2_cr('swin_v2_cr_small_ns_224', pretrained=pretrained, **model_kwargs)
898+
899+
866900
@register_model
867901
def swin_v2_cr_base_384(pretrained=False, **kwargs):
868902
"""Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
@@ -887,6 +921,20 @@ def swin_v2_cr_base_224(pretrained=False, **kwargs):
887921
return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs)
888922

889923

924+
@register_model
925+
def swin_v2_cr_base_ns_224(pretrained=False, **kwargs):
926+
"""Swin-B V2 CR @ 224x224, trained ImageNet-1k"""
927+
model_kwargs = dict(
928+
embed_dim=128,
929+
depths=(2, 2, 18, 2),
930+
num_heads=(4, 8, 16, 32),
931+
init_values=1e-6,
932+
extra_norm_stage=True,
933+
**kwargs
934+
)
935+
return _create_swin_transformer_v2_cr('swin_v2_cr_base_ns_224', pretrained=pretrained, **model_kwargs)
936+
937+
890938
@register_model
891939
def swin_v2_cr_large_384(pretrained=False, **kwargs):
892940
"""Swin-L V2 CR @ 384x384, trained ImageNet-1k"""

0 commit comments

Comments
 (0)