Skip to content

Commit c6db404

Browse files
committed
Update forward_intermediates for hiera to have its own fwd impl w/ early stopping. Remove return_intermediates bool from forward(). Still an fx issue with None mask arg :(
1 parent e8b08a4 commit c6db404

File tree

1 file changed

+40
-23
lines changed

1 file changed

+40
-23
lines changed

timm/models/hiera.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@
3232

3333

3434
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
35-
from timm.layers import DropPath, Mlp, use_fused_attn
35+
from timm.layers import DropPath, Mlp, use_fused_attn, _assert
3636

3737

3838
from ._registry import generate_default_cfgs, register_model
3939
from ._builder import build_model_with_cfg
4040
from ._features import feature_take_indices
41+
from ._features_fx import register_notrace_function
4142

4243

4344
def conv_nd(n: int) -> Type[nn.Module]:
@@ -48,13 +49,14 @@ def conv_nd(n: int) -> Type[nn.Module]:
4849
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
4950

5051

52+
@register_notrace_function
5153
def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor:
5254
# target_size: [(T), (H), W]
5355
# (spatial) mask: [B, C, (t), (h), w]
5456
if mask is None:
5557
return mask
5658

57-
assert len(mask.shape[2:]) == len(target_size)
59+
_assert(len(mask.shape[2:]) == len(target_size), "mask spatial shape and target_size must match.")
5860
if mask.shape[2:] != target_size:
5961
return F.interpolate(mask.float(), size=target_size)
6062
return mask
@@ -622,6 +624,7 @@ def _pos_embed(self, x) -> torch.Tensor:
622624
def forward_intermediates(
623625
self,
624626
x: torch.Tensor,
627+
mask: Optional[torch.Tensor] = None,
625628
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
626629
norm: bool = False,
627630
stop_early: bool = True,
@@ -643,10 +646,31 @@ def forward_intermediates(
643646
assert not norm, 'normalization of features not supported'
644647
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
645648
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
649+
take_indices = [self.stage_ends[i] for i in take_indices]
650+
max_index = self.stage_ends[max_index]
651+
652+
if mask is not None:
653+
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
654+
else:
655+
patch_mask = None
656+
x = self.patch_embed(x, mask=patch_mask)
657+
x = self._pos_embed(x)
658+
x = self.unroll(x)
659+
660+
# Discard masked tokens
661+
if mask is not None:
662+
x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
663+
664+
intermediates = []
665+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
666+
blocks = self.blocks
667+
else:
668+
blocks = self.blocks[:max_index + 1]
669+
for i, blk in enumerate(blocks):
670+
x = blk(x)
671+
if i in take_indices:
672+
intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2))
646673

647-
# FIXME using existing return_intermediates support in model, doesn't have early stopping.
648-
x, intermediates = self.forward_features(x, return_intermediates=True)
649-
intermediates = [y.permute(0, 3, 1, 2) for i, y in enumerate(intermediates) if i in take_indices]
650674
if intermediates_only:
651675
return intermediates
652676

@@ -673,18 +697,18 @@ def prune_intermediate_layers(
673697
def forward_features(
674698
self,
675699
x: torch.Tensor,
676-
mask: torch.Tensor = None,
700+
mask: Optional[torch.Tensor] = None,
677701
return_intermediates: bool = False,
678702
) -> torch.Tensor:
679703
"""
680704
mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
681705
Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
682706
"""
683-
x = self.patch_embed(
684-
x,
685-
mask=mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
686-
if mask is not None else None,
687-
)
707+
if mask is not None:
708+
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
709+
else:
710+
patch_mask = None
711+
x = self.patch_embed(x, mask=patch_mask)
688712
x = self._pos_embed(x)
689713
x = self.unroll(x)
690714

@@ -718,19 +742,12 @@ def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
718742
def forward(
719743
self,
720744
x: torch.Tensor,
721-
mask: torch.Tensor = None,
722-
return_intermediates: bool = False,
745+
mask: Optional[torch.Tensor] = None,
723746
) -> torch.Tensor:
724-
if return_intermediates:
725-
x, intermediates = self.forward_features(x, mask=mask, return_intermediates=return_intermediates)
726-
if mask is not None:
727-
x = self.forward_head(x)
728-
return x, intermediates
729-
else:
730-
x = self.forward_features(x, mask=mask)
731-
if mask is None:
732-
x = self.forward_head(x)
733-
return x
747+
x = self.forward_features(x, mask=mask)
748+
if mask is None:
749+
x = self.forward_head(x)
750+
return x
734751

735752

736753
def _cfg(url='', **kwargs):

0 commit comments

Comments
 (0)