Skip to content

Commit f5ca414

Browse files
committed
Adjust arg order for recent vit model args, add a few comments
1 parent 41dc49a commit f5ca414

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

timm/models/vision_transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ class VisionTransformer(nn.Module):
325325
def __init__(
326326
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
327327
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
328-
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', class_token=True,
329-
fc_norm=None, embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
328+
class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
329+
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
330330
"""
331331
Args:
332332
img_size (int, tuple): input image size
@@ -340,12 +340,12 @@ def __init__(
340340
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
341341
qkv_bias (bool): enable bias for qkv if True
342342
init_values: (float): layer-scale init values
343+
class_token (bool): use class token
344+
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
343345
drop_rate (float): dropout rate
344346
attn_drop_rate (float): attention dropout rate
345347
drop_path_rate (float): stochastic depth rate
346348
weight_init (str): weight init scheme
347-
class_token (bool): use class token
348-
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
349349
embed_layer (nn.Module): patch embedding layer
350350
norm_layer: (nn.Module): normalization layer
351351
act_layer: (nn.Module): MLP activation layer

timm/models/vision_transformer_relpos.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -240,35 +240,41 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
240240

241241
class VisionTransformerRelPos(nn.Module):
242242
""" Vision Transformer w/ Relative Position Bias
243+
244+
Differing from classic vit, this impl
245+
* uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed
246+
* defaults to no class token (can be enabled)
247+
* defaults to global avg pool for head (can be changed)
248+
* layer-scale (residual branch gain) enabled
243249
"""
244250

245251
def __init__(
246252
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
247-
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
248-
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', class_token=False,
249-
rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False,
253+
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
254+
class_token=False, rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False,
255+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip',
250256
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock):
251257
"""
252258
Args:
253259
img_size (int, tuple): input image size
254260
patch_size (int, tuple): patch size
255261
in_chans (int): number of input channels
256262
num_classes (int): number of classes for classification head
257-
global_pool (str): type of global pooling for final sequence (default: 'token')
263+
global_pool (str): type of global pooling for final sequence (default: 'avg')
258264
embed_dim (int): embedding dimension
259265
depth (int): depth of transformer
260266
num_heads (int): number of attention heads
261267
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
262268
qkv_bias (bool): enable bias for qkv if True
263269
init_values: (float): layer-scale init values
270+
class_token (bool): use class token (default: False)
271+
rel_pos_ty pe (str): type of relative position
272+
shared_rel_pos (bool): share relative pos across all blocks
273+
fc_norm (bool): use pre classifier norm instead of pre-pool
264274
drop_rate (float): dropout rate
265275
attn_drop_rate (float): attention dropout rate
266276
drop_path_rate (float): stochastic depth rate
267277
weight_init (str): weight init scheme
268-
class_token (bool): use class token (default: False)
269-
rel_pos_ty pe (str): type of relative position
270-
shared_rel_pos (bool): share relative pos across all blocks
271-
fc_norm (bool): use pre classifier norm
272278
embed_layer (nn.Module): patch embedding layer
273279
norm_layer: (nn.Module): normalization layer
274280
act_layer: (nn.Module): MLP activation layer
@@ -384,11 +390,10 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
384390

385391
@register_model
386392
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
387-
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
393+
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
388394
"""
389395
model_kwargs = dict(
390-
patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5,
391-
block_fn=ResPostRelPosBlock, **kwargs)
396+
patch_size=32, embed_dim=896, depth=12, num_heads=14, block_fn=ResPostRelPosBlock, **kwargs)
392397
model = _create_vision_transformer_relpos(
393398
'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs)
394399
return model
@@ -398,7 +403,7 @@ def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
398403
def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs):
399404
""" ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token
400405
"""
401-
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
406+
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, **kwargs)
402407
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
403408
return model
404409

@@ -408,8 +413,7 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
408413
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
409414
"""
410415
model_kwargs = dict(
411-
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
412-
fc_norm=True, **kwargs)
416+
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, **kwargs)
413417
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs)
414418
return model
415419

@@ -419,7 +423,6 @@ def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs):
419423
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
420424
"""
421425
model_kwargs = dict(
422-
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
423-
block_fn=ResPostRelPosBlock, **kwargs)
426+
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
424427
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
425428
return model

0 commit comments

Comments
 (0)