2121 round_channels , resolve_bn_args , resolve_act_layer , BN_EPS_TF_DEFAULT
2222from ._features import FeatureInfo , FeatureHooks
2323from ._manipulate import checkpoint_seq
24+ from ._pretrained import generate_default_cfgs
2425from ._registry import register_model
2526
2627__all__ = ['MobileNetV3' , 'MobileNetV3Features' ]
2728
2829
29- def _cfg (url = '' , ** kwargs ):
30- return {
31- 'url' : url , 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : (7 , 7 ),
32- 'crop_pct' : 0.875 , 'interpolation' : 'bilinear' ,
33- 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
34- 'first_conv' : 'conv_stem' , 'classifier' : 'classifier' ,
35- ** kwargs
36- }
37-
38-
39- default_cfgs = {
40- 'mobilenetv3_large_075' : _cfg (url = '' ),
41- 'mobilenetv3_large_100' : _cfg (
42- interpolation = 'bicubic' ,
43- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth' ),
44- 'mobilenetv3_large_100_miil' : _cfg (
45- interpolation = 'bilinear' , mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ),
46- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth' ),
47- 'mobilenetv3_large_100_miil_in21k' : _cfg (
48- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth' ,
49- interpolation = 'bilinear' , mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), num_classes = 11221 ),
50-
51- 'mobilenetv3_small_050' : _cfg (
52- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth' ,
53- interpolation = 'bicubic' ),
54- 'mobilenetv3_small_075' : _cfg (
55- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth' ,
56- interpolation = 'bicubic' ),
57- 'mobilenetv3_small_100' : _cfg (
58- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth' ,
59- interpolation = 'bicubic' ),
60-
61- 'mobilenetv3_rw' : _cfg (
62- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth' ,
63- interpolation = 'bicubic' ),
64-
65- 'tf_mobilenetv3_large_075' : _cfg (
66- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth' ,
67- mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
68- 'tf_mobilenetv3_large_100' : _cfg (
69- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth' ,
70- mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
71- 'tf_mobilenetv3_large_minimal_100' : _cfg (
72- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth' ,
73- mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
74- 'tf_mobilenetv3_small_075' : _cfg (
75- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth' ,
76- mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
77- 'tf_mobilenetv3_small_100' : _cfg (
78- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth' ,
79- mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
80- 'tf_mobilenetv3_small_minimal_100' : _cfg (
81- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth' ,
82- mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
83-
84- 'fbnetv3_b' : _cfg (
85- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth' ,
86- test_input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
87- 'fbnetv3_d' : _cfg (
88- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth' ,
89- test_input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
90- 'fbnetv3_g' : _cfg (
91- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth' ,
92- input_size = (3 , 240 , 240 ), test_input_size = (3 , 288 , 288 ), crop_pct = 0.95 , pool_size = (8 , 8 )),
93-
94- "lcnet_035" : _cfg (),
95- "lcnet_050" : _cfg (
96- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth' ,
97- interpolation = 'bicubic' ,
98- ),
99- "lcnet_075" : _cfg (
100- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth' ,
101- interpolation = 'bicubic' ,
102- ),
103- "lcnet_100" : _cfg (
104- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth' ,
105- interpolation = 'bicubic' ,
106- ),
107- "lcnet_150" : _cfg (),
108- }
109-
110-
11130class MobileNetV3 (nn .Module ):
11231 """ MobiletNet-V3
11332
@@ -124,9 +43,24 @@ class MobileNetV3(nn.Module):
12443 """
12544
12645 def __init__ (
127- self , block_args , num_classes = 1000 , in_chans = 3 , stem_size = 16 , fix_stem = False , num_features = 1280 ,
128- head_bias = True , pad_type = '' , act_layer = None , norm_layer = None , se_layer = None , se_from_exp = True ,
129- round_chs_fn = round_channels , drop_rate = 0. , drop_path_rate = 0. , global_pool = 'avg' ):
46+ self ,
47+ block_args ,
48+ num_classes = 1000 ,
49+ in_chans = 3 ,
50+ stem_size = 16 ,
51+ fix_stem = False ,
52+ num_features = 1280 ,
53+ head_bias = True ,
54+ pad_type = '' ,
55+ act_layer = None ,
56+ norm_layer = None ,
57+ se_layer = None ,
58+ se_from_exp = True ,
59+ round_chs_fn = round_channels ,
60+ drop_rate = 0. ,
61+ drop_path_rate = 0. ,
62+ global_pool = 'avg' ,
63+ ):
13064 super (MobileNetV3 , self ).__init__ ()
13165 act_layer = act_layer or nn .ReLU
13266 norm_layer = norm_layer or nn .BatchNorm2d
@@ -145,8 +79,15 @@ def __init__(
14579
14680 # Middle stages (IR/ER/DS Blocks)
14781 builder = EfficientNetBuilder (
148- output_stride = 32 , pad_type = pad_type , round_chs_fn = round_chs_fn , se_from_exp = se_from_exp ,
149- act_layer = act_layer , norm_layer = norm_layer , se_layer = se_layer , drop_path_rate = drop_path_rate )
82+ output_stride = 32 ,
83+ pad_type = pad_type ,
84+ round_chs_fn = round_chs_fn ,
85+ se_from_exp = se_from_exp ,
86+ act_layer = act_layer ,
87+ norm_layer = norm_layer ,
88+ se_layer = se_layer ,
89+ drop_path_rate = drop_path_rate ,
90+ )
15091 self .blocks = nn .Sequential (* builder (stem_size , block_args ))
15192 self .feature_info = builder .features
15293 head_chs = builder .in_chs
@@ -225,9 +166,23 @@ class MobileNetV3Features(nn.Module):
225166 """
226167
227168 def __init__ (
228- self , block_args , out_indices = (0 , 1 , 2 , 3 , 4 ), feature_location = 'bottleneck' , in_chans = 3 ,
229- stem_size = 16 , fix_stem = False , output_stride = 32 , pad_type = '' , round_chs_fn = round_channels ,
230- se_from_exp = True , act_layer = None , norm_layer = None , se_layer = None , drop_rate = 0. , drop_path_rate = 0. ):
169+ self ,
170+ block_args ,
171+ out_indices = (0 , 1 , 2 , 3 , 4 ),
172+ feature_location = 'bottleneck' ,
173+ in_chans = 3 ,
174+ stem_size = 16 ,
175+ fix_stem = False ,
176+ output_stride = 32 ,
177+ pad_type = '' ,
178+ round_chs_fn = round_channels ,
179+ se_from_exp = True ,
180+ act_layer = None ,
181+ norm_layer = None ,
182+ se_layer = None ,
183+ drop_rate = 0. ,
184+ drop_path_rate = 0. ,
185+ ):
231186 super (MobileNetV3Features , self ).__init__ ()
232187 act_layer = act_layer or nn .ReLU
233188 norm_layer = norm_layer or nn .BatchNorm2d
@@ -243,9 +198,16 @@ def __init__(
243198
244199 # Middle stages (IR/ER/DS Blocks)
245200 builder = EfficientNetBuilder (
246- output_stride = output_stride , pad_type = pad_type , round_chs_fn = round_chs_fn , se_from_exp = se_from_exp ,
247- act_layer = act_layer , norm_layer = norm_layer , se_layer = se_layer ,
248- drop_path_rate = drop_path_rate , feature_location = feature_location )
201+ output_stride = output_stride ,
202+ pad_type = pad_type ,
203+ round_chs_fn = round_chs_fn ,
204+ se_from_exp = se_from_exp ,
205+ act_layer = act_layer ,
206+ norm_layer = norm_layer ,
207+ se_layer = se_layer ,
208+ drop_path_rate = drop_path_rate ,
209+ feature_location = feature_location ,
210+ )
249211 self .blocks = nn .Sequential (* builder (stem_size , block_args ))
250212 self .feature_info = FeatureInfo (builder .features , out_indices )
251213 self ._stage_out_idx = {v ['stage' ]: i for i , v in enumerate (self .feature_info ) if i in out_indices }
@@ -286,7 +248,9 @@ def _create_mnv3(variant, pretrained=False, **kwargs):
286248 kwargs_filter = ('num_classes' , 'num_features' , 'head_conv' , 'head_bias' , 'global_pool' )
287249 model_cls = MobileNetV3Features
288250 model = build_model_with_cfg (
289- model_cls , variant , pretrained ,
251+ model_cls ,
252+ variant ,
253+ pretrained ,
290254 pretrained_strict = not features_only ,
291255 kwargs_filter = kwargs_filter ,
292256 ** kwargs )
@@ -567,6 +531,110 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
567531 return model
568532
569533
534+ def _cfg (url = '' , ** kwargs ):
535+ return {
536+ 'url' : url , 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : (7 , 7 ),
537+ 'crop_pct' : 0.875 , 'interpolation' : 'bilinear' ,
538+ 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
539+ 'first_conv' : 'conv_stem' , 'classifier' : 'classifier' ,
540+ ** kwargs
541+ }
542+
543+
544+ default_cfgs = generate_default_cfgs ({
545+ 'mobilenetv3_large_075.untrained' : _cfg (url = '' ),
546+ 'mobilenetv3_large_100.ra_in1k' : _cfg (
547+ interpolation = 'bicubic' ,
548+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth' ,
549+ hf_hub_id = 'timm/' ),
550+ 'mobilenetv3_large_100.miil_in21k_ft_in1k' : _cfg (
551+ interpolation = 'bilinear' , mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ),
552+ origin_url = 'https://github.com/Alibaba-MIIL/ImageNet21K' ,
553+ paper_ids = 'arXiv:2104.10972v4' ,
554+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth' ,
555+ hf_hub_id = 'timm/' ),
556+ 'mobilenetv3_large_100.miil_in21k' : _cfg (
557+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth' ,
558+ hf_hub_id = 'timm/' ,
559+ origin_url = 'https://github.com/Alibaba-MIIL/ImageNet21K' ,
560+ paper_ids = 'arXiv:2104.10972v4' ,
561+ interpolation = 'bilinear' , mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), num_classes = 11221 ),
562+
563+ 'mobilenetv3_small_050.lamb_in1k' : _cfg (
564+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth' ,
565+ hf_hub_id = 'timm/' ,
566+ interpolation = 'bicubic' ),
567+ 'mobilenetv3_small_075.lamb_in1k' : _cfg (
568+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth' ,
569+ hf_hub_id = 'timm/' ,
570+ interpolation = 'bicubic' ),
571+ 'mobilenetv3_small_100.lamb_in1k' : _cfg (
572+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth' ,
573+ hf_hub_id = 'timm/' ,
574+ interpolation = 'bicubic' ),
575+
576+ 'mobilenetv3_rw.rmsp_in1k' : _cfg (
577+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth' ,
578+ interpolation = 'bicubic' ),
579+
580+ 'tf_mobilenetv3_large_075.in1k' : _cfg (
581+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth' ,
582+ hf_hub_id = 'timm/' ,
583+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
584+ 'tf_mobilenetv3_large_100.in1k' : _cfg (
585+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth' ,
586+ hf_hub_id = 'timm/' ,
587+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
588+ 'tf_mobilenetv3_large_minimal_100.in1k' : _cfg (
589+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth' ,
590+ hf_hub_id = 'timm/' ,
591+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
592+ 'tf_mobilenetv3_small_075.in1k' : _cfg (
593+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth' ,
594+ hf_hub_id = 'timm/' ,
595+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
596+ 'tf_mobilenetv3_small_100.in1k' : _cfg (
597+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth' ,
598+ hf_hub_id = 'timm/' ,
599+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
600+ 'tf_mobilenetv3_small_minimal_100.in1k' : _cfg (
601+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth' ,
602+ hf_hub_id = 'timm/' ,
603+ mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
604+
605+ 'fbnetv3_b.ra2_in1k' : _cfg (
606+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth' ,
607+ hf_hub_id = 'timm/' ,
608+ test_input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
609+ 'fbnetv3_d.ra2_in1k' : _cfg (
610+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth' ,
611+ hf_hub_id = 'timm/' ,
612+ test_input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
613+ 'fbnetv3_g.ra2_in1k' : _cfg (
614+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth' ,
615+ hf_hub_id = 'timm/' ,
616+ input_size = (3 , 240 , 240 ), test_input_size = (3 , 288 , 288 ), crop_pct = 0.95 , pool_size = (8 , 8 )),
617+
618+ "lcnet_035.untrained" : _cfg (),
619+ "lcnet_050.ra2_in1k" : _cfg (
620+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth' ,
621+ hf_hub_id = 'timm/' ,
622+ interpolation = 'bicubic' ,
623+ ),
624+ "lcnet_075.ra2_in1k" : _cfg (
625+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth' ,
626+ hf_hub_id = 'timm/' ,
627+ interpolation = 'bicubic' ,
628+ ),
629+ "lcnet_100.ra2_in1k" : _cfg (
630+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth' ,
631+ hf_hub_id = 'timm/' ,
632+ interpolation = 'bicubic' ,
633+ ),
634+ "lcnet_150.untrained" : _cfg (),
635+ })
636+
637+
570638@register_model
571639def mobilenetv3_large_075 (pretrained = False , ** kwargs ):
572640 """ MobileNet V3 """
@@ -581,24 +649,6 @@ def mobilenetv3_large_100(pretrained=False, **kwargs):
581649 return model
582650
583651
584- @register_model
585- def mobilenetv3_large_100_miil (pretrained = False , ** kwargs ):
586- """ MobileNet V3
587- Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
588- """
589- model = _gen_mobilenet_v3 ('mobilenetv3_large_100_miil' , 1.0 , pretrained = pretrained , ** kwargs )
590- return model
591-
592-
593- @register_model
594- def mobilenetv3_large_100_miil_in21k (pretrained = False , ** kwargs ):
595- """ MobileNet V3, 21k pretraining
596- Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
597- """
598- model = _gen_mobilenet_v3 ('mobilenetv3_large_100_miil_in21k' , 1.0 , pretrained = pretrained , ** kwargs )
599- return model
600-
601-
602652@register_model
603653def mobilenetv3_small_050 (pretrained = False , ** kwargs ):
604654 """ MobileNet V3 """
0 commit comments