Skip to content

Commit d4c0588

Browse files
committed
Remove persistent buffers from Swin-V2. Change SwinV2Cr cos attn + tau/logit_scale to match official, add ckpt convert, init_value zeros resid LN weight by default
1 parent 27c42f0 commit d4c0588

File tree

2 files changed

+46
-22
lines changed

2 files changed

+46
-22
lines changed

timm/models/swin_transformer_v2.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from .helpers import build_model_with_cfg, named_apply
2626
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
2727
from .registry import register_model
28-
from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit
2928

3029

3130
def _cfg(url='', **kwargs):
@@ -75,7 +74,7 @@ def _cfg(url='', **kwargs):
7574
),
7675
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
7776
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
78-
input_size=(3, 384, 384)
77+
input_size=(3, 384, 384), crop_pct=1.0,
7978
),
8079
'swinv2_large_window12_192_22k': _cfg(
8180
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
@@ -87,7 +86,7 @@ def _cfg(url='', **kwargs):
8786
),
8887
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
8988
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
90-
input_size=(3, 384, 384)
89+
input_size=(3, 384, 384), crop_pct=1.0,
9190
),
9291
}
9392

@@ -174,7 +173,7 @@ def __init__(
174173
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
175174
torch.abs(relative_coords_table) + 1.0) / math.log2(8)
176175

177-
self.register_buffer("relative_coords_table", relative_coords_table)
176+
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
178177

179178
# get pair-wise relative position index for each token inside the window
180179
coords_h = torch.arange(self.window_size[0])
@@ -187,7 +186,7 @@ def __init__(
187186
relative_coords[:, :, 1] += self.window_size[1] - 1
188187
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
189188
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
190-
self.register_buffer("relative_position_index", relative_position_index)
189+
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
191190

192191
self.qkv = nn.Linear(dim, dim * 3, bias=False)
193192
if qkv_bias:
@@ -215,7 +214,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None):
215214
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
216215
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
217216
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
218-
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
217+
q, k, v = qkv.unbind(0)
219218

220219
# cosine attention
221220
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
@@ -559,9 +558,6 @@ def _init_weights(self, m):
559558
trunc_normal_(m.weight, std=.02)
560559
if isinstance(m, nn.Linear) and m.bias is not None:
561560
nn.init.constant_(m.bias, 0)
562-
elif isinstance(m, nn.LayerNorm):
563-
nn.init.constant_(m.bias, 0)
564-
nn.init.constant_(m.weight, 1.0)
565561

566562
@torch.jit.ignore
567563
def no_weight_decay(self):
@@ -621,6 +617,18 @@ def forward(self, x):
621617
return x
622618

623619

620+
def checkpoint_filter_fn(state_dict, model):
621+
out_dict = {}
622+
if 'model' in state_dict:
623+
# For deit models
624+
state_dict = state_dict['model']
625+
for k, v in state_dict.items():
626+
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
627+
continue # skip buffers that should not be persistent
628+
out_dict[k] = v
629+
return out_dict
630+
631+
624632
def _create_swin_transformer_v2(variant, pretrained=False, **kwargs):
625633
model = build_model_with_cfg(
626634
SwinTransformerV2, variant, pretrained,

timm/models/swin_transformer_v2_cr.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@
3434

3535
import torch
3636
import torch.nn as nn
37+
import torch.nn.functional as F
3738
import torch.utils.checkpoint as checkpoint
3839

3940
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
4041
from .fx_features import register_notrace_function
4142
from .helpers import build_model_with_cfg, named_apply
4243
from .layers import DropPath, Mlp, to_2tuple, _assert
4344
from .registry import register_model
44-
from .vision_transformer import checkpoint_filter_fn
45+
4546

4647
_logger = logging.getLogger(__name__)
4748

@@ -186,12 +187,13 @@ def __init__(
186187
act_layer=nn.ReLU,
187188
drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without?
188189
)
189-
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
190+
# NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn
191+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads)))
190192
self._make_pair_wise_relative_positions()
191193

