Skip to content

Commit 1618527

Browse files
committed
Add layer scale and parallel blocks to vision_transformer
1 parent c42be74 commit 1618527

File tree

1 file changed

+102
-9
lines changed

1 file changed

+102
-9
lines changed

timm/models/vision_transformer.py

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ def _cfg(url='', **kwargs):
170170
'/vit_base_patch16_224_1k_miil_84_4.pth',
171171
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
172172
),
173+
174+
# experimental
175+
'vit_small_patch16_36x1_224': _cfg(url=''),
176+
'vit_small_patch16_18x2_224': _cfg(url=''),
177+
'vit_base_patch16_18x2_224': _cfg(url=''),
173178
}
174179

175180

@@ -201,28 +206,81 @@ def forward(self, x):
201206
return x
202207

203208

209+
class LayerScale(nn.Module):
210+
def __init__(self, dim, init_values=1e-5, inplace=False):
211+
super().__init__()
212+
self.inplace = inplace
213+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
214+
215+
def forward(self, x):
216+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
217+
218+
204219
class Block(nn.Module):
205220

206221
def __init__(
207-
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
222+
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
208223
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
209224
super().__init__()
210225
self.norm1 = norm_layer(dim)
211226
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
227+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
212228
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
213229
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
214230

215231
self.norm2 = norm_layer(dim)
216232
mlp_hidden_dim = int(dim * mlp_ratio)
217233
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
234+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
218235
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
219236

220237
def forward(self, x):
221-
x = x + self.drop_path1(self.attn(self.norm1(x)))
222-
x = x + self.drop_path2(self.mlp(self.norm2(x)))
238+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
239+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
223240
return x
224241

225242

243+
class ParallelBlock(nn.Module):
244+
245+
def __init__(
246+
self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None,
247+
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
248+
super().__init__()
249+
self.num_parallel = num_parallel
250+
self.attns = nn.ModuleList()
251+
self.ffns = nn.ModuleList()
252+
for _ in range(num_parallel):
253+
self.attns.append(nn.Sequential(OrderedDict([
254+
('norm', norm_layer(dim)),
255+
('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),
256+
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
257+
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
258+
])))
259+
self.ffns.append(nn.Sequential(OrderedDict([
260+
('norm', norm_layer(dim)),
261+
('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),
262+
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
263+
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
264+
])))
265+
266+
def _forward_jit(self, x):
267+
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
268+
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
269+
return x
270+
271+
@torch.jit.ignore
272+
def _forward(self, x):
273+
x = x + sum(attn(x) for attn in self.attns)
274+
x = x + sum(ffn(x) for ffn in self.ffns)
275+
return x
276+
277+
def forward(self, x):
278+
if torch.jit.is_scripting() or torch.jit.is_tracing():
279+
return self._forward_jit(x)
280+
else:
281+
return self._forward(x)
282+
283+
226284
class VisionTransformer(nn.Module):
227285
""" Vision Transformer
228286
@@ -233,8 +291,8 @@ class VisionTransformer(nn.Module):
233291
def __init__(
234292
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
235293
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
236-
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
237-
embed_layer=PatchEmbed, norm_layer=None, act_layer=None):
294+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
295+
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
238296
"""
239297
Args:
240298
img_size (int, tuple): input image size
@@ -248,10 +306,11 @@ def __init__(
248306
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
249307
qkv_bias (bool): enable bias for qkv if True
250308
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
251-
weight_init: (str): weight init scheme
252309
drop_rate (float): dropout rate
253310
attn_drop_rate (float): attention dropout rate
254311
drop_path_rate (float): stochastic depth rate
312+
weight_init: (str): weight init scheme
313+
init_values: (float): layer-scale init values
255314
embed_layer (nn.Module): patch embedding layer
256315
norm_layer: (nn.Module): normalization layer
257316
act_layer: (nn.Module): MLP activation layer
@@ -277,9 +336,9 @@ def __init__(
277336

278337
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
279338
self.blocks = nn.Sequential(*[
280-
Block(
281-
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
282-
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
339+
block_fn(
340+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
341+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
283342
for i in range(depth)])
284343
use_fc_norm = self.global_pool == 'avg'
285344
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
@@ -941,3 +1000,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
9411000
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
9421001
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
9431002
return model
1003+
1004+
1005+
@register_model
1006+
def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
1007+
""" ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.
1008+
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1009+
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
1010+
"""
1011+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs)
1012+
model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs)
1013+
return model
1014+
1015+
1016+
@register_model
1017+
def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
1018+
""" ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
1019+
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1020+
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
1021+
"""
1022+
model_kwargs = dict(
1023+
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
1024+
model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
1025+
return model
1026+
1027+
1028+
@register_model
1029+
def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
1030+
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
1031+
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1032+
"""
1033+
model_kwargs = dict(
1034+
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
1035+
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
1036+
return model

0 commit comments

Comments
 (0)