11""" Relative Position Vision Transformer (ViT) in PyTorch
22
3+ NOTE: these models are experimental / WIP, expect changes
4+
35Hacked together by / Copyright 2022, Ross Wightman
46"""
57import math
@@ -37,9 +39,23 @@ def _cfg(url='', **kwargs):
3739 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth' ,
3840 input_size = (3 , 256 , 256 )),
3941 'vit_relpos_base_patch16_plus_240' : _cfg (url = '' , input_size = (3 , 240 , 240 )),
40- 'vit_relpos_base_patch16_rpn_224' : _cfg (url = '' ),
42+
43+ 'vit_relpos_small_patch16_224' : _cfg (
44+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth' ),
45+ 'vit_relpos_medium_patch16_224' : _cfg (
46+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth' ),
4147 'vit_relpos_base_patch16_224' : _cfg (
4248 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth' ),
49+
50+ 'vit_relpos_base_patch16_cls_224' : _cfg (
51+ url = '' ),
52+ 'vit_relpos_base_patch16_gapcls_224' : _cfg (
53+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth' ),
54+
55+ 'vit_relpos_small_patch16_rpn_224' : _cfg (url = '' ),
56+ 'vit_relpos_medium_patch16_rpn_224' : _cfg (
57+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth' ),
58+ 'vit_relpos_base_patch16_rpn_224' : _cfg (url = '' ),
4359}
4460
4561
@@ -66,43 +82,84 @@ def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0)
6682 return relative_position_index
6783
6884
69- def gen_relative_position_log (win_size : Tuple [int , int ]) -> torch .Tensor :
70- """Method initializes the pair-wise relative positions to compute the positional biases."""
71- coordinates = torch .stack (torch .meshgrid ([torch .arange (win_size [0 ]), torch .arange (win_size [1 ])])).flatten (1 )
72- relative_coords = coordinates [:, :, None ] - coordinates [:, None , :]
73- relative_coords = relative_coords .permute (1 , 2 , 0 ).float ()
74- relative_coordinates_log = torch .sign (relative_coords ) * torch .log (1.0 + relative_coords .abs ())
75- return relative_coordinates_log
85+ def gen_relative_log_coords (
86+ win_size : Tuple [int , int ],
87+ pretrained_win_size : Tuple [int , int ] = (0 , 0 ),
88+ mode = 'swin'
89+ ):
90+ # as per official swin-v2 impl, supporting timm swin-v2-cr coords as well
91+ relative_coords_h = torch .arange (- (win_size [0 ] - 1 ), win_size [0 ], dtype = torch .float32 )
92+ relative_coords_w = torch .arange (- (win_size [1 ] - 1 ), win_size [1 ], dtype = torch .float32 )
93+ relative_coords_table = torch .stack (torch .meshgrid ([relative_coords_h , relative_coords_w ]))
94+ relative_coords_table = relative_coords_table .permute (1 , 2 , 0 ).contiguous () # 2*Wh-1, 2*Ww-1, 2
95+ if mode == 'swin' :
96+ if pretrained_win_size [0 ] > 0 :
97+ relative_coords_table [:, :, 0 ] /= (pretrained_win_size [0 ] - 1 )
98+ relative_coords_table [:, :, 1 ] /= (pretrained_win_size [1 ] - 1 )
99+ else :
100+ relative_coords_table [:, :, 0 ] /= (win_size [0 ] - 1 )
101+ relative_coords_table [:, :, 1 ] /= (win_size [1 ] - 1 )
102+ relative_coords_table *= 8 # normalize to -8, 8
103+ scale = math .log2 (8 )
104+ else :
105+ # FIXME we should support a form of normalization (to -1/1) for this mode?
106+ scale = math .log2 (math .e )
107+ relative_coords_table = torch .sign (relative_coords_table ) * torch .log2 (
108+ 1.0 + relative_coords_table .abs ()) / scale
109+ return relative_coords_table
76110
77111
78112class RelPosMlp (nn .Module ):
79- # based on timm swin-v2 impl
80- def __init__ (self , window_size , num_heads = 8 , hidden_dim = 32 , class_token = False ):
113+ def __init__ (
114+ self ,
115+ window_size ,
116+ num_heads = 8 ,
117+ hidden_dim = 128 ,
118+ class_token = False ,
119+ mode = 'cr' ,
120+ pretrained_window_size = (0 , 0 )
121+ ):
81122 super ().__init__ ()
82123 self .window_size = window_size
83124 self .window_area = self .window_size [0 ] * self .window_size [1 ]
84125 self .class_token = 1 if class_token else 0
85126 self .num_heads = num_heads
127+ self .bias_shape = (self .window_area ,) * 2 + (num_heads ,)
128+ self .apply_sigmoid = mode == 'swin'
86129
130+ mlp_bias = (True , False ) if mode == 'swin' else True
87131 self .mlp = Mlp (
88132 2 , # x, y
89- hidden_features = min ( 128 , hidden_dim * num_heads ) ,
133+ hidden_features = hidden_dim ,
90134 out_features = num_heads ,
91135 act_layer = nn .ReLU ,
136+ bias = mlp_bias ,
92137 drop = (0.125 , 0. )
93138 )
94139
95140 self .register_buffer (
96- 'rel_coords_log' ,
97- gen_relative_position_log (window_size ),
98- persistent = False
99- )
141+ "relative_position_index" ,
142+ gen_relative_position_index (window_size ),
143+ persistent = False )
144+
145+ # get relative_coords_table
146+ self .register_buffer (
147+ "rel_coords_log" ,
148+ gen_relative_log_coords (window_size , pretrained_window_size , mode = mode ),
149+ persistent = False )
100150
101151 def get_bias (self ) -> torch .Tensor :
102- relative_position_bias = self .mlp (self .rel_coords_log ).permute (2 , 0 , 1 ).unsqueeze (0 )
152+ relative_position_bias = self .mlp (self .rel_coords_log )
153+ if self .relative_position_index is not None :
154+ relative_position_bias = relative_position_bias .view (- 1 , self .num_heads )[
155+ self .relative_position_index .view (- 1 )] # Wh*Ww,Wh*Ww,nH
156+ relative_position_bias = relative_position_bias .view (self .bias_shape )
157+ relative_position_bias = relative_position_bias .permute (2 , 0 , 1 )
158+ if self .apply_sigmoid :
159+ relative_position_bias = 16 * torch .sigmoid (relative_position_bias )
103160 if self .class_token :
104161 relative_position_bias = F .pad (relative_position_bias , [self .class_token , 0 , self .class_token , 0 ])
105- return relative_position_bias
162+ return relative_position_bias . unsqueeze ( 0 ). contiguous ()
106163
107164 def forward (self , attn , shared_rel_pos : Optional [torch .Tensor ] = None ):
108165 return attn + self .get_bias ()
@@ -131,10 +188,10 @@ def init_weights(self):
131188 trunc_normal_ (self .relative_position_bias_table , std = .02 )
132189
133190 def get_bias (self ) -> torch .Tensor :
134- relative_position_bias = self .relative_position_bias_table [self .relative_position_index .view (- 1 )]. view (
135- self . bias_shape ) # win_h * win_w, win_h * win_w, num_heads
136- relative_position_bias = relative_position_bias .permute (2 , 0 , 1 ). contiguous ( )
137- return relative_position_bias
191+ relative_position_bias = self .relative_position_bias_table [self .relative_position_index .view (- 1 )]
192+ # win_h * win_w, win_h * win_w, num_heads
193+ relative_position_bias = relative_position_bias .view ( self . bias_shape ). permute (2 , 0 , 1 )
194+ return relative_position_bias . unsqueeze ( 0 ). contiguous ()
138195
139196 def forward (self , attn , shared_rel_pos : Optional [torch .Tensor ] = None ):
140197 return attn + self .get_bias ()
@@ -250,8 +307,8 @@ class VisionTransformerRelPos(nn.Module):
250307
251308 def __init__ (
252309 self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , global_pool = 'avg' ,
253- embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , init_values = 1e-5 ,
254- class_token = False , rel_pos_type = 'mlp' , shared_rel_pos = False , fc_norm = False ,
310+ embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , init_values = 1e-6 ,
311+ class_token = False , fc_norm = False , rel_pos_type = 'mlp' , shared_rel_pos = False , rel_pos_dim = None ,
255312 drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. , weight_init = 'skip' ,
256313 embed_layer = PatchEmbed , norm_layer = None , act_layer = None , block_fn = RelPosBlock ):
257314 """
@@ -268,9 +325,9 @@ def __init__(
268325 qkv_bias (bool): enable bias for qkv if True
269326 init_values: (float): layer-scale init values
270327 class_token (bool): use class token (default: False)
328+ fc_norm (bool): use pre classifier norm instead of pre-pool
271329 rel_pos_ty pe (str): type of relative position
272330 shared_rel_pos (bool): share relative pos across all blocks
273- fc_norm (bool): use pre classifier norm instead of pre-pool
274331 drop_rate (float): dropout rate
275332 attn_drop_rate (float): attention dropout rate
276333 drop_path_rate (float): stochastic depth rate
@@ -295,8 +352,15 @@ def __init__(
295352 img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dim )
296353 feat_size = self .patch_embed .grid_size
297354
298- rel_pos_cls = RelPosMlp if rel_pos_type == 'mlp' else RelPosBias
299- rel_pos_cls = partial (rel_pos_cls , window_size = feat_size , class_token = class_token )
355+ rel_pos_args = dict (window_size = feat_size , class_token = class_token )
356+ if rel_pos_type .startswith ('mlp' ):
357+ if rel_pos_dim :
358+ rel_pos_args ['hidden_dim' ] = rel_pos_dim
359+ if 'swin' in rel_pos_type :
360+ rel_pos_args ['mode' ] = 'swin'
361+ rel_pos_cls = partial (RelPosMlp , ** rel_pos_args )
362+ else :
363+ rel_pos_cls = partial (RelPosBias , ** rel_pos_args )
300364 self .shared_rel_pos = None
301365 if shared_rel_pos :
302366 self .shared_rel_pos = rel_pos_cls (num_heads = num_heads )
@@ -408,6 +472,26 @@ def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs):
408472 return model
409473
410474
475+ @register_model
476+ def vit_relpos_small_patch16_224 (pretrained = False , ** kwargs ):
477+ """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
478+ """
479+ model_kwargs = dict (
480+ patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , qkv_bias = False , fc_norm = True , ** kwargs )
481+ model = _create_vision_transformer_relpos ('vit_relpos_small_patch16_224' , pretrained = pretrained , ** model_kwargs )
482+ return model
483+
484+
485+ @register_model
486+ def vit_relpos_medium_patch16_224 (pretrained = False , ** kwargs ):
487+ """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
488+ """
489+ model_kwargs = dict (
490+ patch_size = 16 , embed_dim = 512 , depth = 12 , num_heads = 8 , qkv_bias = False , fc_norm = True , ** kwargs )
491+ model = _create_vision_transformer_relpos ('vit_relpos_medium_patch16_224' , pretrained = pretrained , ** model_kwargs )
492+ return model
493+
494+
411495@register_model
412496def vit_relpos_base_patch16_224 (pretrained = False , ** kwargs ):
413497 """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
@@ -418,11 +502,57 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
418502 return model
419503
420504
505+ @register_model
506+ def vit_relpos_base_patch16_cls_224 (pretrained = False , ** kwargs ):
507+ """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
508+ """
509+ model_kwargs = dict (
510+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False ,
511+ class_token = True , global_pool = 'token' , ** kwargs )
512+ model = _create_vision_transformer_relpos ('vit_relpos_base_patch16_cls_224' , pretrained = pretrained , ** model_kwargs )
513+ return model
514+
515+
516+ @register_model
517+ def vit_relpos_base_patch16_gapcls_224 (pretrained = False , ** kwargs ):
518+ """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
519+ NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled
520+ Leaving here for comparisons w/ a future re-train as it performs quite well.
521+ """
522+ model_kwargs = dict (
523+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False , fc_norm = True , class_token = True , ** kwargs )
524+ model = _create_vision_transformer_relpos ('vit_relpos_base_patch16_gapcls_224' , pretrained = pretrained , ** model_kwargs )
525+ return model
526+
527+
528+ @register_model
529+ def vit_relpos_small_patch16_rpn_224 (pretrained = False , ** kwargs ):
530+ """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
531+ """
532+ model_kwargs = dict (
533+ patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , qkv_bias = False , block_fn = ResPostRelPosBlock , ** kwargs )
534+ model = _create_vision_transformer_relpos (
535+ 'vit_relpos_small_patch16_rpn_224' , pretrained = pretrained , ** model_kwargs )
536+ return model
537+
538+
539+ @register_model
540+ def vit_relpos_medium_patch16_rpn_224 (pretrained = False , ** kwargs ):
541+ """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
542+ """
543+ model_kwargs = dict (
544+ patch_size = 16 , embed_dim = 512 , depth = 12 , num_heads = 8 , qkv_bias = False , block_fn = ResPostRelPosBlock , ** kwargs )
545+ model = _create_vision_transformer_relpos (
546+ 'vit_relpos_medium_patch16_rpn_224' , pretrained = pretrained , ** model_kwargs )
547+ return model
548+
549+
421550@register_model
422551def vit_relpos_base_patch16_rpn_224 (pretrained = False , ** kwargs ):
423552 """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
424553 """
425554 model_kwargs = dict (
426555 patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False , block_fn = ResPostRelPosBlock , ** kwargs )
427- model = _create_vision_transformer_relpos ('vit_relpos_base_patch16_rpn_224' , pretrained = pretrained , ** model_kwargs )
556+ model = _create_vision_transformer_relpos (
557+ 'vit_relpos_base_patch16_rpn_224' , pretrained = pretrained , ** model_kwargs )
428558 return model
0 commit comments