3232
3333
3434from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
35- from timm .layers import DropPath , Mlp , use_fused_attn , _assert
35+ from timm .layers import DropPath , Mlp , use_fused_attn , _assert , get_norm_layer
3636
3737
3838from ._registry import generate_default_cfgs , register_model
@@ -372,20 +372,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
372372 return x
373373
374374
375- class Head (nn .Module ):
375+ class NormClassifierHead (nn .Module ):
376376 def __init__ (
377377 self ,
378- dim : int ,
378+ in_features : int ,
379379 num_classes : int ,
380+ pool_type : str = 'avg' ,
380381 drop_rate : float = 0.0 ,
382+ norm_layer : Union [str , Callable ] = 'layernorm' ,
381383 ):
382384 super ().__init__ ()
383- self .dropout = nn .Dropout (drop_rate ) if drop_rate > 0 else nn .Identity ()
384- self .projection = nn .Linear (dim , num_classes )
385+ norm_layer = get_norm_layer (norm_layer )
386+ assert pool_type in ('avg' , '' )
387+ self .in_features = self .num_features = in_features
388+ self .pool_type = pool_type
389+ self .norm = norm_layer (in_features )
390+ self .drop = nn .Dropout (drop_rate ) if drop_rate else nn .Identity ()
391+ self .fc = nn .Linear (in_features , num_classes ) if num_classes > 0 else nn .Identity ()
392+
393+ def reset (self , num_classes : int , pool_type : Optional [str ] = None , other : bool = False ):
394+ if pool_type is not None :
395+ assert pool_type in ('avg' , '' )
396+ self .pool_type = pool_type
397+ if other :
398+ # reset other non-fc layers
399+ self .norm = nn .Identity ()
400+ self .fc = nn .Linear (self .in_features , num_classes ) if num_classes > 0 else nn .Identity ()
385401
386- def forward (self , x : torch .Tensor ) -> torch .Tensor :
387- x = self .dropout (x )
388- x = self .projection (x )
402+ def forward (self , x : torch .Tensor , pre_logits : bool = False ) -> torch .Tensor :
403+ if self .pool_type == 'avg' :
404+ x = x .mean (dim = 1 )
405+ x = self .norm (x )
406+ x = self .drop (x )
407+ if pre_logits :
408+ return x
409+ x = self .fc (x )
389410 return x
390411
391412
@@ -438,6 +459,7 @@ def __init__(
438459 embed_dim : int = 96 , # initial embed dim
439460 num_heads : int = 1 , # initial number of heads
440461 num_classes : int = 1000 ,
462+ global_pool : str = 'avg' ,
441463 stages : Tuple [int , ...] = (2 , 3 , 16 , 3 ),
442464 q_pool : int = 3 , # number of q_pool stages
443465 q_stride : Tuple [int , ...] = (2 , 2 ),
@@ -458,11 +480,7 @@ def __init__(
458480 ):
459481 super ().__init__ ()
460482 self .num_classes = num_classes
461-
462- # Do it this way to ensure that the init args are all PoD (for config usage)
463- if isinstance (norm_layer , str ):
464- norm_layer = partial (getattr (nn , norm_layer ), eps = 1e-6 )
465-
483+ norm_layer = get_norm_layer (norm_layer )
466484 depth = sum (stages )
467485 self .patch_stride = patch_stride
468486 self .tokens_spatial_shape = [i // s for i , s in zip (img_size , patch_stride )]
@@ -552,8 +570,14 @@ def __init__(
552570 dict (num_chs = dim_out , reduction = 2 ** (cur_stage + 2 ), module = f'blocks.{ self .stage_ends [cur_stage ]} ' )]
553571 self .blocks .append (block )
554572
555- self .norm = norm_layer (embed_dim )
556- self .head = Head (embed_dim , num_classes , drop_rate = drop_rate )
573+ self .num_features = embed_dim
574+ self .head = NormClassifierHead (
575+ embed_dim ,
576+ num_classes ,
577+ pool_type = global_pool ,
578+ drop_rate = drop_rate ,
579+ norm_layer = norm_layer ,
580+ )
557581
558582 # Initialize everything
559583 if sep_pos_embed :
@@ -562,8 +586,8 @@ def __init__(
562586 else :
563587 nn .init .trunc_normal_ (self .pos_embed , std = 0.02 )
564588 self .apply (partial (self ._init_weights ))
565- self .head .projection .weight .data .mul_ (head_init_scale )
566- self .head .projection .bias .data .mul_ (head_init_scale )
589+ self .head .fc .weight .data .mul_ (head_init_scale )
590+ self .head .fc .bias .data .mul_ (head_init_scale )
567591
568592 def _init_weights (self , m , init_bias = 0.02 ):
569593 if isinstance (m , (nn .Linear , nn .Conv1d , nn .Conv2d , nn .Conv3d )):
@@ -678,19 +702,17 @@ def forward_intermediates(
678702
679703 def prune_intermediate_layers (
680704 self ,
681- n : Union [int , List [int ], Tuple [int ]] = 1 ,
705+ indices : Union [int , List [int ], Tuple [int ]] = 1 ,
682706 prune_norm : bool = False ,
683707 prune_head : bool = True ,
684708 ):
685709 """ Prune layers not required for specified intermediates.
686710 """
687- take_indices , max_index = feature_take_indices (len (self .stage_ends ), n )
711+ take_indices , max_index = feature_take_indices (len (self .stage_ends ), indices )
688712 max_index = self .stage_ends [max_index ]
689713 self .blocks = self .blocks [:max_index + 1 ] # truncate blocks
690714 if prune_head :
691- # norm part of head for this model, equivalent to fc_norm in other vit.
692- self .norm = nn .Identity ()
693- self .head = nn .Identity ()
715+ self .head .reset (0 , other = True )
694716 return take_indices
695717
696718
@@ -732,11 +754,7 @@ def forward_features(
732754 return x
733755
734756 def forward_head (self , x , pre_logits : bool = False ) -> torch .Tensor :
735- x = x .mean (dim = 1 )
736- x = self .norm (x )
737- if pre_logits :
738- return x
739- x = self .head (x )
757+ x = self .head (x , pre_logits = pre_logits ) if pre_logits else self .head (x )
740758 return x
741759
742760 def forward (
@@ -756,7 +774,7 @@ def _cfg(url='', **kwargs):
756774 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
757775 'crop_pct' : .9 , 'interpolation' : 'bicubic' , 'fixed_input_size' : True ,
758776 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
759- 'first_conv' : 'patch_embed.proj' , 'classifier' : 'head' ,
777+ 'first_conv' : 'patch_embed.proj' , 'classifier' : 'head.fc ' ,
760778 ** kwargs
761779 }
762780
@@ -837,6 +855,10 @@ def checkpoint_filter_fn(state_dict, model=None):
837855 # )
838856 #v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2)
839857 pass
858+ if 'head.projection.' in k :
859+ k = k .replace ('head.projection.' , 'head.fc.' )
860+ if k .startswith ('norm.' ):
861+ k = k .replace ('norm.' , 'head.norm.' )
840862 output [k ] = v
841863 return output
842864
0 commit comments