File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed
Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change 1717import torch .nn as nn
1818from torch .utils .checkpoint import checkpoint
1919
20- from timm .layers import Format
20+ from timm .layers import Format , _assert
2121
2222
2323__all__ = [
@@ -51,14 +51,15 @@ def feature_take_indices(
5151 indices = num_features # all features if None
5252
5353 if isinstance (indices , int ):
54- assert indices >= 0
5554 # convert int -> last n indices
55+ _assert (0 < indices <= num_features , f'last-n ({ indices } ) is out of range (1 to { num_features } )' )
5656 take_indices = [num_features - indices + i for i in range (indices )]
57- elif isinstance (indices , tuple ):
58- # duplicating this is silly, but needed for torchscript type resolution of n
59- take_indices = [num_features + idx if idx < 0 else idx for idx in indices ]
6057 else :
61- take_indices = [num_features + idx if idx < 0 else idx for idx in indices ]
58+ take_indices : List [int ] = []
59+ for i in indices :
60+ idx = num_features + i if i < 0 else i
61+ _assert (0 <= idx < num_features , f'feature index { idx } is out of range (0 to { num_features - 1 } )' )
62+ take_indices .append (idx )
6263
6364 if not torch .jit .is_scripting () and as_set :
6465 return set (take_indices ), max (take_indices )
You can’t perform that action at this time.
0 commit comments