3434
3535import torch
3636import torch .nn as nn
37+ import torch .nn .functional as F
3738import torch .utils .checkpoint as checkpoint
3839
3940from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
4041from .fx_features import register_notrace_function
4142from .helpers import build_model_with_cfg , named_apply
4243from .layers import DropPath , Mlp , to_2tuple , _assert
4344from .registry import register_model
44- from . vision_transformer import checkpoint_filter_fn
45+
4546
4647_logger = logging .getLogger (__name__ )
4748
@@ -186,12 +187,13 @@ def __init__(
186187 act_layer = nn .ReLU ,
187188 drop = (0.125 , 0. ) # FIXME should there be stochasticity, appears to 'overfit' without?
188189 )
189- self .register_parameter ("tau" , torch .nn .Parameter (torch .ones (num_heads )))
190+ # NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn
191+ self .logit_scale = nn .Parameter (torch .log (10 * torch .ones (num_heads )))
190192 self ._make_pair_wise_relative_positions ()
191193
192194 def _make_pair_wise_relative_positions (self ) -> None :
193195 """Method initializes the pair-wise relative positions to compute the positional biases."""
194- device = self .tau .device
196+ device = self .logit_scale .device
195197 coordinates = torch .stack (torch .meshgrid ([
196198 torch .arange (self .window_size [0 ], device = device ),
197199 torch .arange (self .window_size [1 ], device = device )]), dim = 0 ).flatten (1 )
@@ -250,10 +252,11 @@ def _forward_batch(
250252 query , key , value = qkv .unbind (0 )
251253
252254 # compute attention map with scaled cosine attention
253- denom = torch . norm (query , dim = - 1 , keepdim = True ) @ torch . norm (key , dim = - 1 , keepdim = True ).transpose (- 2 , - 1 )
254- attn = query @ key . transpose ( - 2 , - 1 ) / denom . clamp ( min = 1e-6 )
255- attn = attn / self . tau . clamp ( min = 0.01 ). reshape ( 1 , self . num_heads , 1 , 1 )
255+ attn = ( F . normalize (query , dim = - 1 ) @ F . normalize (key , dim = - 1 ).transpose (- 2 , - 1 ) )
256+ logit_scale = torch . clamp ( self . logit_scale . reshape ( 1 , self . num_heads , 1 , 1 ), max = math . log ( 1. / 0.01 )). exp ( )
257+ attn = attn * logit_scale
256258 attn = attn + self ._relative_positional_encodings ()
259+
257260 if mask is not None :
258261 # Apply mask if utilized
259262 num_win : int = mask .shape [0 ]
@@ -309,7 +312,7 @@ def __init__(
309312 window_size : Tuple [int , int ],
310313 shift_size : Tuple [int , int ] = (0 , 0 ),
311314 mlp_ratio : float = 4.0 ,
312- init_values : float = 0 ,
315+ init_values : Optional [ float ] = 0 ,
313316 drop : float = 0.0 ,
314317 drop_attn : float = 0.0 ,
315318 drop_path : float = 0.0 ,
@@ -323,7 +326,7 @@ def __init__(
323326 self .target_shift_size : Tuple [int , int ] = to_2tuple (shift_size )
324327 self .window_size , self .shift_size = self ._calc_window_shift (to_2tuple (window_size ))
325328 self .window_area = self .window_size [0 ] * self .window_size [1 ]
326- self .init_values : float = init_values
329+ self .init_values : Optional [ float ] = init_values
327330
328331 # attn branch
329332 self .attn = WindowMultiHeadAttention (
@@ -387,7 +390,7 @@ def _make_attention_mask(self) -> None:
387390
388391 def init_weights (self ):
389392 # extra, module specific weight init
390- if self .init_values :
393+ if self .init_values is not None :
391394 nn .init .constant_ (self .norm1 .weight , self .init_values )
392395 nn .init .constant_ (self .norm2 .weight , self .init_values )
393396
@@ -536,7 +539,7 @@ def __init__(
536539 feat_size : Tuple [int , int ],
537540 window_size : Tuple [int , int ],
538541 mlp_ratio : float = 4.0 ,
539- init_values : float = 0.0 ,
542+ init_values : Optional [ float ] = 0.0 ,
540543 drop : float = 0.0 ,
541544 drop_attn : float = 0.0 ,
542545 drop_path : Union [List [float ], float ] = 0.0 ,
@@ -650,7 +653,7 @@ def __init__(
650653 depths : Tuple [int , ...] = (2 , 2 , 6 , 2 ),
651654 num_heads : Tuple [int , ...] = (3 , 6 , 12 , 24 ),
652655 mlp_ratio : float = 4.0 ,
653- init_values : float = 0.0 ,
656+ init_values : Optional [ float ] = 0. ,
654657 drop_rate : float = 0.0 ,
655658 attn_drop_rate : float = 0.0 ,
656659 drop_path_rate : float = 0.0 ,
@@ -808,6 +811,21 @@ def init_weights(module: nn.Module, name: str = ''):
808811 module .init_weights ()
809812
810813
814+ def checkpoint_filter_fn (state_dict , model ):
815+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
816+ out_dict = {}
817+ if 'model' in state_dict :
818+ # For deit models
819+ state_dict = state_dict ['model' ]
820+ for k , v in state_dict .items ():
821+ if 'tau' in k :
822+ # convert old tau based checkpoints -> logit_scale (inverse)
823+ v = torch .log (1 / v )
824+ k = k .replace ('tau' , 'logit_scale' )
825+ out_dict [k ] = v
826+ return out_dict
827+
828+
811829def _create_swin_transformer_v2_cr (variant , pretrained = False , ** kwargs ):
812830 if kwargs .get ('features_only' , None ):
813831 raise RuntimeError ('features_only not implemented for Vision Transformer models.' )
@@ -890,7 +908,6 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs):
890908 embed_dim = 96 ,
891909 depths = (2 , 2 , 18 , 2 ),
892910 num_heads = (3 , 6 , 12 , 24 ),
893- init_values = 1e-5 ,
894911 extra_norm_stage = True ,
895912 ** kwargs
896913 )
@@ -928,7 +945,6 @@ def swinv2_cr_base_ns_224(pretrained=False, **kwargs):
928945 embed_dim = 128 ,
929946 depths = (2 , 2 , 18 , 2 ),
930947 num_heads = (4 , 8 , 16 , 32 ),
931- init_values = 1e-6 ,
932948 extra_norm_stage = True ,
933949 ** kwargs
934950 )
0 commit comments