3737
3838from ._registry import generate_default_cfgs , register_model
3939from ._builder import build_model_with_cfg
40+ from ._features import feature_take_indices
4041
4142
4243def conv_nd (n : int ) -> Type [nn .Module ]:
@@ -517,7 +518,7 @@ def __init__(
517518 # Transformer blocks
518519 cur_stage = 0
519520 self .blocks = nn .ModuleList ()
520-
521+ self . feature_info = []
521522 for i in range (depth ):
522523 dim_out = embed_dim
523524 # Mask unit or global attention.
@@ -543,8 +544,10 @@ def __init__(
543544 window_size = flat_mu_size ,
544545 use_mask_unit_attn = use_mask_unit_attn ,
545546 )
546-
547547 embed_dim = dim_out
548+ if i in self .stage_ends :
549+ self .feature_info += [
550+ dict (num_chs = dim_out , reduction = 2 ** (cur_stage + 2 ), module = f'blocks.{ self .stage_ends [cur_stage ]} ' )]
548551 self .blocks .append (block )
549552
550553 self .norm = norm_layer (embed_dim )
@@ -616,6 +619,57 @@ def _pos_embed(self, x) -> torch.Tensor:
616619 x = x + pos_embed
617620 return x
618621
622+ def forward_intermediates (
623+ self ,
624+ x : torch .Tensor ,
625+ indices : Optional [Union [int , List [int ], Tuple [int ]]] = None ,
626+ norm : bool = False ,
627+ stop_early : bool = True ,
628+ output_fmt : str = 'NCHW' ,
629+ intermediates_only : bool = False ,
630+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
631+ """ Forward features that returns intermediates.
632+
633+ Args:
634+ x: Input image tensor
635+ indices: Take last n blocks if int, all if None, select matching indices if sequence
636+ norm: Apply norm layer to all intermediates
637+ stop_early: Stop iterating over blocks when last desired intermediate hit
638+ output_fmt: Shape of intermediate feature outputs
639+ intermediates_only: Only return intermediate features
640+ Returns:
641+
642+ """
643+ assert not norm , 'normalization of features not supported'
644+ assert output_fmt in ('NCHW' ,), 'Output format must be one of NCHW.'
645+ take_indices , max_index = feature_take_indices (len (self .stage_ends ), indices )
646+
647+ # FIXME using existing return_intermediates support in model, doesn't have early stopping.
648+ x , intermediates = self .forward_features (x , return_intermediates = True )
649+ intermediates = [y .permute (0 , 3 , 1 , 2 ) for i , y in enumerate (intermediates ) if i in take_indices ]
650+ if intermediates_only :
651+ return intermediates
652+
653+ return x , intermediates
654+
655+ def prune_intermediate_layers (
656+ self ,
657+ n : Union [int , List [int ], Tuple [int ]] = 1 ,
658+ prune_norm : bool = False ,
659+ prune_head : bool = True ,
660+ ):
661+ """ Prune layers not required for specified intermediates.
662+ """
663+ take_indices , max_index = feature_take_indices (len (self .stage_ends ), n )
664+ max_index = self .stage_ends [max_index ]
665+ self .blocks = self .blocks [:max_index + 1 ] # truncate blocks
666+ if prune_head :
667+ # norm part of head for this model, equivalent to fc_norm in other vit.
668+ self .norm = nn .Identity ()
669+ self .head = nn .Identity ()
670+ return take_indices
671+
672+
619673 def forward_features (
620674 self ,
621675 x : torch .Tensor ,
0 commit comments