Skip to content

Commit 6bff9c7

Browse files
committed
Cleanup model_factory imports, consistent __all__ for models, fixed inception_v4 weight url
1 parent e6c1442 commit 6bff9c7

File tree

12 files changed

+71
-74
lines changed

12 files changed

+71
-74
lines changed

models/densenet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1313
import re
1414

15-
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
15+
_models = ['densenet121', 'densenet169', 'densenet201', 'densenet161']
16+
__all__ = ['DenseNet'] + _models
1617

1718

1819
def _cfg(url=''):

models/dpn.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from models.adaptive_avgmax_pool import select_adaptive_pool2d
2020
from data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
2121

22-
__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
22+
_models = ['dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
23+
__all__ = ['DPN'] + _models
2324

2425

2526
def _cfg(url=''):
@@ -32,18 +33,12 @@ def _cfg(url=''):
3233

3334

3435
default_cfgs = {
35-
'dpn68':
36-
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth'),
37-
'dpn68b_extra':
38-
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth'),
39-
'dpn92_extra':
40-
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'),
41-
'dpn98':
42-
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth'),
43-
'dpn131':
44-
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth'),
45-
'dpn107_extra':
46-
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth')
36+
'dpn68': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth'),
37+
'dpn68b_extra': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth'),
38+
'dpn92_extra': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'),
39+
'dpn98': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth'),
40+
'dpn131': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth'),
41+
'dpn107_extra': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth')
4742
}
4843

4944

models/genmobilenet.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@
2626
from models.conv2d_same import sconv2d
2727
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2828

29-
__all__ = ['GenMobileNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140',
30-
'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'semnasnet_140', 'mnasnet_small',
31-
'mobilenetv1_100', 'mobilenetv2_100', 'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100',
32-
'chamnetv1_100', 'chamnetv2_100', 'fbnetc_100', 'spnasnet_100']
29+
_models = [
30+
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
31+
'semnasnet_100', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
32+
'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100',
33+
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100']
34+
__all__ = ['GenMobileNet', 'genmobilenet_model_names'] + _models
3335

3436

3537
def _cfg(url='', **kwargs):
@@ -67,7 +69,7 @@ def _cfg(url='', **kwargs):
6769
'spnasnet_100': _cfg(url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1'),
6870
}
6971

70-
_DEBUG = True
72+
_DEBUG = False
7173

7274
# Default args for PyTorch BN impl
7375
_BN_MOMENTUM_PT_DEFAULT = 0.1
@@ -266,7 +268,7 @@ class _BlockBuilder:
266268
def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
267269
act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
268270
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
269-
folded_bn=False, padding_same=False):
271+
folded_bn=False, padding_same=False, verbose=False):
270272
self.depth_multiplier = depth_multiplier
271273
self.depth_divisor = depth_divisor
272274
self.min_depth = min_depth
@@ -277,6 +279,7 @@ def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
277279
self.bn_eps = bn_eps
278280
self.folded_bn = folded_bn
279281
self.padding_same = padding_same
282+
self.verbose = verbose
280283
self.in_chs = None
281284

