3232
3333
3434from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
35- from timm .layers import DropPath , Mlp , use_fused_attn
35+ from timm .layers import DropPath , Mlp , use_fused_attn , _assert
3636
3737
3838from ._registry import generate_default_cfgs , register_model
3939from ._builder import build_model_with_cfg
4040from ._features import feature_take_indices
41+ from ._features_fx import register_notrace_function
4142
4243
4344def conv_nd (n : int ) -> Type [nn .Module ]:
@@ -48,13 +49,14 @@ def conv_nd(n: int) -> Type[nn.Module]:
4849 return [nn .Identity , nn .Conv1d , nn .Conv2d , nn .Conv3d ][n ]
4950
5051
52+ @register_notrace_function
5153def get_resized_mask (target_size : List [int ], mask : torch .Tensor ) -> torch .Tensor :
5254 # target_size: [(T), (H), W]
5355 # (spatial) mask: [B, C, (t), (h), w]
5456 if mask is None :
5557 return mask
5658
57- assert len (mask .shape [2 :]) == len (target_size )
59+ _assert ( len (mask .shape [2 :]) == len (target_size ), "mask spatial shape and target_size must match." )
5860 if mask .shape [2 :] != target_size :
5961 return F .interpolate (mask .float (), size = target_size )
6062 return mask
@@ -622,6 +624,7 @@ def _pos_embed(self, x) -> torch.Tensor:
622624 def forward_intermediates (
623625 self ,
624626 x : torch .Tensor ,
627+ mask : Optional [torch .Tensor ] = None ,
625628 indices : Optional [Union [int , List [int ], Tuple [int ]]] = None ,
626629 norm : bool = False ,
627630 stop_early : bool = True ,
@@ -643,10 +646,31 @@ def forward_intermediates(
643646 assert not norm , 'normalization of features not supported'
644647 assert output_fmt in ('NCHW' ,), 'Output format must be one of NCHW.'
645648 take_indices , max_index = feature_take_indices (len (self .stage_ends ), indices )
649+ take_indices = [self .stage_ends [i ] for i in take_indices ]
650+ max_index = self .stage_ends [max_index ]
651+
652+ if mask is not None :
653+ patch_mask = mask .view (x .shape [0 ], 1 , * self .mask_spatial_shape ) # B, C, *mask_spatial_shape
654+ else :
655+ patch_mask = None
656+ x = self .patch_embed (x , mask = patch_mask )
657+ x = self ._pos_embed (x )
658+ x = self .unroll (x )
659+
660+ # Discard masked tokens
661+ if mask is not None :
662+ x = x [mask [..., None ].tile (1 , self .mu_size , x .shape [2 ])].view (x .shape [0 ], - 1 , x .shape [- 1 ])
663+
664+ intermediates = []
665+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
666+ blocks = self .blocks
667+ else :
668+ blocks = self .blocks [:max_index + 1 ]
669+ for i , blk in enumerate (blocks ):
670+ x = blk (x )
671+ if i in take_indices :
672+ intermediates .append (self .reroll (x , i , mask = mask ).permute (0 , 3 , 1 , 2 ))
646673
647- # FIXME using existing return_intermediates support in model, doesn't have early stopping.
648- x , intermediates = self .forward_features (x , return_intermediates = True )
649- intermediates = [y .permute (0 , 3 , 1 , 2 ) for i , y in enumerate (intermediates ) if i in take_indices ]
650674 if intermediates_only :
651675 return intermediates
652676
@@ -673,18 +697,18 @@ def prune_intermediate_layers(
673697 def forward_features (
674698 self ,
675699 x : torch .Tensor ,
676- mask : torch .Tensor = None ,
700+ mask : Optional [ torch .Tensor ] = None ,
677701 return_intermediates : bool = False ,
678702 ) -> torch .Tensor :
679703 """
680704 mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
681705 Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
682706 """
683- x = self . patch_embed (
684- x ,
685- mask = mask . view ( x . shape [ 0 ], 1 , * self . mask_spatial_shape ) # B, C, *mask_spatial_shape
686- if mask is not None else None ,
687- )
707+ if mask is not None :
708+ patch_mask = mask . view ( x . shape [ 0 ], 1 , * self . mask_spatial_shape ) # B, C, *mask_spatial_shape
709+ else :
710+ patch_mask = None
711+ x = self . patch_embed ( x , mask = patch_mask )
688712 x = self ._pos_embed (x )
689713 x = self .unroll (x )
690714
@@ -718,19 +742,12 @@ def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
718742 def forward (
719743 self ,
720744 x : torch .Tensor ,
721- mask : torch .Tensor = None ,
722- return_intermediates : bool = False ,
745+ mask : Optional [torch .Tensor ] = None ,
723746 ) -> torch .Tensor :
724- if return_intermediates :
725- x , intermediates = self .forward_features (x , mask = mask , return_intermediates = return_intermediates )
726- if mask is not None :
727- x = self .forward_head (x )
728- return x , intermediates
729- else :
730- x = self .forward_features (x , mask = mask )
731- if mask is None :
732- x = self .forward_head (x )
733- return x
747+ x = self .forward_features (x , mask = mask )
748+ if mask is None :
749+ x = self .forward_head (x )
750+ return x
734751
735752
736753def _cfg (url = '' , ** kwargs ):
0 commit comments