Skip to content

Commit fa81164

Browse files
committed
Fix stem width for really small mobilenetv3 arch defs
1 parent edd3d73 commit fa81164

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

timm/models/mobilenetv3.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,10 @@ class MobileNetV3(nn.Module):
110110
* LCNet - https://arxiv.org/abs/2109.15099
111111
"""
112112

113-
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
114-
pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
115-
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
113+
def __init__(
114+
self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,
115+
head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
116+
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
116117
super(MobileNetV3, self).__init__()
117118
act_layer = act_layer or nn.ReLU
118119
norm_layer = norm_layer or nn.BatchNorm2d
@@ -122,7 +123,8 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_f
122123
self.drop_rate = drop_rate
123124

124125
# Stem
125-
stem_size = round_chs_fn(stem_size)
126+
if not fix_stem:
127+
stem_size = round_chs_fn(stem_size)
126128
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
127129
self.bn1 = norm_layer(stem_size)
128130
self.act1 = act_layer(inplace=True)
@@ -188,16 +190,17 @@ class MobileNetV3Features(nn.Module):
188190
"""
189191

190192
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
191-
stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True,
192-
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
193+
stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
194+
se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
193195
super(MobileNetV3Features, self).__init__()
194196
act_layer = act_layer or nn.ReLU
195197
norm_layer = norm_layer or nn.BatchNorm2d
196198
se_layer = se_layer or SqueezeExcite
197199
self.drop_rate = drop_rate
198200

199201
# Stem
200-
stem_size = round_chs_fn(stem_size)
202+
if not fix_stem:
203+
stem_size = round_chs_fn(stem_size)
201204
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
202205
self.bn1 = norm_layer(stem_size)
203206
self.act1 = act_layer(inplace=True)
@@ -381,6 +384,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
381384
block_args=decode_arch_def(arch_def),
382385
num_features=num_features,
383386
stem_size=16,
387+
fix_stem=channel_multiplier < 0.75,
384388
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
385389
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
386390
act_layer=act_layer,

0 commit comments

Comments
 (0)