@@ -240,35 +240,41 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
240240
241241class VisionTransformerRelPos (nn .Module ):
242242 """ Vision Transformer w/ Relative Position Bias
243+
244+ Differing from classic vit, this impl
245+ * uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed
246+ * defaults to no class token (can be enabled)
247+ * defaults to global avg pool for head (can be changed)
248+ * layer-scale (residual branch gain) enabled
243249 """
244250
245251 def __init__ (
246252 self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , global_pool = 'avg' ,
247- embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , init_values = None ,
248- drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. , weight_init = 'skip' , class_token = False ,
249- rel_pos_type = 'mlp' , shared_rel_pos = False , fc_norm = False ,
253+ embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , init_values = 1e-5 ,
254+ class_token = False , rel_pos_type = 'mlp' , shared_rel_pos = False , fc_norm = False ,
255+ drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. , weight_init = 'skip' ,
250256 embed_layer = PatchEmbed , norm_layer = None , act_layer = None , block_fn = RelPosBlock ):
251257 """
252258 Args:
253259 img_size (int, tuple): input image size
254260 patch_size (int, tuple): patch size
255261 in_chans (int): number of input channels
256262 num_classes (int): number of classes for classification head
257- global_pool (str): type of global pooling for final sequence (default: 'token ')
263+ global_pool (str): type of global pooling for final sequence (default: 'avg ')
258264 embed_dim (int): embedding dimension
259265 depth (int): depth of transformer
260266 num_heads (int): number of attention heads
261267 mlp_ratio (int): ratio of mlp hidden dim to embedding dim
262268 qkv_bias (bool): enable bias for qkv if True
263269 init_values: (float): layer-scale init values
270+ class_token (bool): use class token (default: False)
271+ rel_pos_ty pe (str): type of relative position
272+ shared_rel_pos (bool): share relative pos across all blocks
273+ fc_norm (bool): use pre classifier norm instead of pre-pool
264274 drop_rate (float): dropout rate
265275 attn_drop_rate (float): attention dropout rate
266276 drop_path_rate (float): stochastic depth rate
267277 weight_init (str): weight init scheme
268- class_token (bool): use class token (default: False)
269- rel_pos_ty pe (str): type of relative position
270- shared_rel_pos (bool): share relative pos across all blocks
271- fc_norm (bool): use pre classifier norm
272278 embed_layer (nn.Module): patch embedding layer
273279 norm_layer: (nn.Module): normalization layer
274280 act_layer: (nn.Module): MLP activation layer
@@ -384,11 +390,10 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
384390
385391@register_model
386392def vit_relpos_base_patch32_plus_rpn_256 (pretrained = False , ** kwargs ):
387- """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
393+ """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
388394 """
389395 model_kwargs = dict (
390- patch_size = 32 , embed_dim = 896 , depth = 12 , num_heads = 14 , init_values = 1e-5 ,
391- block_fn = ResPostRelPosBlock , ** kwargs )
396+ patch_size = 32 , embed_dim = 896 , depth = 12 , num_heads = 14 , block_fn = ResPostRelPosBlock , ** kwargs )
392397 model = _create_vision_transformer_relpos (
393398 'vit_relpos_base_patch32_plus_rpn_256' , pretrained = pretrained , ** model_kwargs )
394399 return model
@@ -398,7 +403,7 @@ def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
398403def vit_relpos_base_patch16_plus_240 (pretrained = False , ** kwargs ):
399404 """ ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token
400405 """
401- model_kwargs = dict (patch_size = 16 , embed_dim = 896 , depth = 12 , num_heads = 14 , init_values = 1e-5 , ** kwargs )
406+ model_kwargs = dict (patch_size = 16 , embed_dim = 896 , depth = 12 , num_heads = 14 , ** kwargs )
402407 model = _create_vision_transformer_relpos ('vit_relpos_base_patch16_plus_240' , pretrained = pretrained , ** model_kwargs )
403408 return model
404409
@@ -408,8 +413,7 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
408413 """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
409414 """
410415 model_kwargs = dict (
411- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False , init_values = 1e-5 ,
412- fc_norm = True , ** kwargs )
416+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False , fc_norm = True , ** kwargs )
413417 model = _create_vision_transformer_relpos ('vit_relpos_base_patch16_224' , pretrained = pretrained , ** model_kwargs )
414418 return model
415419
@@ -419,7 +423,6 @@ def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs):
419423 """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
420424 """
421425 model_kwargs = dict (
422- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False , init_values = 1e-5 ,
423- block_fn = ResPostRelPosBlock , ** kwargs )
426+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False , block_fn = ResPostRelPosBlock , ** kwargs )
424427 model = _create_vision_transformer_relpos ('vit_relpos_base_patch16_rpn_224' , pretrained = pretrained , ** model_kwargs )
425428 return model
0 commit comments