@@ -74,11 +74,11 @@ def _cfg(url='', **kwargs):
7474class ClassAttn (nn .Module ):
7575 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
7676 # with slight modifications to do CA
77- def __init__ (self , dim , num_heads = 8 , qkv_bias = False , qk_scale = None , attn_drop = 0. , proj_drop = 0. ):
77+ def __init__ (self , dim , num_heads = 8 , qkv_bias = False , attn_drop = 0. , proj_drop = 0. ):
7878 super ().__init__ ()
7979 self .num_heads = num_heads
8080 head_dim = dim // num_heads
81- self .scale = qk_scale or head_dim ** - 0.5
81+ self .scale = head_dim ** - 0.5
8282
8383 self .q = nn .Linear (dim , dim , bias = qkv_bias )
8484 self .k = nn .Linear (dim , dim , bias = qkv_bias )
@@ -110,13 +110,13 @@ class LayerScaleBlockClassAttn(nn.Module):
110110 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
111111 # with slight modifications to add CA and LayerScale
112112 def __init__ (
113- self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , qk_scale = None , drop = 0. , attn_drop = 0. ,
113+ self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. ,
114114 drop_path = 0. , act_layer = nn .GELU , norm_layer = nn .LayerNorm , attn_block = ClassAttn ,
115115 mlp_block = Mlp , init_values = 1e-4 ):
116116 super ().__init__ ()
117117 self .norm1 = norm_layer (dim )
118118 self .attn = attn_block (
119- dim , num_heads = num_heads , qkv_bias = qkv_bias , qk_scale = qk_scale , attn_drop = attn_drop , proj_drop = drop )
119+ dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
120120 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
121121 self .norm2 = norm_layer (dim )
122122 mlp_hidden_dim = int (dim * mlp_ratio )
@@ -134,14 +134,14 @@ def forward(self, x, x_cls):
134134class TalkingHeadAttn (nn .Module ):
135135 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
136136 # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
137- def __init__ (self , dim , num_heads = 8 , qkv_bias = False , qk_scale = None , attn_drop = 0. , proj_drop = 0. ):
137+ def __init__ (self , dim , num_heads = 8 , qkv_bias = False , attn_drop = 0. , proj_drop = 0. ):
138138 super ().__init__ ()
139139
140140 self .num_heads = num_heads
141141
142142 head_dim = dim // num_heads
143143
144- self .scale = qk_scale or head_dim ** - 0.5
144+ self .scale = head_dim ** - 0.5
145145
146146 self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
147147 self .attn_drop = nn .Dropout (attn_drop )
@@ -177,13 +177,13 @@ class LayerScaleBlock(nn.Module):
177177 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
178178 # with slight modifications to add layerScale
179179 def __init__ (
180- self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , qk_scale = None , drop = 0. , attn_drop = 0. ,
180+ self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. ,
181181 drop_path = 0. , act_layer = nn .GELU , norm_layer = nn .LayerNorm , attn_block = TalkingHeadAttn ,
182182 mlp_block = Mlp , init_values = 1e-4 ):
183183 super ().__init__ ()
184184 self .norm1 = norm_layer (dim )
185185 self .attn = attn_block (
186- dim , num_heads = num_heads , qkv_bias = qkv_bias , qk_scale = qk_scale , attn_drop = attn_drop , proj_drop = drop )
186+ dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
187187 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
188188 self .norm2 = norm_layer (dim )
189189 mlp_hidden_dim = int (dim * mlp_ratio )
@@ -202,7 +202,7 @@ class Cait(nn.Module):
202202 # with slight modifications to adapt to our cait models
203203 def __init__ (
204204 self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , embed_dim = 768 , depth = 12 ,
205- num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , qk_scale = None , drop_rate = 0. , attn_drop_rate = 0. ,
205+ num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , drop_rate = 0. , attn_drop_rate = 0. ,
206206 drop_path_rate = 0. ,
207207 norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
208208 global_pool = None ,
@@ -235,14 +235,14 @@ def __init__(
235235 dpr = [drop_path_rate for i in range (depth )]
236236 self .blocks = nn .ModuleList ([
237237 block_layers (
238- dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias , qk_scale = qk_scale ,
238+ dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias ,
239239 drop = drop_rate , attn_drop = attn_drop_rate , drop_path = dpr [i ], norm_layer = norm_layer ,
240240 act_layer = act_layer , attn_block = attn_block , mlp_block = mlp_block , init_values = init_scale )
241241 for i in range (depth )])
242242
243243 self .blocks_token_only = nn .ModuleList ([
244244 block_layers_token (
245- dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio_clstk , qkv_bias = qkv_bias , qk_scale = qk_scale ,
245+ dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio_clstk , qkv_bias = qkv_bias ,
246246 drop = 0.0 , attn_drop = 0.0 , drop_path = 0.0 , norm_layer = norm_layer ,
247247 act_layer = act_layer , attn_block = attn_block_token_only ,
248248 mlp_block = mlp_block_token_only , init_values = init_scale )
@@ -270,6 +270,13 @@ def _init_weights(self, m):
270270 def no_weight_decay (self ):
271271 return {'pos_embed' , 'cls_token' }
272272
273+ def get_classifier (self ):
274+ return self .head
275+
276+ def reset_classifier (self , num_classes , global_pool = '' ):
277+ self .num_classes = num_classes
278+ self .head = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
279+
273280 def forward_features (self , x ):
274281 B = x .shape [0 ]
275282 x = self .patch_embed (x )
@@ -293,7 +300,6 @@ def forward_features(self, x):
293300 def forward (self , x ):
294301 x = self .forward_features (x )
295302 x = self .head (x )
296-
297303 return x
298304
299305
0 commit comments