Skip to content

Commit 4c531be

Browse files
committed
set_input_size(), always_partition, strict_img_size, dynamic mask option for all swin models. More flexibility in resolution, window resizing.
1 parent 2b3f1a4 commit 4c531be

File tree

3 files changed

+416
-148
lines changed

3 files changed

+416
-148
lines changed

timm/models/swin_transformer.py

Lines changed: 74 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def __init__(
219219
window_size: _int_or_tuple_2_t = 7,
220220
shift_size: int = 0,
221221
always_partition: bool = False,
222+
dynamic_mask: bool = False,
222223
mlp_ratio: float = 4.,
223224
qkv_bias: bool = True,
224225
proj_drop: float = 0.,
@@ -235,6 +236,7 @@ def __init__(
235236
num_heads: Number of attention heads.
236237
head_dim: Enforce the number of channels per head
237238
shift_size: Shift size for SW-MSA.
239+
always_partition: Always partition into full windows and shift
238240
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
239241
qkv_bias: If True, add a learnable bias to query, key, value.
240242
proj_drop: Dropout rate.
@@ -246,9 +248,10 @@ def __init__(
246248
super().__init__()
247249
self.dim = dim
248250
self.input_resolution = input_resolution
249-
self.target_shift_size = to_2tuple(shift_size)
251+
self.target_shift_size = to_2tuple(shift_size) # store for later resize
250252
self.always_partition = always_partition
251-
self.window_size, self.shift_size = self._calc_window_shift(window_size, target_shift_size=shift_size)
253+
self.dynamic_mask = dynamic_mask
254+
self.window_size, self.shift_size = self._calc_window_shift(window_size, shift_size)
252255
self.window_area = self.window_size[0] * self.window_size[1]
253256
self.mlp_ratio = mlp_ratio
254257

@@ -257,7 +260,7 @@ def __init__(
257260
dim,
258261
num_heads=num_heads,
259262
head_dim=head_dim,
260-
window_size=to_2tuple(self.window_size),
263+
window_size=self.window_size,
261264
qkv_bias=qkv_bias,
262265
attn_drop=attn_drop,
263266
proj_drop=proj_drop,
@@ -273,33 +276,46 @@ def __init__(
273276
)
274277
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
275278

276-
self._make_attention_mask()
279+
self.register_buffer(
280+
"attn_mask",
281+
None if self.dynamic_mask else self.get_attn_mask(),
282+
persistent=False,
283+
)
277284

278-
def _make_attention_mask(self):
285+
def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
279286
if any(self.shift_size):
280287
# calculate attention mask for SW-MSA
281-
H, W = self.input_resolution
288+
if x is not None:
289+
H, W = x.shape[1], x.shape[2]
290+
device = x.device
291+
dtype = x.dtype
292+
else:
293+
H, W = self.input_resolution
294+
device = None
295+
dtype = None
282296
H = math.ceil(H / self.window_size[0]) * self.window_size[0]
283297
W = math.ceil(W / self.window_size[1]) * self.window_size[1]
284-
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
298+
img_mask = torch.zeros((1, H, W, 1), dtype=dtype, device=device) # 1 H W 1
285299
cnt = 0
286300
for h in (
287-
slice(0, -self.window_size[0]),
288-
slice(-self.window_size[0], -self.shift_size[0]),
289-
slice(-self.shift_size[0], None)):
301+
(0, -self.window_size[0]),
302+
(-self.window_size[0], -self.shift_size[0]),
303+
(-self.shift_size[0], None),
304+
):
290305
for w in (
291-
slice(0, -self.window_size[1]),
292-
slice(-self.window_size[1], -self.shift_size[1]),
293-
slice(-self.shift_size[1], None)):
294-
img_mask[:, h, w, :] = cnt
306+
(0, -self.window_size[1]),
307+
(-self.window_size[1], -self.shift_size[1]),
308+
(-self.shift_size[1], None),
309+
):
310+
img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt
295311
cnt += 1
296312
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
297313
mask_windows = mask_windows.view(-1, self.window_area)
298314
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
299315
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
300316
else:
301317
attn_mask = None
302-
self.register_buffer("attn_mask", attn_mask, persistent=False)
318+
return attn_mask
303319

304320
def _calc_window_shift(
305321
self,
@@ -308,14 +324,16 @@ def _calc_window_shift(
308324
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
309325
target_window_size = to_2tuple(target_window_size)
310326
if target_shift_size is None:
311-
# if passed value is None, recalculate from default window_size // 2 if it was active
327+
# if passed value is None, recalculate from default window_size // 2 if it was previously non-zero
312328
target_shift_size = self.target_shift_size
313329
if any(target_shift_size):
314-
target_shift_size = [target_window_size[0] // 2, target_window_size[1] // 2]
330+
target_shift_size = (target_window_size[0] // 2, target_window_size[1] // 2)
315331
else:
316332
target_shift_size = to_2tuple(target_shift_size)
333+
317334
if self.always_partition:
318335
return target_window_size, target_shift_size
336+
319337
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
320338
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
321339
return tuple(window_size), tuple(shift_size)
@@ -338,7 +356,11 @@ def set_input_size(
338356
self.window_size, self.shift_size = self._calc_window_shift(window_size)
339357
self.window_area = self.window_size[0] * self.window_size[1]
340358
self.attn.set_window_size(self.window_size)
341-
self._make_attention_mask()
359+
self.register_buffer(
360+
"attn_mask",
361+
None if self.dynamic_mask else self.get_attn_mask(),
362+
persistent=False,
363+
)
342364

343365
def _attn(self, x):
344366
B, H, W, C = x.shape
@@ -354,14 +376,18 @@ def _attn(self, x):
354376
pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
355377
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
356378
shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
357-
Hp, Wp = H + pad_h, W + pad_w
379+
_, Hp, Wp, _ = shifted_x.shape
358380

359381
# partition windows
360382
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
361383
x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
362384

363385
# W-MSA/SW-MSA
364-
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
386+
if getattr(self, 'dynamic_mask', False):
387+
attn_mask = self.get_attn_mask(shifted_x)
388+
else:
389+
attn_mask = self.attn_mask
390+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
365391

366392
# merge windows
367393
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
@@ -408,8 +434,11 @@ def __init__(
408434

409435
def forward(self, x):
410436
B, H, W, C = x.shape
411-
_assert(H % 2 == 0, f"x height ({H}) is not even.")
412-
_assert(W % 2 == 0, f"x width ({W}) is not even.")
437+
438+
pad_values = (0, 0, 0, H % 2, 0, W % 2)
439+
x = nn.functional.pad(x, pad_values)
440+
_, H, W, _ = x.shape
441+
413442
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
414443
x = self.norm(x)
415444
x = self.reduction(x)
@@ -431,6 +460,7 @@ def __init__(
431460
head_dim: Optional[int] = None,
432461
window_size: _int_or_tuple_2_t = 7,
433462
always_partition: bool = False,
463+
dynamic_mask: bool = False,
434464
mlp_ratio: float = 4.,
435465
qkv_bias: bool = True,
436466
proj_drop: float = 0.,
@@ -485,6 +515,7 @@ def __init__(
485515
window_size=window_size,
486516
shift_size=0 if (i % 2 == 0) else shift_size,
487517
always_partition=always_partition,
518+
dynamic_mask=dynamic_mask,
488519
mlp_ratio=mlp_ratio,
489520
qkv_bias=qkv_bias,
490521
proj_drop=proj_drop,
@@ -500,11 +531,12 @@ def set_input_size(
500531
window_size: int,
501532
always_partition: Optional[bool] = None,
502533
):
503-
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
534+
""" Updates the resolution, window size and so the pair-wise relative positions.
504535
505536
Args:
506-
feat_size (Tuple[int, int]): New input resolution
507-
window_size (int): New window size
537+
feat_size: New input (feature) resolution
538+
window_size: New window size
539+
always_partition: Always partition / shift the window
508540
"""
509541
self.input_resolution = feat_size
510542
if isinstance(self.downsample, nn.Identity):
@@ -548,6 +580,7 @@ def __init__(
548580
head_dim: Optional[int] = None,
549581
window_size: _int_or_tuple_2_t = 7,
550582
always_partition: bool = False,
583+
strict_img_size: bool = True,
551584
mlp_ratio: float = 4.,
552585
qkv_bias: bool = True,
553586
drop_rate: float = 0.,
@@ -599,9 +632,10 @@ def __init__(
599632
in_chans=in_chans,
600633
embed_dim=embed_dim[0],
601634
norm_layer=norm_layer,
635+
strict_img_size=strict_img_size,
602636
output_fmt='NHWC',
603637
)
604-
self.patch_grid = self.patch_embed.grid_size
638+
patch_grid = self.patch_embed.grid_size
605639

606640
# build layers
607641
head_dim = to_ntuple(self.num_layers)(head_dim)
@@ -621,15 +655,16 @@ def __init__(
621655
dim=in_dim,
622656
out_dim=out_dim,
623657
input_resolution=(
624-
self.patch_grid[0] // scale,
625-
self.patch_grid[1] // scale
658+
patch_grid[0] // scale,
659+
patch_grid[1] // scale
626660
),
627661
depth=depths[i],
628662
downsample=i > 0,
629663
num_heads=num_heads[i],
630664
head_dim=head_dim[i],
631665
window_size=window_size[i],
632666
always_partition=always_partition,
667+
dynamic_mask=not strict_img_size,
633668
mlp_ratio=mlp_ratio[i],
634669
qkv_bias=qkv_bias,
635670
proj_drop=proj_drop_rate,
@@ -673,27 +708,29 @@ def set_input_size(
673708
img_size: Optional[Tuple[int, int]] = None,
674709
patch_size: Optional[Tuple[int, int]] = None,
675710
window_size: Optional[Tuple[int, int]] = None,
676-
window_ratio: int = 32,
711+
window_ratio: int = 8,
677712
always_partition: Optional[bool] = None,
678713
) -> None:
679714
""" Updates the image resolution and window size.
680715
681716
Args:
682-
img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
683-
window_size (Optional[int]): New window size, if None based on new_img_size // window_div
684-
window_ratio (int): divisor for calculating window size from image size
717+
img_size: New input resolution, if None current resolution is used
718+
patch_size (Optional[Tuple[int, int]): New patch size, if None use current patch size
719+
window_size: New window size, if None based on new_img_size // window_div
720+
window_ratio: divisor for calculating window size from grid size
721+
always_partition: always partition into windows and shift (even if window size < feat size)
685722
"""
686723
if img_size is not None or patch_size is not None:
687724
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
688-
self.patch_grid = self.patch_embed.grid_size
725+
patch_grid = self.patch_embed.grid_size
726+
689727
if window_size is None:
690-
img_size = self.patch_embed.img_size
691-
window_size = tuple([s // window_ratio for s in img_size])
728+
window_size = tuple([pg // window_ratio for pg in patch_grid])
729+
692730
for index, stage in enumerate(self.layers):
693731
stage_scale = 2 ** max(index - 1, 0)
694-
print(self.patch_grid, stage_scale)
695732
stage.set_input_size(
696-
feat_size=(self.patch_grid[0] // stage_scale, self.patch_grid[1] // stage_scale),
733+
feat_size=(patch_grid[0] // stage_scale, patch_grid[1] // stage_scale),
697734
window_size=window_size,
698735
always_partition=always_partition,
699736
)

0 commit comments

Comments
 (0)