@@ -76,10 +76,15 @@ def _cfg(url='', **kwargs):
7676 'swin_v2_cr_small_224' : _cfg (
7777 url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth" ,
7878 input_size = (3 , 224 , 224 ), crop_pct = 0.9 ),
79+ 'swin_v2_cr_small_ns_224' : _cfg (
80+ url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth" ,
81+ input_size = (3 , 224 , 224 ), crop_pct = 0.9 ),
7982 'swin_v2_cr_base_384' : _cfg (
8083 url = "" , input_size = (3 , 384 , 384 ), crop_pct = 1.0 ),
8184 'swin_v2_cr_base_224' : _cfg (
8285 url = "" , input_size = (3 , 224 , 224 ), crop_pct = 0.9 ),
86+ 'swin_v2_cr_base_ns_224' : _cfg (
87+ url = "" , input_size = (3 , 224 , 224 ), crop_pct = 0.9 ),
8388 'swin_v2_cr_large_384' : _cfg (
8489 url = "" , input_size = (3 , 384 , 384 ), crop_pct = 1.0 ),
8590 'swin_v2_cr_large_224' : _cfg (
@@ -179,7 +184,7 @@ def __init__(
179184 hidden_features = meta_hidden_dim ,
180185 out_features = num_heads ,
181186 act_layer = nn .ReLU ,
182- drop = 0.1 # FIXME should there be stochasticity, appears to 'overfit' without?
187+ drop = ( 0.125 , 0. ) # FIXME should there be stochasticity, appears to 'overfit' without?
183188 )
184189 self .register_parameter ("tau" , torch .nn .Parameter (torch .ones (num_heads )))
185190 self ._make_pair_wise_relative_positions ()
@@ -304,6 +309,7 @@ def __init__(
304309 window_size : Tuple [int , int ],
305310 shift_size : Tuple [int , int ] = (0 , 0 ),
306311 mlp_ratio : float = 4.0 ,
312+ init_values : float = 0 ,
307313 drop : float = 0.0 ,
308314 drop_attn : float = 0.0 ,
309315 drop_path : float = 0.0 ,
@@ -317,6 +323,7 @@ def __init__(
317323 self .target_shift_size : Tuple [int , int ] = to_2tuple (shift_size )
318324 self .window_size , self .shift_size = self ._calc_window_shift (to_2tuple (window_size ))
319325 self .window_area = self .window_size [0 ] * self .window_size [1 ]
326+ self .init_values : float = init_values
320327
321328 # attn branch
322329 self .attn = WindowMultiHeadAttention (
@@ -345,6 +352,7 @@ def __init__(
345352 self .norm3 = norm_layer (dim ) if extra_norm else nn .Identity ()
346353
347354 self ._make_attention_mask ()
355+ self .init_weights ()
348356
349357 def _calc_window_shift (self , target_window_size ):
350358 window_size = [f if f <= w else w for f , w in zip (self .feat_size , target_window_size )]
@@ -377,6 +385,12 @@ def _make_attention_mask(self) -> None:
377385 attn_mask = None
378386 self .register_buffer ("attn_mask" , attn_mask , persistent = False )
379387
388+ def init_weights (self ):
389+ # extra, module specific weight init
390+ if self .init_values :
391+ nn .init .constant_ (self .norm1 .weight , self .init_values )
392+ nn .init .constant_ (self .norm2 .weight , self .init_values )
393+
380394 def update_input_size (self , new_window_size : Tuple [int , int ], new_feat_size : Tuple [int , int ]) -> None :
381395 """Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
382396
@@ -435,7 +449,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
435449 Returns:
436450 output (torch.Tensor): Output tensor of the shape [B, C, H, W]
437451 """
438- # NOTE post-norm branches (op -> norm -> drop)
452+ # post-norm branches (op -> norm -> drop)
439453 x = x + self .drop_path1 (self .norm1 (self ._shifted_window_attn (x )))
440454 x = x + self .drop_path2 (self .norm2 (self .mlp (x )))
441455 x = self .norm3 (x ) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
@@ -522,6 +536,7 @@ def __init__(
522536 feat_size : Tuple [int , int ],
523537 window_size : Tuple [int , int ],
524538 mlp_ratio : float = 4.0 ,
539+ init_values : float = 0.0 ,
525540 drop : float = 0.0 ,
526541 drop_attn : float = 0.0 ,
527542 drop_path : Union [List [float ], float ] = 0.0 ,
@@ -552,6 +567,7 @@ def _extra_norm(index):
552567 window_size = window_size ,
553568 shift_size = tuple ([0 if ((index % 2 ) == 0 ) else w // 2 for w in window_size ]),
554569 mlp_ratio = mlp_ratio ,
570+ init_values = init_values ,
555571 drop = drop ,
556572 drop_attn = drop_attn ,
557573 drop_path = drop_path [index ] if isinstance (drop_path , list ) else drop_path ,
@@ -634,6 +650,7 @@ def __init__(
634650 depths : Tuple [int , ...] = (2 , 2 , 6 , 2 ),
635651 num_heads : Tuple [int , ...] = (3 , 6 , 12 , 24 ),
636652 mlp_ratio : float = 4.0 ,
653+ init_values : float = 0.0 ,
637654 drop_rate : float = 0.0 ,
638655 attn_drop_rate : float = 0.0 ,
639656 drop_path_rate : float = 0.0 ,
@@ -674,6 +691,7 @@ def __init__(
674691 num_heads = num_heads ,
675692 window_size = window_size ,
676693 mlp_ratio = mlp_ratio ,
694+ init_values = init_values ,
677695 drop = drop_rate ,
678696 drop_attn = attn_drop_rate ,
679697 drop_path = drop_path_rate [sum (depths [:index ]):sum (depths [:index + 1 ])],
@@ -786,6 +804,8 @@ def init_weights(module: nn.Module, name: str = ''):
786804 nn .init .xavier_uniform_ (module .weight )
787805 if module .bias is not None :
788806 nn .init .zeros_ (module .bias )
807+ elif hasattr (module , 'init_weights' ):
808+ module .init_weights ()
789809
790810
791811def _create_swin_transformer_v2_cr (variant , pretrained = False , ** kwargs ):
@@ -863,6 +883,20 @@ def swin_v2_cr_small_224(pretrained=False, **kwargs):
863883 return _create_swin_transformer_v2_cr ('swin_v2_cr_small_224' , pretrained = pretrained , ** model_kwargs )
864884
865885
886+ @register_model
887+ def swin_v2_cr_small_ns_224 (pretrained = False , ** kwargs ):
888+ """Swin-S V2 CR @ 224x224, trained ImageNet-1k"""
889+ model_kwargs = dict (
890+ embed_dim = 96 ,
891+ depths = (2 , 2 , 18 , 2 ),
892+ num_heads = (3 , 6 , 12 , 24 ),
893+ init_values = 1e-5 ,
894+ extra_norm_stage = True ,
895+ ** kwargs
896+ )
897+ return _create_swin_transformer_v2_cr ('swin_v2_cr_small_ns_224' , pretrained = pretrained , ** model_kwargs )
898+
899+
866900@register_model
867901def swin_v2_cr_base_384 (pretrained = False , ** kwargs ):
868902 """Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
@@ -887,6 +921,20 @@ def swin_v2_cr_base_224(pretrained=False, **kwargs):
887921 return _create_swin_transformer_v2_cr ('swin_v2_cr_base_224' , pretrained = pretrained , ** model_kwargs )
888922
889923
924+ @register_model
925+ def swin_v2_cr_base_ns_224 (pretrained = False , ** kwargs ):
926+ """Swin-B V2 CR @ 224x224, trained ImageNet-1k"""
927+ model_kwargs = dict (
928+ embed_dim = 128 ,
929+ depths = (2 , 2 , 18 , 2 ),
930+ num_heads = (4 , 8 , 16 , 32 ),
931+ init_values = 1e-6 ,
932+ extra_norm_stage = True ,
933+ ** kwargs
934+ )
935+ return _create_swin_transformer_v2_cr ('swin_v2_cr_base_ns_224' , pretrained = pretrained , ** model_kwargs )
936+
937+
890938@register_model
891939def swin_v2_cr_large_384 (pretrained = False , ** kwargs ):
892940 """Swin-L V2 CR @ 384x384, trained ImageNet-1k"""
0 commit comments