192194
def _make_pair_wise_relative_positions(self) -> None:
193195
"""Method initializes the pair-wise relative positions to compute the positional biases."""
194-
device = self.tau.device
196+
device = self.logit_scale.device
195197
coordinates = torch.stack(torch.meshgrid([
196198
torch.arange(self.window_size[0], device=device),
197199
torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1)
@@ -250,10 +252,11 @@ def _forward_batch(
250252
query, key, value = qkv.unbind(0)
251253

252254
# compute attention map with scaled cosine attention
253-
denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1)
254-
attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6)
255-
attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1)
255+
attn = (F.normalize(query, dim=-1) @ F.normalize(key, dim=-1).transpose(-2, -1))
256+
logit_scale = torch.clamp(self.logit_scale.reshape(1, self.num_heads, 1, 1), max=math.log(1. / 0.01)).exp()
257+
attn = attn * logit_scale
256258
attn = attn + self._relative_positional_encodings()
259+
257260
if mask is not None:
258261
# Apply mask if utilized
259262
num_win: int = mask.shape[0]
@@ -309,7 +312,7 @@ def __init__(
309312
window_size: Tuple[int, int],
310313
shift_size: Tuple[int, int] = (0, 0),
311314
mlp_ratio: float = 4.0,
312-
init_values: float = 0,
315+
init_values: Optional[float] = 0,
313316
drop: float = 0.0,
314317
drop_attn: float = 0.0,
315318
drop_path: float = 0.0,
@@ -323,7 +326,7 @@ def __init__(
323326
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
324327
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
325328
self.window_area = self.window_size[0] * self.window_size[1]
326-
self.init_values: float = init_values
329+
self.init_values: Optional[float] = init_values
327330

328331
# attn branch
329332
self.attn = WindowMultiHeadAttention(
@@ -387,7 +390,7 @@ def _make_attention_mask(self) -> None:
387390

388391
def init_weights(self):
389392
# extra, module specific weight init
390-
if self.init_values:
393+
if self.init_values is not None:
391394
nn.init.constant_(self.norm1.weight, self.init_values)
392395
nn.init.constant_(self.norm2.weight, self.init_values)
393396

@@ -536,7 +539,7 @@ def __init__(
536539
feat_size: Tuple[int, int],
537540
window_size: Tuple[int, int],
538541
mlp_ratio: float = 4.0,
539-
init_values: float = 0.0,
542+
init_values: Optional[float] = 0.0,
540543
drop: float = 0.0,
541544
drop_attn: float = 0.0,
542545
drop_path: Union[List[float], float] = 0.0,
@@ -650,7 +653,7 @@ def __init__(
650653
depths: Tuple[int, ...] = (2, 2, 6, 2),
651654
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
652655
mlp_ratio: float = 4.0,
653-
init_values: float = 0.0,
656+
init_values: Optional[float] = 0.,
654657
drop_rate: float = 0.0,
655658
attn_drop_rate: float = 0.0,
656659
drop_path_rate: float = 0.0,
@@ -808,6 +811,21 @@ def init_weights(module: nn.Module, name: str = ''):
808811
module.init_weights()
809812

810813

814+
def checkpoint_filter_fn(state_dict, model):
815+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
816+
out_dict = {}
817+
if 'model' in state_dict:
818+
# For deit models
819+
state_dict = state_dict['model']
820+
for k, v in state_dict.items():
821+
if 'tau' in k:
822+
# convert old tau based checkpoints -> logit_scale (inverse)
823+
v = torch.log(1 / v)
824+
k = k.replace('tau', 'logit_scale')
825+
out_dict[k] = v
826+
return out_dict
827+
828+
811829
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
812830
if kwargs.get('features_only', None):
813831
raise RuntimeError('features_only not implemented for Vision Transformer models.')
@@ -890,7 +908,6 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs):
890908
embed_dim=96,
891909
depths=(2, 2, 18, 2),
892910
num_heads=(3, 6, 12, 24),
893-
init_values=1e-5,
894911
extra_norm_stage=True,
895912
**kwargs
896913
)
@@ -928,7 +945,6 @@ def swinv2_cr_base_ns_224(pretrained=False, **kwargs):
928945
embed_dim=128,
929946
depths=(2, 2, 18, 2),
930947
num_heads=(4, 8, 16, 32),
931-
init_values=1e-6,
932948
extra_norm_stage=True,
933949
**kwargs
934950
)

0 commit comments

Comments
 (0)