Skip to content

Commit f313497

Browse files
committed
Add the MnasNet-B1 variant weights, add A1/B1 model names as in the stand-alone repo, remove a bit of unused code
1 parent c1a84ec commit f313497

File tree

2 files changed

+19
-49
lines changed

2 files changed

+19
-49
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,11 @@ I've leveraged the training scripts in this repository to train a few of the mod
6969
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic |
7070
| efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic |
7171
| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic |
72-
| semnasnet_100 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic |
72+
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic |
7373
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear |
7474
| resnet34 | 75.110 (24.890) | 92.284 (7.716) | 22M | bilinear |
7575
| seresnet34 | 74.808 (25.192) | 92.124 (7.876) | 22M | bilinear |
76+
| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.38M | bicubic |
7677
| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.42M | bilinear |
7778
| seresnet18 | 71.742 (28.258) | 90.334 (9.666) | 11.8M | bicubic |
7879

models/gen_efficientnet.py

Lines changed: 17 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2929

3030
_models = [
31-
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
32-
'semnasnet_100', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
31+
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
32+
'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
3333
'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100',
3434
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0',
3535
'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'tf_efficientnet_b0',
@@ -50,7 +50,9 @@ def _cfg(url='', **kwargs):
5050
default_cfgs = {
5151
'mnasnet_050': _cfg(url=''),
5252
'mnasnet_075': _cfg(url=''),
53-
'mnasnet_100': _cfg(url=''),
53+
'mnasnet_100': _cfg(
54+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
55+
interpolation='bicubic'),
5456
'tflite_mnasnet_100': _cfg(
5557
url='https://www.dropbox.com/s/q55ir3tx8mpeyol/tflite_mnasnet_100-31639cdc.pth?dl=1',
5658
interpolation='bicubic'),
@@ -161,8 +163,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
161163
is assumed to indicate the block type.
162164
163165
leading string - block type (
164-
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act,
165-
ca = Cascade3x3, and possibly more)
166+
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
166167
r - number of repeat blocks,
167168
k - kernel size,
168169
s - strides (1-9),
@@ -227,15 +228,6 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
227228
block_args['pw_group'] = options['g']
228229
if options['g'] > 1:
229230
block_args['shuffle_type'] = 'mid'
230-
elif block_type == 'ca':
231-
block_args = dict(
232-
block_type=block_type,
233-
kernel_size=int(options['k']),
234-
out_chs=int(options['c']),
235-
stride=int(options['s']),
236-
act_fn=act_fn,
237-
noskip=noskip,
238-
)
239231
elif block_type == 'ds' or block_type == 'dsa':
240232
block_args = dict(
241233
block_type=block_type,
@@ -345,8 +337,6 @@ def _make_block(self, ba):
345337
elif bt == 'ds' or bt == 'dsa':
346338
ba['drop_connect_rate'] = self.drop_connect_rate
347339
block = DepthwiseSeparableConv(**ba)
348-
elif bt == 'ca':
349-
block = CascadeConv(**ba)
350340
elif bt == 'cn':
351341
block = ConvBnAct(**ba)
352342
else:
@@ -565,36 +555,6 @@ def forward(self, x):
565555
return x
566556

567557

568-
class CascadeConv(nn.Sequential):
569-
# FIXME haven't used yet
570-
def __init__(self, in_chs, out_chs, kernel_size=3, stride=2, act_fn=F.relu, noskip=False,
571-
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
572-
folded_bn=False, padding_same=False):
573-
super(CascadeConv, self).__init__()
574-
assert stride in [1, 2]
575-
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
576-
self.act_fn = act_fn
577-
padding = _padding_arg(1, padding_same)
578-
579-
self.conv1 = sconv2d(in_chs, in_chs, kernel_size, stride=stride, padding=padding, bias=folded_bn)
580-
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
581-
self.conv2 = sconv2d(in_chs, out_chs, kernel_size, stride=1, padding=padding, bias=folded_bn)
582-
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
583-
584-
def forward(self, x):
585-
residual = x
586-
x = self.conv1(x)
587-
if self.bn1 is not None:
588-
x = self.bn1(x)
589-
x = self.act_fn(x)
590-
x = self.conv2(x)
591-
if self.bn2 is not None:
592-
x = self.bn2(x)
593-
if self.has_residual:
594-
x += residual
595-
return x
596-
597-
598558
class InvertedResidual(nn.Module):
599559
""" Inverted residual block w/ optional SE"""
600560

@@ -699,7 +659,6 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
699659
super(GenEfficientNet, self).__init__()
700660
self.num_classes = num_classes
701661
self.drop_rate = drop_rate
702-
self.drop_connect_rate = drop_connect_rate
703662
self.act_fn = act_fn
704663
self.num_features = num_features
705664

@@ -730,7 +689,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
730689
nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
731690

732691
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
733-
self.classifier = nn.Linear(self.num_features, self.num_classes)
692+
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
734693

735694
for m in self.modules():
736695
if weight_init == 'goog':
@@ -1220,6 +1179,11 @@ def mnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
12201179
return model
12211180

12221181

1182+
def mnasnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
1183+
""" MNASNet B1, depth multiplier of 1.0. """
1184+
return mnasnet_100(num_classes, in_chans, pretrained, **kwargs)
1185+
1186+
12231187
def tflite_mnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
12241188
""" MNASNet B1, depth multiplier of 1.0. """
12251189
default_cfg = default_cfgs['tflite_mnasnet_100']
@@ -1273,6 +1237,11 @@ def semnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
12731237
return model
12741238

12751239

1240+
def mnasnet_a1(num_classes, in_chans=3, pretrained=False, **kwargs):
1241+
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
1242+
return semnasnet_100(num_classes, in_chans, pretrained, **kwargs)
1243+
1244+
12761245
def tflite_semnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
12771246
""" MNASNet A1, depth multiplier of 1.0. """
12781247
default_cfg = default_cfgs['tflite_semnasnet_100']

0 commit comments

Comments
 (0)