Skip to content

Commit 27c42f0

Browse files
committed
Fix torchscript use for offician Swin-V2, add support for non-square window/shift to WindowAttn/Block
1 parent 2f2b22d commit 27c42f0

File tree

1 file changed

+43
-37
lines changed

1 file changed

+43
-37
lines changed

timm/models/swin_transformer_v2.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# Written by Ze Liu
1414
# --------------------------------------------------------
1515
import math
16+
from typing import Tuple, Optional
1617

1718
import torch
1819
import torch.nn as nn
@@ -91,7 +92,7 @@ def _cfg(url='', **kwargs):
9192
}
9293

9394

94-
def window_partition(x, window_size):
95+
def window_partition(x, window_size: Tuple[int, int]):
9596
"""
9697
Args:
9798
x: (B, H, W, C)
@@ -101,25 +102,25 @@ def window_partition(x, window_size):
101102
windows: (num_windows*B, window_size, window_size, C)
102103
"""
103104
B, H, W, C = x.shape
104-
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
105-
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
105+
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
106+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
106107
return windows
107108

108109

109110
@register_notrace_function # reason: int argument is a Proxy
110-
def window_reverse(windows, window_size, H, W):
111+
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
111112
"""
112113
Args:
113-
windows: (num_windows*B, window_size, window_size, C)
114-
window_size (int): Window size
115-
H (int): Height of image
116-
W (int): Width of image
114+
windows: (num_windows * B, window_size[0], window_size[1], C)
115+
window_size (Tuple[int, int]): Window size
116+
img_size (Tuple[int, int]): Image size
117117
118118
Returns:
119119
x: (B, H, W, C)
120120
"""
121-
B = int(windows.shape[0] / (H * W / window_size / window_size))
122-
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
121+
H, W = img_size
122+
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
123+
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
123124
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
124125
return x
125126

@@ -148,7 +149,7 @@ def __init__(
148149
self.pretrained_window_size = pretrained_window_size
149150
self.num_heads = num_heads
150151

151-
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
152+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
152153

153154
# mlp to generate continuous relative position bias
154155
self.cpb_mlp = nn.Sequential(
@@ -202,7 +203,7 @@ def __init__(
202203
self.proj_drop = nn.Dropout(proj_drop)
203204
self.softmax = nn.Softmax(dim=-1)
204205

205-
def forward(self, x, mask=None):
206+
def forward(self, x, mask: Optional[torch.Tensor] = None):
206207
"""
207208
Args:
208209
x: input features with shape of (num_windows*B, N, C)
@@ -218,7 +219,7 @@ def forward(self, x, mask=None):
218219

219220
# cosine attention
220221
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
221-
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
222+
logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp()
222223
attn = attn * logit_scale
223224

224225
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
@@ -269,16 +270,13 @@ def __init__(
269270
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
270271
super().__init__()
271272
self.dim = dim
272-
self.input_resolution = input_resolution
273+
self.input_resolution = to_2tuple(input_resolution)
273274
self.num_heads = num_heads
274-
self.window_size = window_size
275-
self.shift_size = shift_size
275+
ws, ss = self._calc_window_shift(window_size, shift_size)
276+
self.window_size: Tuple[int, int] = ws
277+
self.shift_size: Tuple[int, int] = ss
278+
self.window_area = self.window_size[0] * self.window_size[1]
276279
self.mlp_ratio = mlp_ratio
277-
if min(self.input_resolution) <= self.window_size:
278-
# if window size is larger than input resolution, we don't partition windows
279-
self.shift_size = 0
280-
self.window_size = min(self.input_resolution)
281-
_assert(0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size")
282280

283281
self.attn = WindowAttention(
284282
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
@@ -291,56 +289,64 @@ def __init__(
291289
self.norm2 = norm_layer(dim)
292290
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
293291

294-
if self.shift_size > 0:
292+
if any(self.shift_size):
295293
# calculate attention mask for SW-MSA
296294
H, W = self.input_resolution
297295
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
298296
cnt = 0
299297
for h in (
300-
slice(0, -self.window_size),
301-
slice(-self.window_size, -self.shift_size),
302-
slice(-self.shift_size, None)):
298+
slice(0, -self.window_size[0]),
299+
slice(-self.window_size[0], -self.shift_size[0]),
300+
slice(-self.shift_size[0], None)):
303301
for w in (
304-
slice(0, -self.window_size),
305-
slice(-self.window_size, -self.shift_size),
306-
slice(-self.shift_size, None)):
302+
slice(0, -self.window_size[1]),
303+
slice(-self.window_size[1], -self.shift_size[1]),
304+
slice(-self.shift_size[1], None)):
307305
img_mask[:, h, w, :] = cnt
308306
cnt += 1
309307
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
310-
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
308+
mask_windows = mask_windows.view(-1, self.window_area)
311309
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
312310
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
313311
else:
314312
attn_mask = None
315313

316314
self.register_buffer("attn_mask", attn_mask)
317315

316+
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
317+
target_window_size = to_2tuple(target_window_size)
318+
target_shift_size = to_2tuple(target_shift_size)
319+
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
320+
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
321+
return tuple(window_size), tuple(shift_size)
322+
318323
def _attn(self, x):
319324
H, W = self.input_resolution
320325
B, L, C = x.shape
321326
_assert(L == H * W, "input feature has wrong size")
322327
x = x.view(B, H, W, C)
323328

324329
# cyclic shift
325-
if self.shift_size > 0:
326-
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
330+
has_shift = any(self.shift_size)
331+
if has_shift:
332+
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
327333
else:
328334
shifted_x = x
329335

330336
# partition windows
331337
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
332-
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
338+
x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
333339

334340
# W-MSA/SW-MSA
335341
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
336342

337343
# merge windows
338-
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
339-
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
344+
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
345+
shifted_x = window_reverse(attn_windows, self.window_size, self.input_resolution) # B H' W' C
340346

341347
# reverse cyclic shift
342-
if self.shift_size > 0:
343-
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
348+
if has_shift:
349+
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
344350
else:
345351
x = shifted_x
346352
x = x.view(B, H * W, C)
@@ -445,7 +451,7 @@ def __init__(
445451

446452
def forward(self, x):
447453
for blk in self.blocks:
448-
if self.grad_checkpointing:
454+
if not torch.jit.is_scripting() and self.grad_checkpointing:
449455
x = checkpoint.checkpoint(blk, x)
450456
else:
451457
x = blk(x)

0 commit comments

Comments
 (0)