Skip to content

Commit bdb165a

Browse files
committed
Merge changes in feature extraction interface to MobileNetV3
Experimental feature extraction interface seems to be changed a little bit with the most up to date version apparently found in EfficientNet class. Here these changes are added to MobileNetV3 class to make it support it and work again, too.
1 parent 13cf688 commit bdb165a

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

timm/models/mobilenetv3.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
148148
and object detection models.
149149
"""
150150

151-
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
151+
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
152152
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
153153
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
154154
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
@@ -174,34 +174,47 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
174174
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
175175
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
176176
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
177-
self.feature_info = builder.features # builder provides info about feature channels for each block
177+
self._feature_info = builder.features # builder provides info about feature channels for each block
178+
self._stage_to_feature_idx = {
179+
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
178180
self._in_chs = builder.in_chs
179181

180182
efficientnet_init_weights(self)
181183
if _DEBUG:
182-
for k, v in self.feature_info.items():
184+
for k, v in self._feature_info.items():
183185
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
184186

185187
# Register feature extraction hooks with FeatureHooks helper
186-
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
187-
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
188-
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
188+
self.feature_hooks = None
189+
if feature_location != 'bottleneck':
190+
hooks = [dict(
191+
name=self._feature_info[idx]['module'],
192+
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
193+
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
189194

190195
def feature_channels(self, idx=None):
191196
""" Feature Channel Shortcut
192197
Returns feature channel count for each output index if idx == None. If idx is an integer, will
193198
return feature channel count for that feature block index (independent of out_indices setting).
194199
"""
195200
if isinstance(idx, int):
196-
return self.feature_info[idx]['num_chs']
197-
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
201+
return self._feature_info[idx]['num_chs']
202+
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
198203

199204
def forward(self, x):
200205
x = self.conv_stem(x)
201206
x = self.bn1(x)
202207
x = self.act1(x)
203-
self.blocks(x)
204-
return self.feature_hooks.get_output(x.device)
208+
if self.feature_hooks is None:
209+
features = []
210+
for i, b in enumerate(self.blocks):
211+
x = b(x)
212+
if i in self._stage_to_feature_idx:
213+
features.append(x)
214+
return features
215+
else:
216+
self.blocks(x)
217+
return self.feature_hooks.get_output(x.device)
205218

206219

207220
def _create_model(model_kwargs, default_cfg, pretrained=False):

0 commit comments

Comments
 (0)