Skip to content

Commit 656e177

Browse files
committed
Convert mobilenetv3 to multi-weight, tweak PretrainedCfg metadata
1 parent 6a01101 commit 656e177

File tree

2 files changed

+167
-115
lines changed

2 files changed

+167
-115
lines changed

timm/models/_pretrained.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ class PretrainedCfg:
4545
classifier: Optional[str] = None
4646

4747
license: Optional[str] = None
48-
source_url: Optional[str] = None
49-
paper: Optional[str] = None
50-
notes: Optional[str] = None
48+
description: Optional[str] = None
49+
origin_url: Optional[str] = None
50+
paper_name: Optional[str] = None
51+
paper_ids: Optional[Union[str, Tuple[str]]] = None
52+
notes: Optional[Tuple[str]] = None
5153

5254
@property
5355
def has_weights(self):

timm/models/mobilenetv3.py

Lines changed: 162 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -21,93 +21,12 @@
2121
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
2222
from ._features import FeatureInfo, FeatureHooks
2323
from ._manipulate import checkpoint_seq
24+
from ._pretrained import generate_default_cfgs
2425
from ._registry import register_model
2526

2627
__all__ = ['MobileNetV3', 'MobileNetV3Features']
2728

2829

29-
def _cfg(url='', **kwargs):
30-
return {
31-
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
32-
'crop_pct': 0.875, 'interpolation': 'bilinear',
33-
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
34-
'first_conv': 'conv_stem', 'classifier': 'classifier',
35-
**kwargs
36-
}
37-
38-
39-
default_cfgs = {
40-
'mobilenetv3_large_075': _cfg(url=''),
41-
'mobilenetv3_large_100': _cfg(
42-
interpolation='bicubic',
43-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
44-
'mobilenetv3_large_100_miil': _cfg(
45-
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.),
46-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth'),
47-
'mobilenetv3_large_100_miil_in21k': _cfg(
48-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth',
49-
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
50-
51-
'mobilenetv3_small_050': _cfg(
52-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
53-
interpolation='bicubic'),
54-
'mobilenetv3_small_075': _cfg(
55-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth',
56-
interpolation='bicubic'),
57-
'mobilenetv3_small_100': _cfg(
58-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth',
59-
interpolation='bicubic'),
60-
61-
'mobilenetv3_rw': _cfg(
62-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
63-
interpolation='bicubic'),
64-
65-
'tf_mobilenetv3_large_075': _cfg(
66-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
67-
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
68-
'tf_mobilenetv3_large_100': _cfg(
69-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
70-
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
71-
'tf_mobilenetv3_large_minimal_100': _cfg(
72-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
73-
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
74-
'tf_mobilenetv3_small_075': _cfg(
75-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
76-
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
77-
'tf_mobilenetv3_small_100': _cfg(
78-
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
79-
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
80-
'tf_mobilenetv3_small_minimal_100': _cfg(
81-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
82-
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
83-
84-
'fbnetv3_b': _cfg(
85-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth',
86-
test_input_size=(3, 256, 256), crop_pct=0.95),
87-
'fbnetv3_d': _cfg(
88-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth',
89-
test_input_size=(3, 256, 256), crop_pct=0.95),
90-
'fbnetv3_g': _cfg(
91-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',
92-
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),
93-
94-
"lcnet_035": _cfg(),
95-
"lcnet_050": _cfg(
96-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
97-
interpolation='bicubic',
98-
),
99-
"lcnet_075": _cfg(
100-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
101-
interpolation='bicubic',
102-
),
103-
"lcnet_100": _cfg(
104-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
105-
interpolation='bicubic',
106-
),
107-
"lcnet_150": _cfg(),
108-
}
109-
110-
11130
class MobileNetV3(nn.Module):
11231
""" MobiletNet-V3
11332
@@ -124,9 +43,24 @@ class MobileNetV3(nn.Module):
12443
"""
12544

12645
def __init__(
127-
self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,
128-
head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
129-
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
46+
self,
47+
block_args,
48+
num_classes=1000,
49+
in_chans=3,
50+
stem_size=16,
51+
fix_stem=False,
52+
num_features=1280,
53+
head_bias=True,
54+
pad_type='',
55+
act_layer=None,
56+
norm_layer=None,
57+
se_layer=None,
58+
se_from_exp=True,
59+
round_chs_fn=round_channels,
60+
drop_rate=0.,
61+
drop_path_rate=0.,
62+
global_pool='avg',
63+
):
13064
super(MobileNetV3, self).__init__()
13165
act_layer = act_layer or nn.ReLU
13266
norm_layer = norm_layer or nn.BatchNorm2d
@@ -145,8 +79,15 @@ def __init__(
14579

14680
# Middle stages (IR/ER/DS Blocks)
14781
builder = EfficientNetBuilder(
148-
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
149-
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
82+
output_stride=32,
83+
pad_type=pad_type,
84+
round_chs_fn=round_chs_fn,
85+
se_from_exp=se_from_exp,
86+
act_layer=act_layer,
87+
norm_layer=norm_layer,
88+
se_layer=se_layer,
89+
drop_path_rate=drop_path_rate,
90+
)
15091
self.blocks = nn.Sequential(*builder(stem_size, block_args))
15192
self.feature_info = builder.features
15293
head_chs = builder.in_chs
@@ -225,9 +166,23 @@ class MobileNetV3Features(nn.Module):
225166
"""
226167

227168
def __init__(
228-
self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
229-
stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
230-
se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
169+
self,
170+
block_args,
171+
out_indices=(0, 1, 2, 3, 4),
172+
feature_location='bottleneck',
173+
in_chans=3,
174+
stem_size=16,
175+
fix_stem=False,
176+
output_stride=32,
177+
pad_type='',
178+
round_chs_fn=round_channels,
179+
se_from_exp=True,
180+
act_layer=None,
181+
norm_layer=None,
182+
se_layer=None,
183+
drop_rate=0.,
184+
drop_path_rate=0.,
185+
):
231186
super(MobileNetV3Features, self).__init__()
232187
act_layer = act_layer or nn.ReLU
233188
norm_layer = norm_layer or nn.BatchNorm2d
@@ -243,9 +198,16 @@ def __init__(
243198

244199
# Middle stages (IR/ER/DS Blocks)
245200
builder = EfficientNetBuilder(
246-
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
247-
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
248-
drop_path_rate=drop_path_rate, feature_location=feature_location)
201+
output_stride=output_stride,
202+
pad_type=pad_type,
203+
round_chs_fn=round_chs_fn,
204+
se_from_exp=se_from_exp,
205+
act_layer=act_layer,
206+
norm_layer=norm_layer,
207+
se_layer=se_layer,
208+
drop_path_rate=drop_path_rate,
209+
feature_location=feature_location,
210+
)
249211
self.blocks = nn.Sequential(*builder(stem_size, block_args))
250212
self.feature_info = FeatureInfo(builder.features, out_indices)
251213
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
@@ -286,7 +248,9 @@ def _create_mnv3(variant, pretrained=False, **kwargs):
286248
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
287249
model_cls = MobileNetV3Features
288250
model = build_model_with_cfg(
289-
model_cls, variant, pretrained,
251+
model_cls,
252+
variant,
253+
pretrained,
290254
pretrained_strict=not features_only,
291255
kwargs_filter=kwargs_filter,
292256
**kwargs)
@@ -567,6 +531,110 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
567531
return model
568532

569533

534+
def _cfg(url='', **kwargs):
535+
return {
536+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
537+
'crop_pct': 0.875, 'interpolation': 'bilinear',
538+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
539+
'first_conv': 'conv_stem', 'classifier': 'classifier',
540+
**kwargs
541+
}
542+
543+
544+
default_cfgs = generate_default_cfgs({
545+
'mobilenetv3_large_075.untrained': _cfg(url=''),
546+
'mobilenetv3_large_100.ra_in1k': _cfg(
547+
interpolation='bicubic',
548+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
549+
hf_hub_id='timm/'),
550+
'mobilenetv3_large_100.miil_in21k_ft_in1k': _cfg(
551+
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.),
552+
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
553+
paper_ids='arXiv:2104.10972v4',
554+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth',
555+
hf_hub_id='timm/'),
556+
'mobilenetv3_large_100.miil_in21k': _cfg(
557+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth',
558+
hf_hub_id='timm/',
559+
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
560+
paper_ids='arXiv:2104.10972v4',
561+
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
562+
563+
'mobilenetv3_small_050.lamb_in1k': _cfg(
564+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
565+
hf_hub_id='timm/',
566+
interpolation='bicubic'),
567+
'mobilenetv3_small_075.lamb_in1k': _cfg(
568+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth',
569+
hf_hub_id='timm/',
570+
interpolation='bicubic'),
571+
'mobilenetv3_small_100.lamb_in1k': _cfg(
572+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth',
573+
hf_hub_id='timm/',
574+
interpolation='bicubic'),
575+
576+
'mobilenetv3_rw.rmsp_in1k': _cfg(
577+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
578+
interpolation='bicubic'),
579+
580+
'tf_mobilenetv3_large_075.in1k': _cfg(
581+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
582+
hf_hub_id='timm/',
583+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
584+
'tf_mobilenetv3_large_100.in1k': _cfg(
585+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
586+
hf_hub_id='timm/',
587+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
588+
'tf_mobilenetv3_large_minimal_100.in1k': _cfg(
589+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
590+
hf_hub_id='timm/',
591+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
592+
'tf_mobilenetv3_small_075.in1k': _cfg(
593+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
594+
hf_hub_id='timm/',
595+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
596+
'tf_mobilenetv3_small_100.in1k': _cfg(
597+
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
598+
hf_hub_id='timm/',
599+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
600+
'tf_mobilenetv3_small_minimal_100.in1k': _cfg(
601+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
602+
hf_hub_id='timm/',
603+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
604+
605+
'fbnetv3_b.ra2_in1k': _cfg(
606+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth',
607+
hf_hub_id='timm/',
608+
test_input_size=(3, 256, 256), crop_pct=0.95),
609+
'fbnetv3_d.ra2_in1k': _cfg(
610+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth',
611+
hf_hub_id='timm/',
612+
test_input_size=(3, 256, 256), crop_pct=0.95),
613+
'fbnetv3_g.ra2_in1k': _cfg(
614+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',
615+
hf_hub_id='timm/',
616+
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),
617+
618+
"lcnet_035.untrained": _cfg(),
619+
"lcnet_050.ra2_in1k": _cfg(
620+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
621+
hf_hub_id='timm/',
622+
interpolation='bicubic',
623+
),
624+
"lcnet_075.ra2_in1k": _cfg(
625+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
626+
hf_hub_id='timm/',
627+
interpolation='bicubic',
628+
),
629+
"lcnet_100.ra2_in1k": _cfg(
630+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
631+
hf_hub_id='timm/',
632+
interpolation='bicubic',
633+
),
634+
"lcnet_150.untrained": _cfg(),
635+
})
636+
637+
570638
@register_model
571639
def mobilenetv3_large_075(pretrained=False, **kwargs):
572640
""" MobileNet V3 """
@@ -581,24 +649,6 @@ def mobilenetv3_large_100(pretrained=False, **kwargs):
581649
return model
582650

583651

584-
@register_model
585-
def mobilenetv3_large_100_miil(pretrained=False, **kwargs):
586-
""" MobileNet V3
587-
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
588-
"""
589-
model = _gen_mobilenet_v3('mobilenetv3_large_100_miil', 1.0, pretrained=pretrained, **kwargs)
590-
return model
591-
592-
593-
@register_model
594-
def mobilenetv3_large_100_miil_in21k(pretrained=False, **kwargs):
595-
""" MobileNet V3, 21k pretraining
596-
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
597-
"""
598-
model = _gen_mobilenet_v3('mobilenetv3_large_100_miil_in21k', 1.0, pretrained=pretrained, **kwargs)
599-
return model
600-
601-
602652
@register_model
603653
def mobilenetv3_small_050(pretrained=False, **kwargs):
604654
""" MobileNet V3 """

0 commit comments

Comments
 (0)