Skip to content

Commit c3fbdd4

Browse files
committed
Fix efficient head for MobileNetV3
1 parent 17da1ad commit c3fbdd4

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

models/genmobilenet.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def _make_block(self, ba):
290290
ba['bn_eps'] = self.bn_eps
291291
ba['folded_bn'] = self.folded_bn
292292
ba['padding_same'] = self.padding_same
293+
# block act fn overrides the model default
293294
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
294295
assert ba['act_fn'] is not None
295296
if _DEBUG:
@@ -611,15 +612,14 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
611612
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
612613
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
613614
drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
614-
global_pool='avg', skip_head_conv=False, efficient_head=False,
615-
weight_init='goog', folded_bn=False, padding_same=False):
615+
global_pool='avg', head_conv='default', weight_init='goog',
616+
folded_bn=False, padding_same=False):
616617
super(GenMobileNet, self).__init__()
617618
self.num_classes = num_classes
618619
self.depth_multiplier = depth_multiplier
619620
self.drop_rate = drop_rate
620621
self.act_fn = act_fn
621622
self.num_features = num_features
622-
self.efficient_head = efficient_head # pool before last conv
623623

624624
stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth)
625625
self.conv_stem = sconv2d(
@@ -629,19 +629,22 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
629629
in_chs = stem_size
630630

631631
builder = _BlockBuilder(
632-
depth_multiplier, depth_divisor, min_depth, act_fn, se_gate_fn, se_reduce_mid,
632+
depth_multiplier, depth_divisor, min_depth,
633+
act_fn, se_gate_fn, se_reduce_mid,
633634
bn_momentum, bn_eps, folded_bn, padding_same)
634635
self.blocks = nn.Sequential(*builder(in_chs, block_args))
635636
in_chs = builder.in_chs
636637

637-
if skip_head_conv:
638+
if not head_conv or head_conv == 'none':
639+
self.efficient_head = False
638640
self.conv_head = None
639641
assert in_chs == self.num_features
640642
else:
643+
self.efficient_head = head_conv == 'efficient'
641644
self.conv_head = sconv2d(
642645
in_chs, self.num_features, 1,
643-
padding=_padding_arg(0, padding_same), bias=folded_bn and not efficient_head)
644-
self.bn2 = None if (folded_bn or efficient_head) else \
646+
padding=_padding_arg(0, padding_same), bias=folded_bn and not self.efficient_head)
647+
self.bn2 = None if (folded_bn or self.efficient_head) else \
645648
nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
646649

647650
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@@ -674,7 +677,7 @@ def forward_features(self, x, pool=True):
674677
x = self.blocks(x)
675678
if self.efficient_head:
676679
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
677-
x = self.global_pool(x) # always need to pool here regardless of bool
680+
x = self.global_pool(x) # always need to pool here regardless of flag
678681
x = self.conv_head(x)
679682
# no BN
680683
x = self.act_fn(x)
@@ -836,7 +839,7 @@ def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs):
836839
bn_momentum=bn_momentum,
837840
bn_eps=bn_eps,
838841
act_fn=F.relu6,
839-
skip_head_conv=True,
842+
head_conv='none',
840843
**kwargs
841844
)
842845
return model
@@ -914,6 +917,7 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
914917
act_fn=hard_swish,
915918
se_gate_fn=hard_sigmoid,
916919
se_reduce_mid=True,
920+
head_conv='efficient',
917921
**kwargs
918922
)
919923
return model

0 commit comments

Comments
 (0)