@@ -219,6 +219,7 @@ def __init__(
219219 window_size : _int_or_tuple_2_t = 7 ,
220220 shift_size : int = 0 ,
221221 always_partition : bool = False ,
222+ dynamic_mask : bool = False ,
222223 mlp_ratio : float = 4. ,
223224 qkv_bias : bool = True ,
224225 proj_drop : float = 0. ,
@@ -235,6 +236,7 @@ def __init__(
235236 num_heads: Number of attention heads.
236237 head_dim: Enforce the number of channels per head
237238 shift_size: Shift size for SW-MSA.
239+ always_partition: Always partition into full windows and shift
238240 mlp_ratio: Ratio of mlp hidden dim to embedding dim.
239241 qkv_bias: If True, add a learnable bias to query, key, value.
240242 proj_drop: Dropout rate.
@@ -246,9 +248,10 @@ def __init__(
246248 super ().__init__ ()
247249 self .dim = dim
248250 self .input_resolution = input_resolution
249- self .target_shift_size = to_2tuple (shift_size )
251+ self .target_shift_size = to_2tuple (shift_size ) # store for later resize
250252 self .always_partition = always_partition
251- self .window_size , self .shift_size = self ._calc_window_shift (window_size , target_shift_size = shift_size )
253+ self .dynamic_mask = dynamic_mask
254+ self .window_size , self .shift_size = self ._calc_window_shift (window_size , shift_size )
252255 self .window_area = self .window_size [0 ] * self .window_size [1 ]
253256 self .mlp_ratio = mlp_ratio
254257
@@ -257,7 +260,7 @@ def __init__(
257260 dim ,
258261 num_heads = num_heads ,
259262 head_dim = head_dim ,
260- window_size = to_2tuple ( self .window_size ) ,
263+ window_size = self .window_size ,
261264 qkv_bias = qkv_bias ,
262265 attn_drop = attn_drop ,
263266 proj_drop = proj_drop ,
@@ -273,33 +276,46 @@ def __init__(
273276 )
274277 self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
275278
276- self ._make_attention_mask ()
279+ self .register_buffer (
280+ "attn_mask" ,
281+ None if self .dynamic_mask else self .get_attn_mask (),
282+ persistent = False ,
283+ )
277284
278- def _make_attention_mask (self ) :
285+ def get_attn_mask (self , x : Optional [ torch . Tensor ] = None ) -> Optional [ torch . Tensor ] :
279286 if any (self .shift_size ):
280287 # calculate attention mask for SW-MSA
281- H , W = self .input_resolution
288+ if x is not None :
289+ H , W = x .shape [1 ], x .shape [2 ]
290+ device = x .device
291+ dtype = x .dtype
292+ else :
293+ H , W = self .input_resolution
294+ device = None
295+ dtype = None
282296 H = math .ceil (H / self .window_size [0 ]) * self .window_size [0 ]
283297 W = math .ceil (W / self .window_size [1 ]) * self .window_size [1 ]
284- img_mask = torch .zeros ((1 , H , W , 1 )) # 1 H W 1
298+ img_mask = torch .zeros ((1 , H , W , 1 ), dtype = dtype , device = device ) # 1 H W 1
285299 cnt = 0
286300 for h in (
287- slice (0 , - self .window_size [0 ]),
288- slice (- self .window_size [0 ], - self .shift_size [0 ]),
289- slice (- self .shift_size [0 ], None )):
301+ (0 , - self .window_size [0 ]),
302+ (- self .window_size [0 ], - self .shift_size [0 ]),
303+ (- self .shift_size [0 ], None ),
304+ ):
290305 for w in (
291- slice (0 , - self .window_size [1 ]),
292- slice (- self .window_size [1 ], - self .shift_size [1 ]),
293- slice (- self .shift_size [1 ], None )):
294- img_mask [:, h , w , :] = cnt
306+ (0 , - self .window_size [1 ]),
307+ (- self .window_size [1 ], - self .shift_size [1 ]),
308+ (- self .shift_size [1 ], None ),
309+ ):
310+ img_mask [:, h [0 ]:h [1 ], w [0 ]:w [1 ], :] = cnt
295311 cnt += 1
296312 mask_windows = window_partition (img_mask , self .window_size ) # nW, window_size, window_size, 1
297313 mask_windows = mask_windows .view (- 1 , self .window_area )
298314 attn_mask = mask_windows .unsqueeze (1 ) - mask_windows .unsqueeze (2 )
299315 attn_mask = attn_mask .masked_fill (attn_mask != 0 , float (- 100.0 )).masked_fill (attn_mask == 0 , float (0.0 ))
300316 else :
301317 attn_mask = None
302- self . register_buffer ( "attn_mask" , attn_mask , persistent = False )
318+ return attn_mask
303319
304320 def _calc_window_shift (
305321 self ,
@@ -308,14 +324,16 @@ def _calc_window_shift(
308324 ) -> Tuple [Tuple [int , int ], Tuple [int , int ]]:
309325 target_window_size = to_2tuple (target_window_size )
310326 if target_shift_size is None :
311- # if passed value is None, recalculate from default window_size // 2 if it was active
327+ # if passed value is None, recalculate from default window_size // 2 if it was previously non-zero
312328 target_shift_size = self .target_shift_size
313329 if any (target_shift_size ):
314- target_shift_size = [ target_window_size [0 ] // 2 , target_window_size [1 ] // 2 ]
330+ target_shift_size = ( target_window_size [0 ] // 2 , target_window_size [1 ] // 2 )
315331 else :
316332 target_shift_size = to_2tuple (target_shift_size )
333+
317334 if self .always_partition :
318335 return target_window_size , target_shift_size
336+
319337 window_size = [r if r <= w else w for r , w in zip (self .input_resolution , target_window_size )]
320338 shift_size = [0 if r <= w else s for r , w , s in zip (self .input_resolution , window_size , target_shift_size )]
321339 return tuple (window_size ), tuple (shift_size )
@@ -338,7 +356,11 @@ def set_input_size(
338356 self .window_size , self .shift_size = self ._calc_window_shift (window_size )
339357 self .window_area = self .window_size [0 ] * self .window_size [1 ]
340358 self .attn .set_window_size (self .window_size )
341- self ._make_attention_mask ()
359+ self .register_buffer (
360+ "attn_mask" ,
361+ None if self .dynamic_mask else self .get_attn_mask (),
362+ persistent = False ,
363+ )
342364
343365 def _attn (self , x ):
344366 B , H , W , C = x .shape
@@ -354,14 +376,18 @@ def _attn(self, x):
354376 pad_h = (self .window_size [0 ] - H % self .window_size [0 ]) % self .window_size [0 ]
355377 pad_w = (self .window_size [1 ] - W % self .window_size [1 ]) % self .window_size [1 ]
356378 shifted_x = torch .nn .functional .pad (shifted_x , (0 , 0 , 0 , pad_w , 0 , pad_h ))
357- Hp , Wp = H + pad_h , W + pad_w
379+ _ , Hp , Wp , _ = shifted_x . shape
358380
359381 # partition windows
360382 x_windows = window_partition (shifted_x , self .window_size ) # nW*B, window_size, window_size, C
361383 x_windows = x_windows .view (- 1 , self .window_area , C ) # nW*B, window_size*window_size, C
362384
363385 # W-MSA/SW-MSA
364- attn_windows = self .attn (x_windows , mask = self .attn_mask ) # nW*B, window_size*window_size, C
386+ if getattr (self , 'dynamic_mask' , False ):
387+ attn_mask = self .get_attn_mask (shifted_x )
388+ else :
389+ attn_mask = self .attn_mask
390+ attn_windows = self .attn (x_windows , mask = attn_mask ) # nW*B, window_size*window_size, C
365391
366392 # merge windows
367393 attn_windows = attn_windows .view (- 1 , self .window_size [0 ], self .window_size [1 ], C )
@@ -408,8 +434,11 @@ def __init__(
408434
409435 def forward (self , x ):
410436 B , H , W , C = x .shape
411- _assert (H % 2 == 0 , f"x height ({ H } ) is not even." )
412- _assert (W % 2 == 0 , f"x width ({ W } ) is not even." )
437+
438+ pad_values = (0 , 0 , 0 , H % 2 , 0 , W % 2 )
439+ x = nn .functional .pad (x , pad_values )
440+ _ , H , W , _ = x .shape
441+
413442 x = x .reshape (B , H // 2 , 2 , W // 2 , 2 , C ).permute (0 , 1 , 3 , 4 , 2 , 5 ).flatten (3 )
414443 x = self .norm (x )
415444 x = self .reduction (x )
@@ -431,6 +460,7 @@ def __init__(
431460 head_dim : Optional [int ] = None ,
432461 window_size : _int_or_tuple_2_t = 7 ,
433462 always_partition : bool = False ,
463+ dynamic_mask : bool = False ,
434464 mlp_ratio : float = 4. ,
435465 qkv_bias : bool = True ,
436466 proj_drop : float = 0. ,
@@ -485,6 +515,7 @@ def __init__(
485515 window_size = window_size ,
486516 shift_size = 0 if (i % 2 == 0 ) else shift_size ,
487517 always_partition = always_partition ,
518+ dynamic_mask = dynamic_mask ,
488519 mlp_ratio = mlp_ratio ,
489520 qkv_bias = qkv_bias ,
490521 proj_drop = proj_drop ,
@@ -500,11 +531,12 @@ def set_input_size(
500531 window_size : int ,
501532 always_partition : Optional [bool ] = None ,
502533 ):
503- """Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
534+ """ Updates the resolution, window size and so the pair-wise relative positions.
504535
505536 Args:
506- feat_size (Tuple[int, int]): New input resolution
507- window_size (int): New window size
537+ feat_size: New input (feature) resolution
538+ window_size: New window size
539+ always_partition: Always partition / shift the window
508540 """
509541 self .input_resolution = feat_size
510542 if isinstance (self .downsample , nn .Identity ):
@@ -548,6 +580,7 @@ def __init__(
548580 head_dim : Optional [int ] = None ,
549581 window_size : _int_or_tuple_2_t = 7 ,
550582 always_partition : bool = False ,
583+ strict_img_size : bool = True ,
551584 mlp_ratio : float = 4. ,
552585 qkv_bias : bool = True ,
553586 drop_rate : float = 0. ,
@@ -599,9 +632,10 @@ def __init__(
599632 in_chans = in_chans ,
600633 embed_dim = embed_dim [0 ],
601634 norm_layer = norm_layer ,
635+ strict_img_size = strict_img_size ,
602636 output_fmt = 'NHWC' ,
603637 )
604- self . patch_grid = self .patch_embed .grid_size
638+ patch_grid = self .patch_embed .grid_size
605639
606640 # build layers
607641 head_dim = to_ntuple (self .num_layers )(head_dim )
@@ -621,15 +655,16 @@ def __init__(
621655 dim = in_dim ,
622656 out_dim = out_dim ,
623657 input_resolution = (
624- self . patch_grid [0 ] // scale ,
625- self . patch_grid [1 ] // scale
658+ patch_grid [0 ] // scale ,
659+ patch_grid [1 ] // scale
626660 ),
627661 depth = depths [i ],
628662 downsample = i > 0 ,
629663 num_heads = num_heads [i ],
630664 head_dim = head_dim [i ],
631665 window_size = window_size [i ],
632666 always_partition = always_partition ,
667+ dynamic_mask = not strict_img_size ,
633668 mlp_ratio = mlp_ratio [i ],
634669 qkv_bias = qkv_bias ,
635670 proj_drop = proj_drop_rate ,
@@ -673,27 +708,29 @@ def set_input_size(
673708 img_size : Optional [Tuple [int , int ]] = None ,
674709 patch_size : Optional [Tuple [int , int ]] = None ,
675710 window_size : Optional [Tuple [int , int ]] = None ,
676- window_ratio : int = 32 ,
711+ window_ratio : int = 8 ,
677712 always_partition : Optional [bool ] = None ,
678713 ) -> None :
679714 """ Updates the image resolution and window size.
680715
681716 Args:
682- img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
683- window_size (Optional[int]): New window size, if None based on new_img_size // window_div
684- window_ratio (int): divisor for calculating window size from image size
717+ img_size: New input resolution, if None current resolution is used
718+ patch_size (Optional[Tuple[int, int]): New patch size, if None use current patch size
719+ window_size: New window size, if None based on new_img_size // window_div
720+ window_ratio: divisor for calculating window size from grid size
721+ always_partition: always partition into windows and shift (even if window size < feat size)
685722 """
686723 if img_size is not None or patch_size is not None :
687724 self .patch_embed .set_input_size (img_size = img_size , patch_size = patch_size )
688- self .patch_grid = self .patch_embed .grid_size
725+ patch_grid = self .patch_embed .grid_size
726+
689727 if window_size is None :
690- img_size = self . patch_embed . img_size
691- window_size = tuple ([ s // window_ratio for s in img_size ])
728+ window_size = tuple ([ pg // window_ratio for pg in patch_grid ])
729+
692730 for index , stage in enumerate (self .layers ):
693731 stage_scale = 2 ** max (index - 1 , 0 )
694- print (self .patch_grid , stage_scale )
695732 stage .set_input_size (
696- feat_size = (self . patch_grid [0 ] // stage_scale , self . patch_grid [1 ] // stage_scale ),
733+ feat_size = (patch_grid [0 ] // stage_scale , patch_grid [1 ] // stage_scale ),
697734 window_size = window_size ,
698735 always_partition = always_partition ,
699736 )
0 commit comments