Skip to content

Commit 7be2995

Browse files
committed
Add missing feature_info() on MobileNetV3, make hook feature output order/type consistent with bottleneck (list, decreasing fmap size)
1 parent 88129b2 commit 7be2995

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

timm/models/efficientnet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424
2525
Hacked together by Ross Wightman
2626
"""
27+
import torch
2728
import torch.nn as nn
2829
import torch.nn.functional as F
2930

31+
from typing import List
32+
3033
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3134
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
3235
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
@@ -471,7 +474,7 @@ def feature_info(self, idx=None):
471474
return self._feature_info[idx]
472475
return [self._feature_info[i] for i in self.out_indices]
473476

474-
def forward(self, x):
477+
def forward(self, x) -> List[torch.Tensor]:
475478
x = self.conv_stem(x)
476479
x = self.bn1(x)
477480
x = self.act1(x)

timm/models/feature_hooks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import torch
2+
13
from collections import defaultdict, OrderedDict
24
from functools import partial
5+
from typing import List
36

47

58
class FeatureHooks:
@@ -25,7 +28,7 @@ def _collect_output_hook(self, name, *args):
2528
x = x[0] # unwrap input tuple
2629
self._feature_outputs[x.device][name] = x
2730

28-
def get_output(self, device):
29-
output = tuple(self._feature_outputs[device].values())[::-1]
31+
def get_output(self, device) -> List[torch.tensor]:
32+
output = list(self._feature_outputs[device].values())
3033
self._feature_outputs[device] = OrderedDict() # clear after reading
3134
return output

timm/models/mobilenetv3.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
88
Hacked together by Ross Wightman
99
"""
10+
import torch
1011
import torch.nn as nn
1112
import torch.nn.functional as F
1213

14+
from typing import List
15+
1316
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1417
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
1518
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
@@ -206,7 +209,16 @@ def feature_channels(self, idx=None):
206209
return self._feature_info[idx]['num_chs']
207210
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
208211

209-
def forward(self, x):
212+
def feature_info(self, idx=None):
213+
""" Feature Channel Shortcut
214+
Returns feature channel count for each output index if idx == None. If idx is an integer, will
215+
return feature channel count for that feature block index (independent of out_indices setting).
216+
"""
217+
if isinstance(idx, int):
218+
return self._feature_info[idx]
219+
return [self._feature_info[i] for i in self.out_indices]
220+
221+
def forward(self, x) -> List[torch.Tensor]:
210222
x = self.conv_stem(x)
211223
x = self.bn1(x)
212224
x = self.act1(x)

0 commit comments

Comments
 (0)