Skip to content

Commit e15f979

Browse files
authored
Merge pull request #123 from aclex/mobilenetv3_fix_feature_extraction
Merge changes in feature extraction interface to MobileNetV3
2 parents 13cf688 + bdb165a commit e15f979

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

timm/models/mobilenetv3.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
148148
and object detection models.
149149
"""
150150

151-
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
151+
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
152152
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
153153
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
154154
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
@@ -174,34 +174,47 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
174174
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
175175
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
176176
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
177-
self.feature_info = builder.features # builder provides info about feature channels for each block
177+
self._feature_info = builder.features # builder provides info about feature channels for each block
178+
self._stage_to_feature_idx = {
179+
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
178180
self._in_chs = builder.in_chs
179181

180182
efficientnet_init_weights(self)
181183
if _DEBUG:
182-
for k, v in self.feature_info.items():
184+
for k, v in self._feature_info.items():
183185
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
184186

185187
# Register feature extraction hooks with FeatureHooks helper
186-
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
187-
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
188-
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
188+
self.feature_hooks = None
189+
if feature_location != 'bottleneck':
190+
hooks = [dict(
191+
name=self._feature_info[idx]['module'],
192+
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
193+
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
189194

190195
def feature_channels(self, idx=None):
191196
""" Feature Channel Shortcut
192197
Returns feature channel count for each output index if idx == None. If idx is an integer, will
193198
return feature channel count for that feature block index (independent of out_indices setting).
194199
"""
195200
if isinstance(idx, int):
196-
return self.feature_info[idx]['num_chs']
197-
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
201+
return self._feature_info[idx]['num_chs']
202+
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
198203

199204
def forward(self, x):
200205
x = self.conv_stem(x)
201206
x = self.bn1(x)
202207
x = self.act1(x)
203-
self.blocks(x)
204-
return self.feature_hooks.get_output(x.device)
208+
if self.feature_hooks is None:
209+
features = []
210+
for i, b in enumerate(self.blocks):
211+
x = b(x)
212+
if i in self._stage_to_feature_idx:
213+
features.append(x)
214+
return features
215+
else:
216+
self.blocks(x)
217+
return self.feature_hooks.get_output(x.device)
205218

206219

207220
def _create_model(model_kwargs, default_cfg, pretrained=False):

0 commit comments

Comments
 (0)