2525from timm .layers import create_conv2d , create_norm_layer , get_act_layer , get_norm_layer , ConvNormAct
2626from timm .layers import DropPath , trunc_normal_ , to_2tuple , to_ntuple
2727from ._builder import build_model_with_cfg
28+ from ._manipulate import checkpoint_seq
2829from ._pretrained import generate_default_cfgs
2930from ._registry import register_model
3031
@@ -498,7 +499,10 @@ def __init__(
498499
499500 def forward (self , x ):
500501 x = self .downsample (x )
501- x = self .blocks (x )
502+ if self .grad_checkpointing :
503+ x = checkpoint_seq (self .blocks , x )
504+ else :
505+ x = self .blocks (x )
502506 return x
503507
504508
@@ -508,6 +512,7 @@ def __init__(
508512 depths ,
509513 in_chans = 3 ,
510514 img_size = 224 ,
515+ global_pool = 'avg' ,
511516 embed_dims = None ,
512517 downsamples = None ,
513518 mlp_ratios = 4 ,
@@ -522,7 +527,9 @@ def __init__(
522527 distillation = True ,
523528 ):
524529 super ().__init__ ()
530+ assert global_pool in ('avg' , '' )
525531 self .num_classes = num_classes
532+ self .global_pool = global_pool
526533 self .feature_info = []
527534 img_size = to_2tuple (img_size )
528535 norm_layer = partial (get_norm_layer (norm_layer ), eps = norm_eps )
@@ -583,11 +590,49 @@ def init_weights(self, m):
583590 if m .bias is not None :
584591 nn .init .constant_ (m .bias , 0 )
585592
586- def forward (self , x ):
593+ @torch .jit .ignore
594+ def no_weight_decay (self ):
595+ return {k for k , _ in self .named_parameters () if 'attention_biases' in k }
596+
597+ @torch .jit .ignore
598+ def group_matcher (self , coarse = False ):
599+ matcher = dict (
600+ stem = r'^stem' , # stem and embed
601+ blocks = [(r'^stages\.(\d+)' , None ), (r'^norm' , (99999 ,))]
602+ )
603+ return matcher
604+
605+ @torch .jit .ignore
606+ def set_grad_checkpointing (self , enable = True ):
607+ for s in self .stages :
608+ s .grad_checkpointing = enable
609+
610+ @torch .jit .ignore
611+ def get_classifier (self ):
612+ return self .head , self .head_dist
613+
614+ def reset_classifier (self , num_classes , global_pool = None ):
615+ self .num_classes = num_classes
616+ if global_pool is not None :
617+ self .global_pool = global_pool
618+ self .head = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
619+ self .head_dist = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
620+
621+ @torch .jit .ignore
622+ def set_distilled_training (self , enable = True ):
623+ self .distilled_training = enable
624+
625+ def forward_features (self , x ):
587626 x = self .stem (x )
588627 x = self .stages (x )
589628 x = self .norm (x )
590- x = x .mean (dim = (2 , 3 ))
629+ return x
630+
631+ def forward_head (self , x , pre_logits : bool = False ):
632+ if self .global_pool == 'avg' :
633+ x = x .mean (dim = (2 , 3 ))
634+ if pre_logits :
635+ return x
591636 x , x_dist = self .head (x ), self .head_dist (x )
592637 if self .distilled_training and self .training and not torch .jit .is_scripting ():
593638 # only return separate classification predictions when training in distilled mode
@@ -596,6 +641,11 @@ def forward(self, x):
596641 # during standard train/finetune, inference average the classifier predictions
597642 return (x + x_dist ) / 2
598643
644+ def forward (self , x ):
645+ x = self .forward_features (x )
646+ x = self .forward_head (x )
647+ return x
648+
599649
600650def _cfg (url = '' , ** kwargs ):
601651 return {
0 commit comments