282285
def _round_channels(self, chs):
@@ -293,7 +296,7 @@ def _make_block(self, ba):
293296
# block act fn overrides the model default
294297
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
295298
assert ba['act_fn'] is not None
296-
if _DEBUG:
299+
if self.verbose:
297300
print('args:', ba)
298301
# could replace this if with lambdas or functools binding if variety increases
299302
if bt == 'ir':
@@ -315,7 +318,7 @@ def _make_stack(self, stack_args):
315318
blocks = []
316319
# each stack (stage) contains a list of block arguments
317320
for block_idx, ba in enumerate(stack_args):
318-
if _DEBUG:
321+
if self.verbose:
319322
print('block', block_idx, end=', ')
320323
if block_idx >= 1:
321324
# only the first block in any stack/stage can have a stride > 1
@@ -334,18 +337,18 @@ def __call__(self, in_chs, arch_def):
334337
List of block stacks (each stack wrapped in nn.Sequential)
335338
"""
336339
arch_args = _decode_arch_def(arch_def) # convert and expand string defs to arg dicts
337-
if _DEBUG:
340+
if self.verbose:
338341
print('Building model trunk with %d stacks (stages)...' % len(arch_args))
339342
self.in_chs = in_chs
340343
blocks = []
341344
# outer list of arch_args defines the stacks ('stages' by some conventions)
342345
for stack_idx, stack in enumerate(arch_args):
343-
if _DEBUG:
346+
if self.verbose:
344347
print('stack', stack_idx)
345348
assert isinstance(stack, list)
346349
stack = self._make_stack(stack)
347350
blocks.append(stack)
348-
if _DEBUG:
351+
if self.verbose:
349352
print()
350353
return blocks
351354

@@ -631,7 +634,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
631634
builder = _BlockBuilder(
632635
depth_multiplier, depth_divisor, min_depth,
633636
act_fn, se_gate_fn, se_reduce_mid,
634-
bn_momentum, bn_eps, folded_bn, padding_same)
637+
bn_momentum, bn_eps, folded_bn, padding_same, verbose=_DEBUG)
635638
self.blocks = nn.Sequential(*builder(in_chs, block_args))
636639
in_chs = builder.in_chs
637640

@@ -1265,3 +1268,7 @@ def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
12651268
if pretrained:
12661269
load_pretrained(model, default_cfg, num_classes, in_chans)
12671270
return model
1271+
1272+
1273+
def genmobilenet_model_names():
1274+
return set(_models)

models/gluon_resnet.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
1212
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1313

14-
__all__ = ['GluonResNet', 'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet101_v1b',
15-
'gluon_resnet152_v1b', 'gluon_resnet50_v1c', 'gluon_resnet101_v1c', 'gluon_resnet152_v1c', 'gluon_resnet50_v1d',
16-
'gluon_resnet101_v1d', 'gluon_resnet152_v1d', 'gluon_resnet50_v1e', 'gluon_resnet101_v1e', 'gluon_resnet152_v1e',
17-
'gluon_resnet50_v1s', 'gluon_resnet101_v1s', 'gluon_resnet152_v1s', 'gluon_resnext50_32x4d',
18-
'gluon_resnext101_32x4d', 'gluon_resnext101_64x4d', 'gluon_resnext152_32x4d', 'gluon_seresnext50_32x4d',
19-
'gluon_seresnext101_32x4d', 'gluon_seresnext101_64x4d', 'gluon_seresnext152_32x4d', 'gluon_senet154'
20-
]
14+
_models = [
15+
'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet101_v1b', 'gluon_resnet152_v1b',
16+
'gluon_resnet50_v1c', 'gluon_resnet101_v1c', 'gluon_resnet152_v1c', 'gluon_resnet50_v1d', 'gluon_resnet101_v1d',
17+
'gluon_resnet152_v1d', 'gluon_resnet50_v1e', 'gluon_resnet101_v1e', 'gluon_resnet152_v1e', 'gluon_resnet50_v1s',
18+
'gluon_resnet101_v1s', 'gluon_resnet152_v1s', 'gluon_resnext50_32x4d', 'gluon_resnext101_32x4d',
19+
'gluon_resnext101_64x4d', 'gluon_resnext152_32x4d', 'gluon_seresnext50_32x4d', 'gluon_seresnext101_32x4d',
20+
'gluon_seresnext101_64x4d', 'gluon_seresnext152_32x4d', 'gluon_senet154']
21+
__all__ = ['GluonResNet'] + _models
2122

2223

2324
def _cfg(url='', **kwargs):

models/inception_resnet_v2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from models.adaptive_avgmax_pool import *
1010
from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1111

12+
_models = ['inception_resnet_v2']
13+
__all__ = ['InceptionResnetV2'] + _models
14+
1215
default_cfgs = {
1316
'inception_resnet_v2': {
1417
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',

models/inception_v3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from models.helpers import load_pretrained
33
from data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
44

5+
_models = ['inception_v3', 'tf_inception_v3', 'adv_inception_v3', 'gluon_inception_v3']
6+
__all__ = _models
7+
58
default_cfgs = {
69
# original PyTorch weights, ported from Tensorflow but modified
710
'inception_v3': {

models/inception_v4.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
from models.adaptive_avgmax_pool import *
1010
from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1111

12+
_models = ['inception_v4']
13+
__all__ = ['InceptionV4'] + _models
14+
1215
default_cfgs = {
1316
'inception_v4': {
14-
'url': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth',
17+
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
1518
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
1619
'crop_pct': 0.875, 'interpolation': 'bicubic',
1720
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
18-
'first_conv': 'features.0.conv', 'classifier': 'classif',
21+
'first_conv': 'features.0.conv', 'classifier': 'last_linear',
1922
}
2023
}
2124

@@ -268,7 +271,7 @@ def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'
268271
Inception_C(),
269272
Inception_C(),
270273
)
271-
self.classif = nn.Linear(self.num_features, num_classes)
274+
self.last_linear = nn.Linear(self.num_features, num_classes)
272275

273276
def get_classifier(self):
274277
return self.classif
@@ -289,7 +292,7 @@ def forward(self, x):
289292
x = self.forward_features(x)
290293
if self.drop_rate > 0:
291294
x = F.dropout(x, p=self.drop_rate, training=self.training)
292-
x = self.classif(x)
295+
x = self.last_linear(x)
293296
return x
294297

295298

models/model_factory.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,18 @@
1-
from models.inception_v4 import inception_v4
2-
from models.inception_resnet_v2 import inception_resnet_v2
3-
from models.densenet import densenet161, densenet121, densenet169, densenet201
4-
from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
5-
resnext50_32x4d, resnext101_32x4d, resnext101_64x4d, resnext152_32x4d
6-
from models.dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
7-
from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
8-
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
9-
from models.xception import xception
10-
from models.pnasnet import pnasnet5large
11-
from models.genmobilenet import \
12-
mnasnet_050, mnasnet_075, mnasnet_100, mnasnet_140, tflite_mnasnet_100,\
13-
semnasnet_050, semnasnet_075, semnasnet_100, semnasnet_140, tflite_semnasnet_100, mnasnet_small,\
14-
mobilenetv1_100, mobilenetv2_100, mobilenetv3_050, mobilenetv3_075, mobilenetv3_100,\
15-
fbnetc_100, chamnetv1_100, chamnetv2_100, spnasnet_100
16-
from models.inception_v3 import inception_v3, gluon_inception_v3, tf_inception_v3, adv_inception_v3
17-
from models.gluon_resnet import gluon_resnet18_v1b, gluon_resnet34_v1b, gluon_resnet50_v1b, gluon_resnet101_v1b, \
18-
gluon_resnet152_v1b, gluon_resnet50_v1c, gluon_resnet101_v1c, gluon_resnet152_v1c, \
19-
gluon_resnet50_v1d, gluon_resnet101_v1d, gluon_resnet152_v1d, \
20-
gluon_resnet50_v1e, gluon_resnet101_v1e, gluon_resnet152_v1e, \
21-
gluon_resnet50_v1s, gluon_resnet101_v1s, gluon_resnet152_v1s, \
22-
gluon_resnext50_32x4d, gluon_resnext101_32x4d , gluon_resnext101_64x4d, gluon_resnext152_32x4d, \
23-
gluon_seresnext50_32x4d, gluon_seresnext101_32x4d, gluon_seresnext101_64x4d, gluon_seresnext152_32x4d, \
24-
gluon_senet154
1+
from models.inception_v4 import *
2+
from models.inception_resnet_v2 import *
3+
from models.densenet import *
4+
from models.resnet import *
5+
from models.dpn import *
6+
from models.senet import *
7+
from models.xception import *
8+
from models.pnasnet import *
9+
from models.genmobilenet import *
10+
from models.inception_v3 import *
11+
from models.gluon_resnet import *
2512

2613
from models.helpers import load_checkpoint
2714

2815

29-
def _is_genmobilenet(name):
30-
genmobilenets = ['mnasnet', 'semnasnet', 'fbnet', 'chamnet', 'mobilenet']
31-
if any([name.startswith(x) for x in genmobilenets]):
32-
return True
33-
return False
34-
35-
3616
def create_model(
3717
model_name='resnet50',
3818
pretrained=None,
@@ -44,8 +24,7 @@ def create_model(
4424
margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)
4525

4626
# Not all models have support for batchnorm params passed as args, only genmobilenet variants
47-
# FIXME better way to do this without pushing support into every other model fn?
48-
supports_bn_params = _is_genmobilenet(model_name)
27+
supports_bn_params = model_name in genmobilenet_model_names()
4928
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
5029
kwargs.pop('bn_tf', None)
5130
kwargs.pop('bn_momentum', None)

models/pnasnet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from models.helpers import load_pretrained
1616
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
1717

18+
_models = ['pnasnet5large']
19+
__all__ = ['PNASNet5Large'] + _models
20+
1821
default_cfgs = {
1922
'pnasnet5large': {
2023
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',

models/resnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
1313
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1414

15-
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
15+
_models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
1616
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']
17+
__all__ = ['ResNet'] + _models
1718

1819

1920
def _cfg(url='', **kwargs):

0 commit comments

Comments
 (0)