2424# --------------------------------------------------------
2525import math
2626from functools import partial
27- from typing import List , Tuple , Type , Callable , Optional , Union
27+ from typing import Callable , Dict , List , Optional , Tuple , Type , Union
2828
2929import torch
3030import torch .nn as nn
3131import torch .nn .functional as F
32+ from torch .utils .checkpoint import checkpoint
3233
3334
3435from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
@@ -480,14 +481,14 @@ def __init__(
480481 ):
481482 super ().__init__ ()
482483 self .num_classes = num_classes
484+ self .grad_checkpointing = False
483485 norm_layer = get_norm_layer (norm_layer )
484- depth = sum ( stages )
486+
485487 self .patch_stride = patch_stride
486488 self .tokens_spatial_shape = [i // s for i , s in zip (img_size , patch_stride )]
487489 num_tokens = math .prod (self .tokens_spatial_shape )
488490 flat_mu_size = math .prod (mask_unit_size )
489491 flat_q_stride = math .prod (q_stride )
490-
491492 assert q_pool < len (stages )
492493 self .q_pool , self .q_stride = q_pool , q_stride
493494 self .mu_size , self .mask_unit_size = flat_mu_size , mask_unit_size
@@ -532,11 +533,10 @@ def __init__(
532533 # q_pool locations
533534 q_pool_blocks = [x + 1 for x in self .stage_ends [:q_pool ]]
534535
535- # stochastic depth decay rule
536- dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )]
537-
538536 # Transformer blocks
539537 cur_stage = 0
538+ depth = sum (stages )
539+ dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )] # stochastic depth decay rule
540540 self .blocks = nn .ModuleList ()
541541 self .feature_info = []
542542 for i in range (depth ):
@@ -586,8 +586,9 @@ def __init__(
586586 else :
587587 nn .init .trunc_normal_ (self .pos_embed , std = 0.02 )
588588 self .apply (partial (self ._init_weights ))
589- self .head .fc .weight .data .mul_ (head_init_scale )
590- self .head .fc .bias .data .mul_ (head_init_scale )
589+ if isinstance (self .head .fc , nn .Linear ):
590+ self .head .fc .weight .data .mul_ (head_init_scale )
591+ self .head .fc .bias .data .mul_ (head_init_scale )
591592
592593 def _init_weights (self , m , init_bias = 0.02 ):
593594 if isinstance (m , (nn .Linear , nn .Conv1d , nn .Conv2d , nn .Conv3d )):
@@ -605,6 +606,25 @@ def no_weight_decay(self):
605606 else :
606607 return ["pos_embed_spatial" , "pos_embed_temporal" ]
607608
609+ @torch .jit .ignore
610+ def group_matcher (self , coarse : bool = False ) -> Dict :
611+ return dict (
612+ stem = r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed' , # stem and embed
613+ blocks = [(r'^blocks\.(\d+)' , None ), (r'^norm' , (99999 ,))]
614+ )
615+
616+ @torch .jit .ignore
617+ def set_grad_checkpointing (self , enable : bool = True ) -> None :
618+ self .grad_checkpointing = enable
619+
620+ @torch .jit .ignore
621+ def get_classifier (self ):
622+ return self .head .fc
623+
624+ def reset_classifier (self , num_classes : int , global_pool : Optional [str ] = None , other : bool = False ):
625+ self .num_classes = num_classes
626+ self .head .reset (num_classes , global_pool , other = other )
627+
608628 def get_random_mask (self , x : torch .Tensor , mask_ratio : float ) -> torch .Tensor :
609629 """
610630 Generates a random mask, mask_ratio fraction are dropped.
@@ -740,7 +760,10 @@ def forward_features(
740760
741761 intermediates = []
742762 for i , blk in enumerate (self .blocks ):
743- x = blk (x )
763+ if self .grad_checkpointing and not torch .jit .is_scripting ():
764+ x = checkpoint (blk , x )
765+ else :
766+ x = blk (x )
744767 if return_intermediates and i in self .stage_ends :
745768 intermediates .append (self .reroll (x , i , mask = mask ))
746769
0 commit comments