88import logging
99from functools import partial
1010from collections import OrderedDict
11+ from dataclasses import dataclass
1112from typing import Optional , Tuple
1213
1314import torch
1617from torch .utils .checkpoint import checkpoint
1718
1819from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
19- from .helpers import build_model_with_cfg , named_apply
20+ from .helpers import build_model_with_cfg , resolve_pretrained_cfg , named_apply
2021from .layers import PatchEmbed , Mlp , DropPath , trunc_normal_ , lecun_normal_ , to_2tuple
2122from .registry import register_model
2223
@@ -47,9 +48,16 @@ def _cfg(url='', **kwargs):
4748 'vit_relpos_base_patch16_224' : _cfg (
4849 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth' ),
4950
51+ 'vit_srelpos_small_patch16_224' : _cfg (
52+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth' ),
53+ 'vit_srelpos_medium_patch16_224' : _cfg (
54+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth' ),
55+
56+ 'vit_relpos_medium_patch16_cls_224' : _cfg (
57+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth' ),
5058 'vit_relpos_base_patch16_cls_224' : _cfg (
5159 url = '' ),
52- 'vit_relpos_base_patch16_gapcls_224 ' : _cfg (
60+ 'vit_relpos_base_patch16_clsgap_224 ' : _cfg (
5361 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth' ),
5462
5563 'vit_relpos_small_patch16_rpn_224' : _cfg (url = '' ),
@@ -59,35 +67,43 @@ def _cfg(url='', **kwargs):
5967}
6068
6169
62- def gen_relative_position_index (win_size : Tuple [int , int ], class_token : int = 0 ) -> torch .Tensor :
63- # cut and paste w/ modifications from swin / beit codebase
64- # cls to token & token 2 cls & cls to cls
70+ def gen_relative_position_index (
71+ q_size : Tuple [int , int ],
72+ k_size : Tuple [int , int ] = None ,
73+ class_token : bool = False ) -> torch .Tensor :
74+ # Adapted with significant modifications from Swin / BeiT codebases
6575 # get pair-wise relative position index for each token inside the window
66- window_area = win_size [0 ] * win_size [1 ]
67- coords = torch .stack (torch .meshgrid ([torch .arange (win_size [0 ]), torch .arange (win_size [1 ])])).flatten (1 ) # 2, Wh, Ww
68- relative_coords = coords [:, :, None ] - coords [:, None , :] # 2, Wh*Ww, Wh*Ww
69- relative_coords = relative_coords .permute (1 , 2 , 0 ).contiguous () # Wh*Ww, Wh*Ww, 2
70- relative_coords [:, :, 0 ] += win_size [0 ] - 1 # shift to start from 0
71- relative_coords [:, :, 1 ] += win_size [1 ] - 1
72- relative_coords [:, :, 0 ] *= 2 * win_size [1 ] - 1
76+ q_coords = torch .stack (torch .meshgrid ([torch .arange (q_size [0 ]), torch .arange (q_size [1 ])])).flatten (1 ) # 2, Wh, Ww
77+ if k_size is None :
78+ k_coords = q_coords
79+ k_size = q_size
80+ else :
81+ # different q vs k sizes is a WIP
82+ k_coords = torch .stack (torch .meshgrid ([torch .arange (k_size [0 ]), torch .arange (k_size [1 ])])).flatten (1 )
83+ relative_coords = q_coords [:, :, None ] - k_coords [:, None , :] # 2, Wh*Ww, Wh*Ww
84+ relative_coords = relative_coords .permute (1 , 2 , 0 ) # Wh*Ww, Wh*Ww, 2
85+ _ , relative_position_index = torch .unique (relative_coords .view (- 1 , 2 ), return_inverse = True , dim = 0 )
86+
7387 if class_token :
74- num_relative_distance = (2 * win_size [0 ] - 1 ) * (2 * win_size [1 ] - 1 ) + 3
75- relative_position_index = torch .zeros (size = (window_area + 1 ,) * 2 , dtype = relative_coords .dtype )
76- relative_position_index [1 :, 1 :] = relative_coords .sum (- 1 ) # Wh*Ww, Wh*Ww
88+ # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
89+ # NOTE not intended or tested with MLP log-coords
90+ max_size = (max (q_size [0 ], k_size [0 ]), max (q_size [1 ], k_size [1 ]))
91+ num_relative_distance = (2 * max_size [0 ] - 1 ) * (2 * max_size [1 ] - 1 ) + 3
92+ relative_position_index = F .pad (relative_position_index , [1 , 0 , 1 , 0 ])
7793 relative_position_index [0 , 0 :] = num_relative_distance - 3
7894 relative_position_index [0 :, 0 ] = num_relative_distance - 2
7995 relative_position_index [0 , 0 ] = num_relative_distance - 1
80- else :
81- relative_position_index = relative_coords .sum (- 1 ) # Wh*Ww, Wh*Ww
82- return relative_position_index
96+
97+ return relative_position_index .contiguous ()
8398
8499
85100def gen_relative_log_coords (
86101 win_size : Tuple [int , int ],
87102 pretrained_win_size : Tuple [int , int ] = (0 , 0 ),
88- mode = 'swin'
103+ mode = 'swin' ,
89104):
90- # as per official swin-v2 impl, supporting timm swin-v2-cr coords as well
105+ assert mode in ('swin' , 'cr' , 'rw' )
106+ # as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
91107 relative_coords_h = torch .arange (- (win_size [0 ] - 1 ), win_size [0 ], dtype = torch .float32 )
92108 relative_coords_w = torch .arange (- (win_size [1 ] - 1 ), win_size [1 ], dtype = torch .float32 )
93109 relative_coords_table = torch .stack (torch .meshgrid ([relative_coords_h , relative_coords_w ]))
@@ -100,12 +116,22 @@ def gen_relative_log_coords(
100116 relative_coords_table [:, :, 0 ] /= (win_size [0 ] - 1 )
101117 relative_coords_table [:, :, 1 ] /= (win_size [1 ] - 1 )
102118 relative_coords_table *= 8 # normalize to -8, 8
103- scale = math .log2 (8 )
119+ relative_coords_table = torch .sign (relative_coords_table ) * torch .log2 (
120+ 1.0 + relative_coords_table .abs ()) / math .log2 (8 )
104121 else :
105- # FIXME we should support a form of normalization (to -1/1) for this mode?
106- scale = math .log2 (math .e )
107- relative_coords_table = torch .sign (relative_coords_table ) * torch .log2 (
108- 1.0 + relative_coords_table .abs ()) / scale
122+ if mode == 'rw' :
123+ # cr w/ window size normalization -> [-1,1] log coords
124+ relative_coords_table [:, :, 0 ] /= (win_size [0 ] - 1 )
125+ relative_coords_table [:, :, 1 ] /= (win_size [1 ] - 1 )
126+ relative_coords_table *= 8 # scale to -8, 8
127+ relative_coords_table = torch .sign (relative_coords_table ) * torch .log2 (
128+ 1.0 + relative_coords_table .abs ())
129+ relative_coords_table /= math .log2 (9 ) # -> [-1, 1]
130+ else :
131+ # mode == 'cr'
132+ relative_coords_table = torch .sign (relative_coords_table ) * torch .log (
133+ 1.0 + relative_coords_table .abs ())
134+
109135 return relative_coords_table
110136
111137
@@ -115,19 +141,29 @@ def __init__(
115141 window_size ,
116142 num_heads = 8 ,
117143 hidden_dim = 128 ,
118- class_token = False ,
144+ prefix_tokens = 0 ,
119145 mode = 'cr' ,
120146 pretrained_window_size = (0 , 0 )
121147 ):
122148 super ().__init__ ()
123149 self .window_size = window_size
124150 self .window_area = self .window_size [0 ] * self .window_size [1 ]
125- self .class_token = 1 if class_token else 0
151+ self .prefix_tokens = prefix_tokens
126152 self .num_heads = num_heads
127153 self .bias_shape = (self .window_area ,) * 2 + (num_heads ,)
128- self .apply_sigmoid = mode == 'swin'
154+ if mode == 'swin' :
155+ self .bias_act = nn .Sigmoid ()
156+ self .bias_gain = 16
157+ mlp_bias = (True , False )
158+ elif mode == 'rw' :
159+ self .bias_act = nn .Tanh ()
160+ self .bias_gain = 4
161+ mlp_bias = True
162+ else :
163+ self .bias_act = nn .Identity ()
164+ self .bias_gain = None
165+ mlp_bias = True
129166
130- mlp_bias = (True , False ) if mode == 'swin' else True
131167 self .mlp = Mlp (
132168 2 , # x, y
133169 hidden_features = hidden_dim ,
@@ -155,10 +191,11 @@ def get_bias(self) -> torch.Tensor:
155191 self .relative_position_index .view (- 1 )] # Wh*Ww,Wh*Ww,nH
156192 relative_position_bias = relative_position_bias .view (self .bias_shape )
157193 relative_position_bias = relative_position_bias .permute (2 , 0 , 1 )
158- if self .apply_sigmoid :
159- relative_position_bias = 16 * torch .sigmoid (relative_position_bias )
160- if self .class_token :
161- relative_position_bias = F .pad (relative_position_bias , [self .class_token , 0 , self .class_token , 0 ])
194+ relative_position_bias = self .bias_act (relative_position_bias )
195+ if self .bias_gain is not None :
196+ relative_position_bias = self .bias_gain * relative_position_bias
197+ if self .prefix_tokens :
198+ relative_position_bias = F .pad (relative_position_bias , [self .prefix_tokens , 0 , self .prefix_tokens , 0 ])
162199 return relative_position_bias .unsqueeze (0 ).contiguous ()
163200
164201 def forward (self , attn , shared_rel_pos : Optional [torch .Tensor ] = None ):
@@ -167,18 +204,18 @@ def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
167204
168205class RelPosBias (nn .Module ):
169206
170- def __init__ (self , window_size , num_heads , class_token = False ):
207+ def __init__ (self , window_size , num_heads , prefix_tokens = 0 ):
171208 super ().__init__ ()
209+ assert prefix_tokens <= 1
172210 self .window_size = window_size
173211 self .window_area = window_size [0 ] * window_size [1 ]
174- self .class_token = 1 if class_token else 0
175- self .bias_shape = (self .window_area + self .class_token ,) * 2 + (num_heads ,)
212+ self .bias_shape = (self .window_area + prefix_tokens ,) * 2 + (num_heads ,)
176213
177- num_relative_distance = (2 * window_size [0 ] - 1 ) * (2 * window_size [1 ] - 1 ) + 3 * self . class_token
214+ num_relative_distance = (2 * window_size [0 ] - 1 ) * (2 * window_size [1 ] - 1 ) + 3 * prefix_tokens
178215 self .relative_position_bias_table = nn .Parameter (torch .zeros (num_relative_distance , num_heads ))
179216 self .register_buffer (
180217 "relative_position_index" ,
181- gen_relative_position_index (self .window_size , class_token = self . class_token ),
218+ gen_relative_position_index (self .window_size , class_token = prefix_tokens > 0 ),
182219 persistent = False ,
183220 )
184221
@@ -306,11 +343,32 @@ class VisionTransformerRelPos(nn.Module):
306343 """
307344
308345 def __init__ (
309- self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , global_pool = 'avg' ,
310- embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , init_values = 1e-6 ,
311- class_token = False , fc_norm = False , rel_pos_type = 'mlp' , shared_rel_pos = False , rel_pos_dim = None ,
312- drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. , weight_init = 'skip' ,
313- embed_layer = PatchEmbed , norm_layer = None , act_layer = None , block_fn = RelPosBlock ):
346+ self ,
347+ img_size = 224 ,
348+ patch_size = 16 ,
349+ in_chans = 3 ,
350+ num_classes = 1000 ,
351+ global_pool = 'avg' ,
352+ embed_dim = 768 ,
353+ depth = 12 ,
354+ num_heads = 12 ,
355+ mlp_ratio = 4. ,
356+ qkv_bias = True ,
357+ init_values = 1e-6 ,
358+ class_token = False ,
359+ fc_norm = False ,
360+ rel_pos_type = 'mlp' ,
361+ rel_pos_dim = None ,
362+ shared_rel_pos = False ,
363+ drop_rate = 0. ,
364+ attn_drop_rate = 0. ,
365+ drop_path_rate = 0. ,
366+ weight_init = 'skip' ,
367+ embed_layer = PatchEmbed ,
368+ norm_layer = None ,
369+ act_layer = None ,
370+ block_fn = RelPosBlock
371+ ):
314372 """
315373 Args:
316374 img_size (int, tuple): input image size
@@ -345,19 +403,22 @@ def __init__(
345403 self .num_classes = num_classes
346404 self .global_pool = global_pool
347405 self .num_features = self .embed_dim = embed_dim # num_features for consistency with other models
348- self .num_tokens = 1 if class_token else 0
406+ self .num_prefix_tokens = 1 if class_token else 0
349407 self .grad_checkpointing = False
350408
351409 self .patch_embed = embed_layer (
352410 img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dim )
353411 feat_size = self .patch_embed .grid_size
354412
355- rel_pos_args = dict (window_size = feat_size , class_token = class_token )
413+ rel_pos_args = dict (window_size = feat_size , prefix_tokens = self . num_prefix_tokens )
356414 if rel_pos_type .startswith ('mlp' ):
357415 if rel_pos_dim :
358416 rel_pos_args ['hidden_dim' ] = rel_pos_dim
417+ # FIXME experimenting with different relpos log coord configs
359418 if 'swin' in rel_pos_type :
360419 rel_pos_args ['mode' ] = 'swin'
420+ elif 'rw' in rel_pos_type :
421+ rel_pos_args ['mode' ] = 'rw'
361422 rel_pos_cls = partial (RelPosMlp , ** rel_pos_args )
362423 else :
363424 rel_pos_cls = partial (RelPosBias , ** rel_pos_args )
@@ -367,7 +428,7 @@ def __init__(
367428 # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both...
368429 rel_pos_cls = None
369430
370- self .cls_token = nn .Parameter (torch .zeros (1 , self .num_tokens , embed_dim )) if self . num_tokens else None
431+ self .cls_token = nn .Parameter (torch .zeros (1 , self .num_prefix_tokens , embed_dim )) if class_token else None
371432
372433 dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )] # stochastic depth decay rule
373434 self .blocks = nn .ModuleList ([
@@ -434,7 +495,7 @@ def forward_features(self, x):
434495
435496 def forward_head (self , x , pre_logits : bool = False ):
436497 if self .global_pool :
437- x = x [:, self .num_tokens :].mean (dim = 1 ) if self .global_pool == 'avg' else x [:, 0 ]
498+ x = x [:, self .num_prefix_tokens :].mean (dim = 1 ) if self .global_pool == 'avg' else x [:, 0 ]
438499 x = self .fc_norm (x )
439500 return x if pre_logits else self .head (x )
440501
@@ -502,6 +563,41 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
502563 return model
503564
504565
566+ @register_model
567+ def vit_srelpos_small_patch16_224 (pretrained = False , ** kwargs ):
568+ """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
569+ """
570+ model_kwargs = dict (
571+ patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , qkv_bias = False , fc_norm = False ,
572+ rel_pos_dim = 384 , shared_rel_pos = True , ** kwargs )
573+ model = _create_vision_transformer_relpos ('vit_srelpos_small_patch16_224' , pretrained = pretrained , ** model_kwargs )
574+ return model
575+
576+
577+ @register_model
578+ def vit_srelpos_medium_patch16_224 (pretrained = False , ** kwargs ):
579+ """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
580+ """
581+ model_kwargs = dict (
582+ patch_size = 16 , embed_dim = 512 , depth = 12 , num_heads = 8 , qkv_bias = False , fc_norm = False ,
583+ rel_pos_dim = 512 , shared_rel_pos = True , ** kwargs )
584+ model = _create_vision_transformer_relpos (
585+ 'vit_srelpos_medium_patch16_224' , pretrained = pretrained , ** model_kwargs )
586+ return model
587+
588+
589+ @register_model
590+ def vit_relpos_medium_patch16_cls_224 (pretrained = False , ** kwargs ):
591+ """ ViT-Base (ViT-M/16) w/ relative log-coord position, class token present
592+ """
593+ model_kwargs = dict (
594+ patch_size = 16 , embed_dim = 512 , depth = 12 , num_heads = 8 , qkv_bias = False , fc_norm = False ,
595+ rel_pos_dim = 256 , class_token = True , global_pool = 'token' , ** kwargs )
596+ model = _create_vision_transformer_relpos (
597+ 'vit_relpos_medium_patch16_cls_224' , pretrained = pretrained , ** model_kwargs )
598+ return model
599+
600+
505601@register_model
506602def vit_relpos_base_patch16_cls_224 (pretrained = False , ** kwargs ):
507603 """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
@@ -514,14 +610,14 @@ def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
514610
515611
516612@register_model
517- def vit_relpos_base_patch16_gapcls_224 (pretrained = False , ** kwargs ):
613+ def vit_relpos_base_patch16_clsgap_224 (pretrained = False , ** kwargs ):
518614 """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
519615 NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled
520616 Leaving here for comparisons w/ a future re-train as it performs quite well.
521617 """
522618 model_kwargs = dict (
523619 patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False , fc_norm = True , class_token = True , ** kwargs )
524- model = _create_vision_transformer_relpos ('vit_relpos_base_patch16_gapcls_224 ' , pretrained = pretrained , ** model_kwargs )
620+ model = _create_vision_transformer_relpos ('vit_relpos_base_patch16_clsgap_224 ' , pretrained = pretrained , ** model_kwargs )
525621 return model
526622
527623
0 commit comments