@@ -170,6 +170,11 @@ def _cfg(url='', **kwargs):
170170 '/vit_base_patch16_224_1k_miil_84_4.pth' ,
171171 mean = (0 , 0 , 0 ), std = (1 , 1 , 1 ), crop_pct = 0.875 , interpolation = 'bilinear' ,
172172 ),
173+
174+ # experimental
175+ 'vit_small_patch16_36x1_224' : _cfg (url = '' ),
176+ 'vit_small_patch16_18x2_224' : _cfg (url = '' ),
177+ 'vit_base_patch16_18x2_224' : _cfg (url = '' ),
173178}
174179
175180
@@ -201,28 +206,81 @@ def forward(self, x):
201206 return x
202207
203208
209+ class LayerScale (nn .Module ):
210+ def __init__ (self , dim , init_values = 1e-5 , inplace = False ):
211+ super ().__init__ ()
212+ self .inplace = inplace
213+ self .gamma = nn .Parameter (init_values * torch .ones (dim ))
214+
215+ def forward (self , x ):
216+ return x .mul_ (self .gamma ) if self .inplace else x * self .gamma
217+
218+
204219class Block (nn .Module ):
205220
206221 def __init__ (
207- self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. ,
222+ self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. , init_values = None ,
208223 drop_path = 0. , act_layer = nn .GELU , norm_layer = nn .LayerNorm ):
209224 super ().__init__ ()
210225 self .norm1 = norm_layer (dim )
211226 self .attn = Attention (dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
227+ self .ls1 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
212228 # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
213229 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
214230
215231 self .norm2 = norm_layer (dim )
216232 mlp_hidden_dim = int (dim * mlp_ratio )
217233 self .mlp = Mlp (in_features = dim , hidden_features = mlp_hidden_dim , act_layer = act_layer , drop = drop )
234+ self .ls2 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
218235 self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
219236
220237 def forward (self , x ):
221- x = x + self .drop_path1 (self .attn (self .norm1 (x )))
222- x = x + self .drop_path2 (self .mlp (self .norm2 (x )))
238+ x = x + self .drop_path1 (self .ls1 ( self . attn (self .norm1 (x ) )))
239+ x = x + self .drop_path2 (self .ls2 ( self . mlp (self .norm2 (x ) )))
223240 return x
224241
225242
243+ class ParallelBlock (nn .Module ):
244+
245+ def __init__ (
246+ self , dim , num_heads , num_parallel = 2 , mlp_ratio = 4. , qkv_bias = False , init_values = None ,
247+ drop = 0. , attn_drop = 0. , drop_path = 0. , act_layer = nn .GELU , norm_layer = nn .LayerNorm ):
248+ super ().__init__ ()
249+ self .num_parallel = num_parallel
250+ self .attns = nn .ModuleList ()
251+ self .ffns = nn .ModuleList ()
252+ for _ in range (num_parallel ):
253+ self .attns .append (nn .Sequential (OrderedDict ([
254+ ('norm' , norm_layer (dim )),
255+ ('attn' , Attention (dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )),
256+ ('ls' , LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()),
257+ ('drop_path' , DropPath (drop_path ) if drop_path > 0. else nn .Identity ())
258+ ])))
259+ self .ffns .append (nn .Sequential (OrderedDict ([
260+ ('norm' , norm_layer (dim )),
261+ ('mlp' , Mlp (dim , hidden_features = int (dim * mlp_ratio ), act_layer = act_layer , drop = drop )),
262+ ('ls' , LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()),
263+ ('drop_path' , DropPath (drop_path ) if drop_path > 0. else nn .Identity ())
264+ ])))
265+
266+ def _forward_jit (self , x ):
267+ x = x + torch .stack ([attn (x ) for attn in self .attns ]).sum (dim = 0 )
268+ x = x + torch .stack ([ffn (x ) for ffn in self .ffns ]).sum (dim = 0 )
269+ return x
270+
271+ @torch .jit .ignore
272+ def _forward (self , x ):
273+ x = x + sum (attn (x ) for attn in self .attns )
274+ x = x + sum (ffn (x ) for ffn in self .ffns )
275+ return x
276+
277+ def forward (self , x ):
278+ if torch .jit .is_scripting () or torch .jit .is_tracing ():
279+ return self ._forward_jit (x )
280+ else :
281+ return self ._forward (x )
282+
283+
226284class VisionTransformer (nn .Module ):
227285 """ Vision Transformer
228286
@@ -233,8 +291,8 @@ class VisionTransformer(nn.Module):
233291 def __init__ (
234292 self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , global_pool = 'token' ,
235293 embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , representation_size = None ,
236- drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. , weight_init = '' ,
237- embed_layer = PatchEmbed , norm_layer = None , act_layer = None ):
294+ drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. , weight_init = '' , init_values = None ,
295+ embed_layer = PatchEmbed , norm_layer = None , act_layer = None , block_fn = Block ):
238296 """
239297 Args:
240298 img_size (int, tuple): input image size
@@ -248,10 +306,11 @@ def __init__(
248306 mlp_ratio (int): ratio of mlp hidden dim to embedding dim
249307 qkv_bias (bool): enable bias for qkv if True
250308 representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
251- weight_init: (str): weight init scheme
252309 drop_rate (float): dropout rate
253310 attn_drop_rate (float): attention dropout rate
254311 drop_path_rate (float): stochastic depth rate
312+ weight_init: (str): weight init scheme
313+ init_values: (float): layer-scale init values
255314 embed_layer (nn.Module): patch embedding layer
256315 norm_layer: (nn.Module): normalization layer
257316 act_layer: (nn.Module): MLP activation layer
@@ -277,9 +336,9 @@ def __init__(
277336
278337 dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )] # stochastic depth decay rule
279338 self .blocks = nn .Sequential (* [
280- Block (
281- dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias , drop = drop_rate ,
282- attn_drop = attn_drop_rate , drop_path = dpr [i ], norm_layer = norm_layer , act_layer = act_layer )
339+ block_fn (
340+ dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias , init_values = init_values ,
341+ drop = drop_rate , attn_drop = attn_drop_rate , drop_path = dpr [i ], norm_layer = norm_layer , act_layer = act_layer )
283342 for i in range (depth )])
284343 use_fc_norm = self .global_pool == 'avg'
285344 self .norm = norm_layer (embed_dim ) if not use_fc_norm else nn .Identity ()
@@ -941,3 +1000,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
9411000 model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False , ** kwargs )
9421001 model = _create_vision_transformer ('vit_base_patch16_224_miil' , pretrained = pretrained , ** model_kwargs )
9431002 return model
1003+
1004+
1005+ @register_model
1006+ def vit_small_patch16_36x1_224 (pretrained = False , ** kwargs ):
1007+ """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.
1008+ Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1009+ Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
1010+ """
1011+ model_kwargs = dict (patch_size = 16 , embed_dim = 384 , depth = 36 , num_heads = 6 , init_values = 1e-5 , ** kwargs )
1012+ model = _create_vision_transformer ('vit_small_patch16_36x1_224' , pretrained = pretrained , ** model_kwargs )
1013+ return model
1014+
1015+
1016+ @register_model
1017+ def vit_small_patch16_18x2_224 (pretrained = False , ** kwargs ):
1018+ """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
1019+ Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1020+ Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
1021+ """
1022+ model_kwargs = dict (
1023+ patch_size = 16 , embed_dim = 384 , depth = 18 , num_heads = 6 , init_values = 1e-5 , block_fn = ParallelBlock , ** kwargs )
1024+ model = _create_vision_transformer ('vit_small_patch16_18x2_224' , pretrained = pretrained , ** model_kwargs )
1025+ return model
1026+
1027+
1028+ @register_model
1029+ def vit_base_patch16_18x2_224 (pretrained = False , ** kwargs ):
1030+ """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
1031+ Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1032+ """
1033+ model_kwargs = dict (
1034+ patch_size = 16 , embed_dim = 768 , depth = 18 , num_heads = 12 , init_values = 1e-5 , block_fn = ParallelBlock , ** kwargs )
1035+ model = _create_vision_transformer ('vit_base_patch16_18x2_224' , pretrained = pretrained , ** model_kwargs )
1036+ return model
0 commit comments