Skip to content

Commit 4b30bae

Browse files
committed
Add updated vit_relpos weights, and impl w/ support for official swin-v2 differences for relpos. Add bias control support for MLP layers
1 parent d4c0588 commit 4b30bae

File tree

3 files changed

+178
-41
lines changed

3 files changed

+178
-41
lines changed

timm/models/layers/mlp.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@
1010
class Mlp(nn.Module):
1111
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
1212
"""
13-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
13+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
1414
super().__init__()
1515
out_features = out_features or in_features
1616
hidden_features = hidden_features or in_features
17+
bias = to_2tuple(bias)
1718
drop_probs = to_2tuple(drop)
1819

19-
self.fc1 = nn.Linear(in_features, hidden_features)
20+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
2021
self.act = act_layer()
2122
self.drop1 = nn.Dropout(drop_probs[0])
22-
self.fc2 = nn.Linear(hidden_features, out_features)
23+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
2324
self.drop2 = nn.Dropout(drop_probs[1])
2425

2526
def forward(self, x):
@@ -35,17 +36,18 @@ class GluMlp(nn.Module):
3536
""" MLP w/ GLU style gating
3637
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
3738
"""
38-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
39+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.):
3940
super().__init__()
4041
out_features = out_features or in_features
4142
hidden_features = hidden_features or in_features
4243
assert hidden_features % 2 == 0
44+
bias = to_2tuple(bias)
4345
drop_probs = to_2tuple(drop)
4446

45-
self.fc1 = nn.Linear(in_features, hidden_features)
47+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
4648
self.act = act_layer()
4749
self.drop1 = nn.Dropout(drop_probs[0])
48-
self.fc2 = nn.Linear(hidden_features // 2, out_features)
50+
self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])
4951
self.drop2 = nn.Dropout(drop_probs[1])
5052

5153
def init_weights(self):
@@ -67,14 +69,16 @@ def forward(self, x):
6769
class GatedMlp(nn.Module):
6870
""" MLP as used in gMLP
6971
"""
70-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
71-
gate_layer=None, drop=0.):
72+
def __init__(
73+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
74+
gate_layer=None, bias=True, drop=0.):
7275
super().__init__()
7376
out_features = out_features or in_features
7477
hidden_features = hidden_features or in_features
78+
bias = to_2tuple(bias)
7579
drop_probs = to_2tuple(drop)
7680

77-
self.fc1 = nn.Linear(in_features, hidden_features)
81+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
7882
self.act = act_layer()
7983
self.drop1 = nn.Dropout(drop_probs[0])
8084
if gate_layer is not None:
@@ -83,7 +87,7 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
8387
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
8488
else:
8589
self.gate = nn.Identity()
86-
self.fc2 = nn.Linear(hidden_features, out_features)
90+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
8791
self.drop2 = nn.Dropout(drop_probs[1])
8892

8993
def forward(self, x):
@@ -100,15 +104,18 @@ class ConvMlp(nn.Module):
100104
""" MLP using 1x1 convs that keeps spatial dims
101105
"""
102106
def __init__(
103-
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
107+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
108+
norm_layer=None, bias=True, drop=0.):
104109
super().__init__()
105110
out_features = out_features or in_features
106111
hidden_features = hidden_features or in_features
107-
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
112+
bias = to_2tuple(bias)
113+
114+
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])
108115
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
109116
self.act = act_layer()
110-
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
111117
self.drop = nn.Dropout(drop)
118+
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])
112119

113120
def forward(self, x):
114121
x = self.fc1(x)

timm/models/swin_transformer_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def __init__(
450450

451451
def forward(self, x):
452452
for blk in self.blocks:
453-
if not torch.jit.is_scripting() and self.grad_checkpointing:
453+
if self.grad_checkpointing and not torch.jit.is_scripting():
454454
x = checkpoint.checkpoint(blk, x)
455455
else:
456456
x = blk(x)

timm/models/vision_transformer_relpos.py

Lines changed: 157 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
""" Relative Position Vision Transformer (ViT) in PyTorch
22
3+
NOTE: these models are experimental / WIP, expect changes
4+
35
Hacked together by / Copyright 2022, Ross Wightman
46
"""
57
import math
@@ -37,9 +39,23 @@ def _cfg(url='', **kwargs):
3739
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
3840
input_size=(3, 256, 256)),
3941
'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)),
40-
'vit_relpos_base_patch16_rpn_224': _cfg(url=''),
42+
43+
'vit_relpos_small_patch16_224': _cfg(
44+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth'),
45+
'vit_relpos_medium_patch16_224': _cfg(
46+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth'),
4147
'vit_relpos_base_patch16_224': _cfg(
4248
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),
49+
50+
'vit_relpos_base_patch16_cls_224': _cfg(
51+
url=''),
52+
'vit_relpos_base_patch16_gapcls_224': _cfg(
53+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'),
54+
55+
'vit_relpos_small_patch16_rpn_224': _cfg(url=''),
56+
'vit_relpos_medium_patch16_rpn_224': _cfg(
57+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth'),
58+
'vit_relpos_base_patch16_rpn_224': _cfg(url=''),
4359
}
4460

4561

@@ -66,43 +82,84 @@ def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0)
6682
return relative_position_index
6783

6884

69-
def gen_relative_position_log(win_size: Tuple[int, int]) -> torch.Tensor:
70-
"""Method initializes the pair-wise relative positions to compute the positional biases."""
71-
coordinates = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1)
72-
relative_coords = coordinates[:, :, None] - coordinates[:, None, :]
73-
relative_coords = relative_coords.permute(1, 2, 0).float()
74-
relative_coordinates_log = torch.sign(relative_coords) * torch.log(1.0 + relative_coords.abs())
75-
return relative_coordinates_log
85+
def gen_relative_log_coords(
86+
win_size: Tuple[int, int],
87+
pretrained_win_size: Tuple[int, int] = (0, 0),
88+
mode='swin'
89+
):
90+
# as per official swin-v2 impl, supporting timm swin-v2-cr coords as well
91+
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
92+
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
93+
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
94+
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
95+
if mode == 'swin':
96+
if pretrained_win_size[0] > 0:
97+
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
98+
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
99+
else:
100+
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
101+
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
102+
relative_coords_table *= 8 # normalize to -8, 8
103+
scale = math.log2(8)
104+
else:
105+
# FIXME we should support a form of normalization (to -1/1) for this mode?
106+
scale = math.log2(math.e)
107+
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
108+
1.0 + relative_coords_table.abs()) / scale
109+
return relative_coords_table
76110

77111

78112
class RelPosMlp(nn.Module):
79-
# based on timm swin-v2 impl
80-
def __init__(self, window_size, num_heads=8, hidden_dim=32, class_token=False):
113+
def __init__(
114+
self,
115+
window_size,
116+
num_heads=8,
117+
hidden_dim=128,
118+
class_token=False,
119+
mode='cr',
120+
pretrained_window_size=(0, 0)
121+
):
81122
super().__init__()
82123
self.window_size = window_size
83124
self.window_area = self.window_size[0] * self.window_size[1]
84125
self.class_token = 1 if class_token else 0
85126
self.num_heads = num_heads
127+
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
128+
self.apply_sigmoid = mode == 'swin'
86129

130+
mlp_bias = (True, False) if mode == 'swin' else True
87131
self.mlp = Mlp(
88132
2, # x, y
89-
hidden_features=min(128, hidden_dim * num_heads),
133+
hidden_features=hidden_dim,
90134
out_features=num_heads,
91135
act_layer=nn.ReLU,
136+
bias=mlp_bias,
92137
drop=(0.125, 0.)
93138
)
94139

95140
self.register_buffer(
96-
'rel_coords_log',
97-
gen_relative_position_log(window_size),
98-
persistent=False
99-
)
141+
"relative_position_index",
142+
gen_relative_position_index(window_size),
143+
persistent=False)
144+
145+
# get relative_coords_table
146+
self.register_buffer(
147+
"rel_coords_log",
148+
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
149+
persistent=False)
100150

101151
def get_bias(self) -> torch.Tensor:
102-
relative_position_bias = self.mlp(self.rel_coords_log).permute(2, 0, 1).unsqueeze(0)
152+
relative_position_bias = self.mlp(self.rel_coords_log)
153+
if self.relative_position_index is not None:
154+
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
155+
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
156+
relative_position_bias = relative_position_bias.view(self.bias_shape)
157+
relative_position_bias = relative_position_bias.permute(2, 0, 1)
158+
if self.apply_sigmoid:
159+
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
103160
if self.class_token:
104161
relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0])
105-
return relative_position_bias
162+
return relative_position_bias.unsqueeze(0).contiguous()
106163

107164
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
108165
return attn + self.get_bias()
@@ -131,10 +188,10 @@ def init_weights(self):
131188
trunc_normal_(self.relative_position_bias_table, std=.02)
132189

133190
def get_bias(self) -> torch.Tensor:
134-
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
135-
self.bias_shape) # win_h * win_w, win_h * win_w, num_heads
136-
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
137-
return relative_position_bias
191+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
192+
# win_h * win_w, win_h * win_w, num_heads
193+
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
194+
return relative_position_bias.unsqueeze(0).contiguous()
138195

139196
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
140197
return attn + self.get_bias()
@@ -250,8 +307,8 @@ class VisionTransformerRelPos(nn.Module):
250307

251308
def __init__(
252309
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
253-
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
254-
class_token=False, rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False,
310+
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6,
311+
class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None,
255312
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip',
256313
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock):
257314
"""
@@ -268,9 +325,9 @@ def __init__(
268325
qkv_bias (bool): enable bias for qkv if True
269326
init_values: (float): layer-scale init values
270327
class_token (bool): use class token (default: False)
328+
fc_norm (bool): use pre classifier norm instead of pre-pool
271329
rel_pos_ty pe (str): type of relative position
272330
shared_rel_pos (bool): share relative pos across all blocks
273-
fc_norm (bool): use pre classifier norm instead of pre-pool
274331
drop_rate (float): dropout rate
275332
attn_drop_rate (float): attention dropout rate
276333
drop_path_rate (float): stochastic depth rate
@@ -295,8 +352,15 @@ def __init__(
295352
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
296353
feat_size = self.patch_embed.grid_size
297354

298-
rel_pos_cls = RelPosMlp if rel_pos_type == 'mlp' else RelPosBias
299-
rel_pos_cls = partial(rel_pos_cls, window_size=feat_size, class_token=class_token)
355+
rel_pos_args = dict(window_size=feat_size, class_token=class_token)
356+
if rel_pos_type.startswith('mlp'):
357+
if rel_pos_dim:
358+
rel_pos_args['hidden_dim'] = rel_pos_dim
359+
if 'swin' in rel_pos_type:
360+
rel_pos_args['mode'] = 'swin'
361+
rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
362+
else:
363+
rel_pos_cls = partial(RelPosBias, **rel_pos_args)
300364
self.shared_rel_pos = None
301365
if shared_rel_pos:
302366
self.shared_rel_pos = rel_pos_cls(num_heads=num_heads)
@@ -408,6 +472,26 @@ def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs):
408472
return model
409473

410474

475+
@register_model
476+
def vit_relpos_small_patch16_224(pretrained=False, **kwargs):
477+
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
478+
"""
479+
model_kwargs = dict(
480+
patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=True, **kwargs)
481+
model = _create_vision_transformer_relpos('vit_relpos_small_patch16_224', pretrained=pretrained, **model_kwargs)
482+
return model
483+
484+
485+
@register_model
486+
def vit_relpos_medium_patch16_224(pretrained=False, **kwargs):
487+
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
488+
"""
489+
model_kwargs = dict(
490+
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=True, **kwargs)
491+
model = _create_vision_transformer_relpos('vit_relpos_medium_patch16_224', pretrained=pretrained, **model_kwargs)
492+
return model
493+
494+
411495
@register_model
412496
def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
413497
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
@@ -418,11 +502,57 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
418502
return model
419503

420504

505+
@register_model
506+
def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
507+
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
508+
"""
509+
model_kwargs = dict(
510+
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False,
511+
class_token=True, global_pool='token', **kwargs)
512+
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_cls_224', pretrained=pretrained, **model_kwargs)
513+
return model
514+
515+
516+
@register_model
517+
def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs):
518+
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
519+
NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled
520+
Leaving here for comparisons w/ a future re-train as it performs quite well.
521+
"""
522+
model_kwargs = dict(
523+
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs)
524+
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs)
525+
return model
526+
527+
528+
@register_model
529+
def vit_relpos_small_patch16_rpn_224(pretrained=False, **kwargs):
530+
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
531+
"""
532+
model_kwargs = dict(
533+
patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
534+
model = _create_vision_transformer_relpos(
535+
'vit_relpos_small_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
536+
return model
537+
538+
539+
@register_model
540+
def vit_relpos_medium_patch16_rpn_224(pretrained=False, **kwargs):
541+
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
542+
"""
543+
model_kwargs = dict(
544+
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
545+
model = _create_vision_transformer_relpos(
546+
'vit_relpos_medium_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
547+
return model
548+
549+
421550
@register_model
422551
def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs):
423552
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
424553
"""
425554
model_kwargs = dict(
426555
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
427-
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
556+
model = _create_vision_transformer_relpos(
557+
'vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
428558
return model

0 commit comments

Comments
 (0)