Skip to content

Commit 8efdc38

Browse files
committed
Fix #2242 add checks for out indices with intermediate getter mode
1 parent d224074 commit 8efdc38

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

timm/models/_features.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818
from torch.utils.checkpoint import checkpoint
1919

20-
from timm.layers import Format
20+
from timm.layers import Format, _assert
2121

2222

2323
__all__ = [
@@ -51,14 +51,15 @@ def feature_take_indices(
5151
indices = num_features # all features if None
5252

5353
if isinstance(indices, int):
54-
assert indices >= 0
5554
# convert int -> last n indices
55+
_assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
5656
take_indices = [num_features - indices + i for i in range(indices)]
57-
elif isinstance(indices, tuple):
58-
# duplicating this is silly, but needed for torchscript type resolution of n
59-
take_indices = [num_features + idx if idx < 0 else idx for idx in indices]
6057
else:
61-
take_indices = [num_features + idx if idx < 0 else idx for idx in indices]
58+
take_indices: List[int] = []
59+
for i in indices:
60+
idx = num_features + i if i < 0 else i
61+
_assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
62+
take_indices.append(idx)
6263

6364
if not torch.jit.is_scripting() and as_set:
6465
return set(take_indices), max(take_indices)

0 commit comments

Comments
 (0)