Skip to content

Commit ef147fd

Browse files
committed
Add forward_intermediates API to Hiera for features_only=True support
1 parent d88bed6 commit ef147fd

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@
4949

5050
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
5151
FEAT_INTER_FILTERS = [
52-
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
52+
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*', 'hiera_*'
5353
]
5454

5555
# transformer models don't support many of the spatial / feature based model functionalities
5656
NON_STD_FILTERS = [
5757
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
5858
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
5959
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
60-
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
60+
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
6161
]
6262
NUM_NON_STD = len(NON_STD_FILTERS)
6363

timm/models/hiera.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from ._registry import generate_default_cfgs, register_model
3939
from ._builder import build_model_with_cfg
40+
from ._features import feature_take_indices
4041

4142

4243
def conv_nd(n: int) -> Type[nn.Module]:
@@ -517,7 +518,7 @@ def __init__(
517518
# Transformer blocks
518519
cur_stage = 0
519520
self.blocks = nn.ModuleList()
520-
521+
self.feature_info = []
521522
for i in range(depth):
522523
dim_out = embed_dim
523524
# Mask unit or global attention.
@@ -543,8 +544,10 @@ def __init__(
543544
window_size=flat_mu_size,
544545
use_mask_unit_attn=use_mask_unit_attn,
545546
)
546-
547547
embed_dim = dim_out
548+
if i in self.stage_ends:
549+
self.feature_info += [
550+
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
548551
self.blocks.append(block)
549552

550553
self.norm = norm_layer(embed_dim)
@@ -616,6 +619,57 @@ def _pos_embed(self, x) -> torch.Tensor:
616619
x = x + pos_embed
617620
return x
618621

622+
def forward_intermediates(
623+
self,
624+
x: torch.Tensor,
625+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
626+
norm: bool = False,
627+
stop_early: bool = True,
628+
output_fmt: str = 'NCHW',
629+
intermediates_only: bool = False,
630+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
631+
""" Forward features that returns intermediates.
632+
633+
Args:
634+
x: Input image tensor
635+
indices: Take last n blocks if int, all if None, select matching indices if sequence
636+
norm: Apply norm layer to all intermediates
637+
stop_early: Stop iterating over blocks when last desired intermediate hit
638+
output_fmt: Shape of intermediate feature outputs
639+
intermediates_only: Only return intermediate features
640+
Returns:
641+
642+
"""
643+
assert not norm, 'normalization of features not supported'
644+
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
645+
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
646+
647+
# FIXME using existing return_intermediates support in model, doesn't have early stopping.
648+
x, intermediates = self.forward_features(x, return_intermediates=True)
649+
intermediates = [y.permute(0, 3, 1, 2) for i, y in enumerate(intermediates) if i in take_indices]
650+
if intermediates_only:
651+
return intermediates
652+
653+
return x, intermediates
654+
655+
def prune_intermediate_layers(
656+
self,
657+
n: Union[int, List[int], Tuple[int]] = 1,
658+
prune_norm: bool = False,
659+
prune_head: bool = True,
660+
):
661+
""" Prune layers not required for specified intermediates.
662+
"""
663+
take_indices, max_index = feature_take_indices(len(self.stage_ends), n)
664+
max_index = self.stage_ends[max_index]
665+
self.blocks = self.blocks[:max_index + 1] # truncate blocks
666+
if prune_head:
667+
# norm part of head for this model, equivalent to fc_norm in other vit.
668+
self.norm = nn.Identity()
669+
self.head = nn.Identity()
670+
return take_indices
671+
672+
619673
def forward_features(
620674
self,
621675
x: torch.Tensor,

0 commit comments

Comments
 (0)