|
11 | 11 | from collections import OrderedDict, defaultdict |
12 | 12 | from copy import deepcopy |
13 | 13 | from functools import partial |
14 | | -from typing import Dict, List, Optional, Sequence, Set, Tuple, Union |
| 14 | +from typing import Dict, List, Optional, Sequence, Tuple, Union |
15 | 15 |
|
16 | 16 | import torch |
17 | 17 | import torch.nn as nn |
18 | 18 | from torch.utils.checkpoint import checkpoint |
19 | 19 |
|
20 | | -from timm.layers import Format |
| 20 | +from timm.layers import Format, _assert |
21 | 21 |
|
22 | 22 |
|
23 | 23 | __all__ = [ |
|
26 | 26 | ] |
27 | 27 |
|
28 | 28 |
|
29 | | -def _take_indices( |
30 | | - num_blocks: int, |
31 | | - n: Optional[Union[int, List[int], Tuple[int]]], |
32 | | -) -> Tuple[Set[int], int]: |
33 | | - if isinstance(n, int): |
34 | | - assert n >= 0 |
35 | | - take_indices = {x for x in range(num_blocks - n, num_blocks)} |
36 | | - else: |
37 | | - take_indices = {num_blocks + idx if idx < 0 else idx for idx in n} |
38 | | - return take_indices, max(take_indices) |
39 | | - |
40 | | - |
41 | | -def _take_indices_jit( |
42 | | - num_blocks: int, |
43 | | - n: Union[int, List[int], Tuple[int]], |
| 29 | +def feature_take_indices( |
| 30 | + num_features: int, |
| 31 | + indices: Optional[Union[int, List[int]]] = None, |
| 32 | + as_set: bool = False, |
44 | 33 | ) -> Tuple[List[int], int]: |
45 | | - if isinstance(n, int): |
46 | | - assert n >= 0 |
47 | | - take_indices = [num_blocks - n + i for i in range(n)] |
48 | | - elif isinstance(n, tuple): |
49 | | - # splitting this up is silly, but needed for torchscript type resolution of n |
50 | | - take_indices = [num_blocks + idx if idx < 0 else idx for idx in n] |
51 | | - else: |
52 | | - take_indices = [num_blocks + idx if idx < 0 else idx for idx in n] |
53 | | - return take_indices, max(take_indices) |
| 34 | + """ Determine the absolute feature indices to 'take' from. |
54 | 35 |
|
| 36 | + Note: This function can be called in forwar() so must be torchscript compatible, |
| 37 | + which requires some incomplete typing and workaround hacks. |
55 | 38 |
|
56 | | -def feature_take_indices( |
57 | | - num_blocks: int, |
58 | | - indices: Optional[Union[int, List[int], Tuple[int]]] = None, |
59 | | -) -> Tuple[List[int], int]: |
| 39 | + Args: |
| 40 | + num_features: total number of features to select from |
| 41 | + indices: indices to select, |
| 42 | + None -> select all |
| 43 | + int -> select last n |
| 44 | + list/tuple of int -> return specified (-ve indices specify from end) |
| 45 | + as_set: return as a set |
| 46 | +
|
| 47 | + Returns: |
| 48 | + List (or set) of absolute (from beginning) indices, Maximum index |
| 49 | + """ |
60 | 50 | if indices is None: |
61 | | - indices = num_blocks # all blocks if None |
62 | | - if torch.jit.is_scripting(): |
63 | | - return _take_indices_jit(num_blocks, indices) |
| 51 | + indices = num_features # all features if None |
| 52 | + |
| 53 | + if isinstance(indices, int): |
| 54 | + # convert int -> last n indices |
| 55 | + _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') |
| 56 | + take_indices = [num_features - indices + i for i in range(indices)] |
64 | 57 | else: |
65 | | - # NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno |
66 | | - return _take_indices(num_blocks, indices) |
| 58 | + take_indices: List[int] = [] |
| 59 | + for i in indices: |
| 60 | + idx = num_features + i if i < 0 else i |
| 61 | + _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') |
| 62 | + take_indices.append(idx) |
| 63 | + |
| 64 | + if not torch.jit.is_scripting() and as_set: |
| 65 | + return set(take_indices), max(take_indices) |
| 66 | + |
| 67 | + return take_indices, max(take_indices) |
67 | 68 |
|
68 | 69 |
|
69 | 70 | def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: |
@@ -464,7 +465,6 @@ def __init__( |
464 | 465 | out_indices, |
465 | 466 | prune_norm=not norm, |
466 | 467 | ) |
467 | | - out_indices = list(out_indices) |
468 | 468 | self.feature_info = _get_feature_info(model, out_indices) |
469 | 469 | self.model = model |
470 | 470 | self.out_indices = out_indices |
|
0 commit comments