Skip to content

Commit f3c11dc

Browse files
authored
Merge pull request #2239 from huggingface/fix_out_indices_order
Fix issue where feature out_indices out of order after wrapping with FeatureGetterNet
2 parents a1996ec + 8efdc38 commit f3c11dc

26 files changed

+86
-86
lines changed

timm/models/_features.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
from collections import OrderedDict, defaultdict
1212
from copy import deepcopy
1313
from functools import partial
14-
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
14+
from typing import Dict, List, Optional, Sequence, Tuple, Union
1515

1616
import torch
1717
import torch.nn as nn
1818
from torch.utils.checkpoint import checkpoint
1919

20-
from timm.layers import Format
20+
from timm.layers import Format, _assert
2121

2222

2323
__all__ = [
@@ -26,44 +26,45 @@
2626
]
2727

2828

29-
def _take_indices(
30-
num_blocks: int,
31-
n: Optional[Union[int, List[int], Tuple[int]]],
32-
) -> Tuple[Set[int], int]:
33-
if isinstance(n, int):
34-
assert n >= 0
35-
take_indices = {x for x in range(num_blocks - n, num_blocks)}
36-
else:
37-
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
38-
return take_indices, max(take_indices)
39-
40-
41-
def _take_indices_jit(
42-
num_blocks: int,
43-
n: Union[int, List[int], Tuple[int]],
29+
def feature_take_indices(
30+
num_features: int,
31+
indices: Optional[Union[int, List[int]]] = None,
32+
as_set: bool = False,
4433
) -> Tuple[List[int], int]:
45-
if isinstance(n, int):
46-
assert n >= 0
47-
take_indices = [num_blocks - n + i for i in range(n)]
48-
elif isinstance(n, tuple):
49-
# splitting this up is silly, but needed for torchscript type resolution of n
50-
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
51-
else:
52-
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
53-
return take_indices, max(take_indices)
34+
""" Determine the absolute feature indices to 'take' from.
5435
36+
Note: This function can be called in forwar() so must be torchscript compatible,
37+
which requires some incomplete typing and workaround hacks.
5538
56-
def feature_take_indices(
57-
num_blocks: int,
58-
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
59-
) -> Tuple[List[int], int]:
39+
Args:
40+
num_features: total number of features to select from
41+
indices: indices to select,
42+
None -> select all
43+
int -> select last n
44+
list/tuple of int -> return specified (-ve indices specify from end)
45+
as_set: return as a set
46+
47+
Returns:
48+
List (or set) of absolute (from beginning) indices, Maximum index
49+
"""
6050
if indices is None:
61-
indices = num_blocks # all blocks if None
62-
if torch.jit.is_scripting():
63-
return _take_indices_jit(num_blocks, indices)
51+
indices = num_features # all features if None
52+
53+
if isinstance(indices, int):
54+
# convert int -> last n indices
55+
_assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
56+
take_indices = [num_features - indices + i for i in range(indices)]
6457
else:
65-
# NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
66-
return _take_indices(num_blocks, indices)
58+
take_indices: List[int] = []
59+
for i in indices:
60+
idx = num_features + i if i < 0 else i
61+
_assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
62+
take_indices.append(idx)
63+
64+
if not torch.jit.is_scripting() and as_set:
65+
return set(take_indices), max(take_indices)
66+
67+
return take_indices, max(take_indices)
6768

6869

6970
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
@@ -464,7 +465,6 @@ def __init__(
464465
out_indices,
465466
prune_norm=not norm,
466467
)
467-
out_indices = list(out_indices)
468468
self.feature_info = _get_feature_info(model, out_indices)
469469
self.model = model
470470
self.out_indices = out_indices

timm/models/beit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
404404
def forward_intermediates(
405405
self,
406406
x: torch.Tensor,
407-
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
407+
indices: Optional[Union[int, List[int]]] = None,
408408
return_prefix_tokens: bool = False,
409409
norm: bool = False,
410410
stop_early: bool = False,
@@ -470,7 +470,7 @@ def forward_intermediates(
470470

471471
def prune_intermediate_layers(
472472
self,
473-
indices: Union[int, List[int], Tuple[int]] = 1,
473+
indices: Union[int, List[int]] = 1,
474474
prune_norm: bool = False,
475475
prune_head: bool = True,
476476
):

timm/models/byobnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,7 +1343,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
13431343
def forward_intermediates(
13441344
self,
13451345
x: torch.Tensor,
1346-
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
1346+
indices: Optional[Union[int, List[int]]] = None,
13471347
norm: bool = False,
13481348
stop_early: bool = False,
13491349
output_fmt: str = 'NCHW',
@@ -1401,7 +1401,7 @@ def forward_intermediates(
14011401

14021402
def prune_intermediate_layers(
14031403
self,
1404-
indices: Union[int, List[int], Tuple[int]] = 1,
1404+
indices: Union[int, List[int]] = 1,
14051405
prune_norm: bool = False,
14061406
prune_head: bool = True,
14071407
):

timm/models/cait.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
341341
def forward_intermediates(
342342
self,
343343
x: torch.Tensor,
344-
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
344+
indices: Optional[Union[int, List[int]]] = None,
345345
norm: bool = False,
346346
stop_early: bool = False,
347347
output_fmt: str = 'NCHW',
@@ -398,7 +398,7 @@ def forward_intermediates(
398398

399399
def prune_intermediate_layers(
400400
self,
401-
indices: Union[int, List[int], Tuple[int]] = 1,
401+
indices: Union[int, List[int]] = 1,
402402
prune_norm: bool = False,
403403
prune_head: bool = True,
404404
):

timm/models/convnext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
412412
def forward_intermediates(
413413
self,
414414
x: torch.Tensor,
415-
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
415+
indices: Optional[Union[int, List[int]]] = None,
416416
norm: bool = False,
417417
stop_early: bool = False,
418418
output_fmt: str = 'NCHW',
@@ -460,7 +460,7 @@ def forward_intermediates(
460460

461461
def prune_intermediate_layers(
462462
self,
463-
indices: Union[int, List[int], Tuple[int]] = 1,
463+
indices: Union[int, List[int]] = 1,
464464
prune_norm: bool = False,
465465
prune_head: bool = True,
466466
):

timm/models/efficientformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def set_distilled_training(self, enable=True):
463463
def forward_intermediates(
464464
self,
465465
x: torch.Tensor,
466-
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
466+
indices: Optional[Union[int, List[int]]] = None,
467467
norm: bool = False,
468468
stop_early: bool = False,
469469
output_fmt: str = 'NCHW',
@@ -516,7 +516,7 @@ def forward_intermediates(
516516

517517
def prune_intermediate_layers(
518518
self,
519-
indices: Union[int, List[int], Tuple[int]] = 1,
519+
indices: Union[int, List[int]] = 1,
520520
prune_norm: bool = False,
521521
prune_head: bool = True,
522522
):

timm/models/efficientnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
165165
def forward_intermediates(
166166
self,
167167
x: torch.Tensor,
168-
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
168+
indices: Optional[Union[int, List[int]]] = None,
169169
norm: bool = False,
170170
stop_early: bool = False,
171171
output_fmt: str = 'NCHW',
@@ -221,7 +221,7 @@ def forward_intermediates(
221221

222222
def prune_intermediate_layers(
223223
self,
224-
indices: Union[int, List[int], Tuple[int]] = 1,
224+
indices: Union[int, List[int]] = 1,
225225
prune_norm: bool = False,
226226
prune_head: bool = True,
227227
extra_blocks: bool = False,

timm/models/eva.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
589589
def forward_intermediates(
590590
self,
591591
x: torch.Tensor,
592-
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
592+
indices: Optional[Union[int, List[int]]] = None,
593593
return_prefix_tokens: bool = False,
594594
norm: bool = False,
595595
stop_early: bool = False,
@@ -646,7 +646,7 @@ def forward_intermediates(
646646

647647
def prune_intermediate_layers(
648648
self,
649-
indices: Union[int, List[int], Tuple[int]] = 1,
649+
indices: Union[int, List[int]] = 1,
650650
prune_norm: bool = False,
651651
prune_head: bool = True,
652652
):

timm/models/fastvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,7 +1251,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
12511251
def forward_intermediates(
12521252
self,
12531253
x: torch.Tensor,
1254-
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
1254+
indices: Optional[Union[int, List[int]]] = None,
12551255
norm: bool = False,
12561256
stop_early: bool = False,
12571257
output_fmt: str = 'NCHW',
@@ -1296,7 +1296,7 @@ def forward_intermediates(
12961296

12971297
def prune_intermediate_layers(
12981298
self,
1299-
indices: Union[int, List[int], Tuple[int]] = 1,
1299+
indices: Union[int, List[int]] = 1,
13001300
prune_norm: bool = False,
13011301
prune_head: bool = True,
13021302
):

timm/models/hiera.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ def forward_intermediates(
669669
self,
670670
x: torch.Tensor,
671671
mask: Optional[torch.Tensor] = None,
672-
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
672+
indices: Optional[Union[int, List[int]]] = None,
673673
norm: bool = False,
674674
stop_early: bool = True,
675675
output_fmt: str = 'NCHW',
@@ -722,7 +722,7 @@ def forward_intermediates(
722722

723723
def prune_intermediate_layers(
724724
self,
725-
indices: Union[int, List[int], Tuple[int]] = 1,
725+
indices: Union[int, List[int]] = 1,
726726
prune_norm: bool = False,
727727
prune_head: bool = True,
728728
):

0 commit comments

Comments
 (0)