1313# Written by Ze Liu
1414# --------------------------------------------------------
1515import math
16+ from typing import Tuple , Optional
1617
1718import torch
1819import torch .nn as nn
@@ -91,7 +92,7 @@ def _cfg(url='', **kwargs):
9192}
9293
9394
94- def window_partition (x , window_size ):
95+ def window_partition (x , window_size : Tuple [ int , int ] ):
9596 """
9697 Args:
9798 x: (B, H, W, C)
@@ -101,25 +102,25 @@ def window_partition(x, window_size):
101102 windows: (num_windows*B, window_size, window_size, C)
102103 """
103104 B , H , W , C = x .shape
104- x = x .view (B , H // window_size , window_size , W // window_size , window_size , C )
105- windows = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (- 1 , window_size , window_size , C )
105+ x = x .view (B , H // window_size [ 0 ] , window_size [ 0 ] , W // window_size [ 1 ] , window_size [ 1 ] , C )
106+ windows = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (- 1 , window_size [ 0 ] , window_size [ 1 ] , C )
106107 return windows
107108
108109
109110@register_notrace_function # reason: int argument is a Proxy
110- def window_reverse (windows , window_size , H , W ):
111+ def window_reverse (windows , window_size : Tuple [ int , int ], img_size : Tuple [ int , int ] ):
111112 """
112113 Args:
113- windows: (num_windows*B, window_size, window_size, C)
114- window_size (int): Window size
115- H (int): Height of image
116- W (int): Width of image
114+ windows: (num_windows * B, window_size[0], window_size[1], C)
115+ window_size (Tuple[int, int]): Window size
116+ img_size (Tuple[int, int]): Image size
117117
118118 Returns:
119119 x: (B, H, W, C)
120120 """
121- B = int (windows .shape [0 ] / (H * W / window_size / window_size ))
122- x = windows .view (B , H // window_size , W // window_size , window_size , window_size , - 1 )
121+ H , W = img_size
122+ B = int (windows .shape [0 ] / (H * W / window_size [0 ] / window_size [1 ]))
123+ x = windows .view (B , H // window_size [0 ], W // window_size [1 ], window_size [0 ], window_size [1 ], - 1 )
123124 x = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (B , H , W , - 1 )
124125 return x
125126
@@ -148,7 +149,7 @@ def __init__(
148149 self .pretrained_window_size = pretrained_window_size
149150 self .num_heads = num_heads
150151
151- self .logit_scale = nn .Parameter (torch .log (10 * torch .ones ((num_heads , 1 , 1 ))), requires_grad = True )
152+ self .logit_scale = nn .Parameter (torch .log (10 * torch .ones ((num_heads , 1 , 1 ))))
152153
153154 # mlp to generate continuous relative position bias
154155 self .cpb_mlp = nn .Sequential (
@@ -202,7 +203,7 @@ def __init__(
202203 self .proj_drop = nn .Dropout (proj_drop )
203204 self .softmax = nn .Softmax (dim = - 1 )
204205
205- def forward (self , x , mask = None ):
206+ def forward (self , x , mask : Optional [ torch . Tensor ] = None ):
206207 """
207208 Args:
208209 x: input features with shape of (num_windows*B, N, C)
@@ -218,7 +219,7 @@ def forward(self, x, mask=None):
218219
219220 # cosine attention
220221 attn = (F .normalize (q , dim = - 1 ) @ F .normalize (k , dim = - 1 ).transpose (- 2 , - 1 ))
221- logit_scale = torch .clamp (self .logit_scale , max = torch .log (torch . tensor ( 1. / 0.01 ) )).exp ()
222+ logit_scale = torch .clamp (self .logit_scale , max = math .log (1. / 0.01 )).exp ()
222223 attn = attn * logit_scale
223224
224225 relative_position_bias_table = self .cpb_mlp (self .relative_coords_table ).view (- 1 , self .num_heads )
@@ -269,16 +270,13 @@ def __init__(
269270 act_layer = nn .GELU , norm_layer = nn .LayerNorm , pretrained_window_size = 0 ):
270271 super ().__init__ ()
271272 self .dim = dim
272- self .input_resolution = input_resolution
273+ self .input_resolution = to_2tuple ( input_resolution )
273274 self .num_heads = num_heads
274- self .window_size = window_size
275- self .shift_size = shift_size
275+ ws , ss = self ._calc_window_shift (window_size , shift_size )
276+ self .window_size : Tuple [int , int ] = ws
277+ self .shift_size : Tuple [int , int ] = ss
278+ self .window_area = self .window_size [0 ] * self .window_size [1 ]
276279 self .mlp_ratio = mlp_ratio
277- if min (self .input_resolution ) <= self .window_size :
278- # if window size is larger than input resolution, we don't partition windows
279- self .shift_size = 0
280- self .window_size = min (self .input_resolution )
281- _assert (0 <= self .shift_size < self .window_size , "shift_size must in 0-window_size" )
282280
283281 self .attn = WindowAttention (
284282 dim , window_size = to_2tuple (self .window_size ), num_heads = num_heads ,
@@ -291,56 +289,64 @@ def __init__(
291289 self .norm2 = norm_layer (dim )
292290 self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
293291
294- if self .shift_size > 0 :
292+ if any ( self .shift_size ) :
295293 # calculate attention mask for SW-MSA
296294 H , W = self .input_resolution
297295 img_mask = torch .zeros ((1 , H , W , 1 )) # 1 H W 1
298296 cnt = 0
299297 for h in (
300- slice (0 , - self .window_size ),
301- slice (- self .window_size , - self .shift_size ),
302- slice (- self .shift_size , None )):
298+ slice (0 , - self .window_size [ 0 ] ),
299+ slice (- self .window_size [ 0 ] , - self .shift_size [ 0 ] ),
300+ slice (- self .shift_size [ 0 ] , None )):
303301 for w in (
304- slice (0 , - self .window_size ),
305- slice (- self .window_size , - self .shift_size ),
306- slice (- self .shift_size , None )):
302+ slice (0 , - self .window_size [ 1 ] ),
303+ slice (- self .window_size [ 1 ] , - self .shift_size [ 1 ] ),
304+ slice (- self .shift_size [ 1 ] , None )):
307305 img_mask [:, h , w , :] = cnt
308306 cnt += 1
309307 mask_windows = window_partition (img_mask , self .window_size ) # nW, window_size, window_size, 1
310- mask_windows = mask_windows .view (- 1 , self .window_size * self . window_size )
308+ mask_windows = mask_windows .view (- 1 , self .window_area )
311309 attn_mask = mask_windows .unsqueeze (1 ) - mask_windows .unsqueeze (2 )
312310 attn_mask = attn_mask .masked_fill (attn_mask != 0 , float (- 100.0 )).masked_fill (attn_mask == 0 , float (0.0 ))
313311 else :
314312 attn_mask = None
315313
316314 self .register_buffer ("attn_mask" , attn_mask )
317315
316+ def _calc_window_shift (self , target_window_size , target_shift_size ) -> Tuple [Tuple [int , int ], Tuple [int , int ]]:
317+ target_window_size = to_2tuple (target_window_size )
318+ target_shift_size = to_2tuple (target_shift_size )
319+ window_size = [r if r <= w else w for r , w in zip (self .input_resolution , target_window_size )]
320+ shift_size = [0 if r <= w else s for r , w , s in zip (self .input_resolution , window_size , target_shift_size )]
321+ return tuple (window_size ), tuple (shift_size )
322+
318323 def _attn (self , x ):
319324 H , W = self .input_resolution
320325 B , L , C = x .shape
321326 _assert (L == H * W , "input feature has wrong size" )
322327 x = x .view (B , H , W , C )
323328
324329 # cyclic shift
325- if self .shift_size > 0 :
326- shifted_x = torch .roll (x , shifts = (- self .shift_size , - self .shift_size ), dims = (1 , 2 ))
330+ has_shift = any (self .shift_size )
331+ if has_shift :
332+ shifted_x = torch .roll (x , shifts = (- self .shift_size [0 ], - self .shift_size [1 ]), dims = (1 , 2 ))
327333 else :
328334 shifted_x = x
329335
330336 # partition windows
331337 x_windows = window_partition (shifted_x , self .window_size ) # nW*B, window_size, window_size, C
332- x_windows = x_windows .view (- 1 , self .window_size * self . window_size , C ) # nW*B, window_size*window_size, C
338+ x_windows = x_windows .view (- 1 , self .window_area , C ) # nW*B, window_size*window_size, C
333339
334340 # W-MSA/SW-MSA
335341 attn_windows = self .attn (x_windows , mask = self .attn_mask ) # nW*B, window_size*window_size, C
336342
337343 # merge windows
338- attn_windows = attn_windows .view (- 1 , self .window_size , self .window_size , C )
339- shifted_x = window_reverse (attn_windows , self .window_size , H , W ) # B H' W' C
344+ attn_windows = attn_windows .view (- 1 , self .window_size [ 0 ] , self .window_size [ 1 ] , C )
345+ shifted_x = window_reverse (attn_windows , self .window_size , self . input_resolution ) # B H' W' C
340346
341347 # reverse cyclic shift
342- if self . shift_size > 0 :
343- x = torch .roll (shifted_x , shifts = ( self .shift_size , self . shift_size ) , dims = (1 , 2 ))
348+ if has_shift :
349+ x = torch .roll (shifted_x , shifts = self .shift_size , dims = (1 , 2 ))
344350 else :
345351 x = shifted_x
346352 x = x .view (B , H * W , C )
@@ -445,7 +451,7 @@ def __init__(
445451
446452 def forward (self , x ):
447453 for blk in self .blocks :
448- if self .grad_checkpointing :
454+ if not torch . jit . is_scripting () and self .grad_checkpointing :
449455 x = checkpoint .checkpoint (blk , x )
450456 else :
451457 x = blk (x )
0 commit comments