@@ -72,6 +72,10 @@ def _cfg(url='', **kwargs):
7272 'tf_mobilenetv3_small_minimal_100' : _cfg (
7373 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth' ,
7474 mean = IMAGENET_INCEPTION_MEAN , std = IMAGENET_INCEPTION_STD ),
75+
76+ 'fbnetv3_b' : _cfg (),
77+ 'fbnetv3_d' : _cfg (),
78+ 'fbnetv3_g' : _cfg (),
7579}
7680
7781
@@ -86,7 +90,7 @@ class MobileNetV3(nn.Module):
8690 """
8791
8892 def __init__ (self , block_args , num_classes = 1000 , in_chans = 3 , stem_size = 16 , num_features = 1280 , head_bias = True ,
89- pad_type = '' , act_layer = None , norm_layer = None , se_layer = None ,
93+ pad_type = '' , act_layer = None , norm_layer = None , se_layer = None , se_from_exp = True ,
9094 round_chs_fn = round_channels , drop_rate = 0. , drop_path_rate = 0. , global_pool = 'avg' ):
9195 super (MobileNetV3 , self ).__init__ ()
9296 act_layer = act_layer or nn .ReLU
@@ -104,7 +108,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_f
104108
105109 # Middle stages (IR/ER/DS Blocks)
106110 builder = EfficientNetBuilder (
107- output_stride = 32 , pad_type = pad_type , round_chs_fn = round_chs_fn ,
111+ output_stride = 32 , pad_type = pad_type , round_chs_fn = round_chs_fn , se_from_exp = se_from_exp ,
108112 act_layer = act_layer , norm_layer = norm_layer , se_layer = se_layer , drop_path_rate = drop_path_rate )
109113 self .blocks = nn .Sequential (* builder (stem_size , block_args ))
110114 self .feature_info = builder .features
@@ -161,8 +165,8 @@ class MobileNetV3Features(nn.Module):
161165 and object detection models.
162166 """
163167
164- def __init__ (self , block_args , out_indices = (0 , 1 , 2 , 3 , 4 ), feature_location = 'bottleneck' ,
165- in_chans = 3 , stem_size = 16 , output_stride = 32 , pad_type = '' , round_chs_fn = round_channels ,
168+ def __init__ (self , block_args , out_indices = (0 , 1 , 2 , 3 , 4 ), feature_location = 'bottleneck' , in_chans = 3 ,
169+ stem_size = 16 , output_stride = 32 , pad_type = '' , round_chs_fn = round_channels , se_from_exp = True ,
166170 act_layer = None , norm_layer = None , se_layer = None , drop_rate = 0. , drop_path_rate = 0. ):
167171 super (MobileNetV3Features , self ).__init__ ()
168172 act_layer = act_layer or nn .ReLU
@@ -178,7 +182,7 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bo
178182
179183 # Middle stages (IR/ER/DS Blocks)
180184 builder = EfficientNetBuilder (
181- output_stride = output_stride , pad_type = pad_type , round_chs_fn = round_chs_fn ,
185+ output_stride = output_stride , pad_type = pad_type , round_chs_fn = round_chs_fn , se_from_exp = se_from_exp ,
182186 act_layer = act_layer , norm_layer = norm_layer , se_layer = se_layer ,
183187 drop_path_rate = drop_path_rate , feature_location = feature_location )
184188 self .blocks = nn .Sequential (* builder (stem_size , block_args ))
@@ -262,7 +266,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
262266 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
263267 norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
264268 act_layer = resolve_act_layer (kwargs , 'hard_swish' ),
265- se_layer = partial (SqueezeExcite , gate_fn = get_act_fn ('hard_sigmoid' ), reduce_from_block = False ),
269+ se_layer = partial (SqueezeExcite , gate_fn = get_act_fn ('hard_sigmoid' )),
266270 ** kwargs ,
267271 )
268272 model = _create_mnv3 (variant , pretrained , ** model_kwargs )
@@ -351,7 +355,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
351355 ['cn_r1_k1_s1_c960' ], # hard-swish
352356 ]
353357 se_layer = partial (
354- SqueezeExcite , gate_fn = get_act_fn ('hard_sigmoid' ), force_act_layer = nn .ReLU , reduce_from_block = False , divisor = 8 )
358+ SqueezeExcite , gate_fn = get_act_fn ('hard_sigmoid' ), force_act_layer = nn .ReLU , round_chs_fn = round_channels )
355359 model_kwargs = dict (
356360 block_args = decode_arch_def (arch_def ),
357361 num_features = num_features ,
@@ -366,6 +370,86 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
366370 return model
367371
368372
373+ def _gen_fbnetv3 (variant , channel_multiplier = 1.0 , pretrained = False , ** kwargs ):
374+ """ FBNetV3
375+ FIXME untested, this is a preliminary impl of some FBNet-V3 variants.
376+ """
377+ vl = variant .split ('_' )[- 1 ]
378+ if vl in ('a' , 'b' ):
379+ stem_size = 16
380+ arch_def = [
381+ # stage 0, 112x112 in
382+ ['ds_r2_k3_s1_e1_c16' ],
383+ # stage 1, 112x112 in
384+ ['ir_r1_k5_s2_e4_c24' , 'ir_r3_k5_s1_e2_c24' ],
385+ # stage 2, 56x56 in
386+ ['ir_r1_k5_s2_e5_c40_se0.25' , 'ir_r4_k5_s1_e3_c40_se0.25' ],
387+ # stage 3, 28x28 in
388+ ['ir_r1_k5_s2_e5_c72' , 'ir_r4_k3_s1_e3_c72' ],
389+ # stage 4, 14x14in
390+ ['ir_r1_k3_s1_e5_c120_se0.25' , 'ir_r5_k5_s1_e3_c120_se0.25' ],
391+ # stage 5, 14x14in
392+ ['ir_r1_k3_s2_e6_c184_se0.25' , 'ir_r5_k5_s1_e4_c184_se0.25' , 'ir_r1_k5_s1_e6_c224_se0.25' ],
393+ # stage 6, 7x7 in
394+ ['cn_r1_k1_s1_c1344' ],
395+ ]
396+ elif vl == 'd' :
397+ stem_size = 24
398+ arch_def = [
399+ # stage 0, 112x112 in
400+ ['ds_r2_k3_s1_e1_c16' ],
401+ # stage 1, 112x112 in
402+ ['ir_r1_k3_s2_e5_c24' , 'ir_r5_k3_s1_e2_c24' ],
403+ # stage 2, 56x56 in
404+ ['ir_r1_k5_s2_e4_c40_se0.25' , 'ir_r4_k3_s1_e3_c40_se0.25' ],
405+ # stage 3, 28x28 in
406+ ['ir_r1_k3_s2_e5_c72' , 'ir_r4_k3_s1_e3_c72' ],
407+ # stage 4, 14x14in
408+ ['ir_r1_k3_s1_e5_c128_se0.25' , 'ir_r6_k5_s1_e3_c128_se0.25' ],
409+ # stage 5, 14x14in
410+ ['ir_r1_k3_s2_e6_c208_se0.25' , 'ir_r5_k5_s1_e5_c208_se0.25' , 'ir_r1_k5_s1_e6_c240_se0.25' ],
411+ # stage 6, 7x7 in
412+ ['cn_r1_k1_s1_c1440' ],
413+ ]
414+ elif vl == 'g' :
415+ stem_size = 32
416+ arch_def = [
417+ # stage 0, 112x112 in
418+ ['ds_r3_k3_s1_e1_c24' ],
419+ # stage 1, 112x112 in
420+ ['ir_r1_k5_s2_e4_c40' , 'ir_r4_k5_s1_e2_c40' ],
421+ # stage 2, 56x56 in
422+ ['ir_r1_k5_s2_e4_c56_se0.25' , 'ir_r4_k5_s1_e3_c56_se0.25' ],
423+ # stage 3, 28x28 in
424+ ['ir_r1_k5_s2_e5_c104' , 'ir_r4_k3_s1_e3_c104' ],
425+ # stage 4, 14x14in
426+ ['ir_r1_k3_s1_e5_c160_se0.25' , 'ir_r8_k5_s1_e3_c160_se0.25' ],
427+ # stage 5, 14x14in
428+ ['ir_r1_k3_s2_e6_c264_se0.25' , 'ir_r6_k5_s1_e5_c264_se0.25' , 'ir_r2_k5_s1_e6_c288_se0.25' ],
429+ # stage 6, 7x7 in
430+ ['cn_r1_k1_s1_c1728' ], # hard-swish
431+ ]
432+ else :
433+ raise NotImplemented
434+ round_chs_fn = partial (round_channels , multiplier = channel_multiplier , round_limit = 0.95 )
435+ se_layer = partial (SqueezeExcite , gate_fn = get_act_fn ('hard_sigmoid' ), round_chs_fn = round_chs_fn )
436+ act_layer = resolve_act_layer (kwargs , 'hard_swish' )
437+ model_kwargs = dict (
438+ block_args = decode_arch_def (arch_def ),
439+ num_features = 1984 ,
440+ head_bias = False ,
441+ stem_size = stem_size ,
442+ round_chs_fn = round_chs_fn ,
443+ se_from_exp = False ,
444+ norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
445+ act_layer = act_layer ,
446+ se_layer = se_layer ,
447+ ** kwargs ,
448+ )
449+ model = _create_mnv3 (variant , pretrained , ** model_kwargs )
450+ return model
451+
452+
369453@register_model
370454def mobilenetv3_large_075 (pretrained = False , ** kwargs ):
371455 """ MobileNet V3 """
@@ -474,3 +558,24 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
474558 kwargs ['pad_type' ] = 'same'
475559 model = _gen_mobilenet_v3 ('tf_mobilenetv3_small_minimal_100' , 1.0 , pretrained = pretrained , ** kwargs )
476560 return model
561+
562+
563+ @register_model
564+ def fbnetv3_b (pretrained = False , ** kwargs ):
565+ """ FBNetV3-B """
566+ model = _gen_fbnetv3 ('fbnetv3_b' , pretrained = pretrained , ** kwargs )
567+ return model
568+
569+
570+ @register_model
571+ def fbnetv3_d (pretrained = False , ** kwargs ):
572+ """ FBNetV3-D """
573+ model = _gen_fbnetv3 ('fbnetv3_d' , pretrained = pretrained , ** kwargs )
574+ return model
575+
576+
577+ @register_model
578+ def fbnetv3_g (pretrained = False , ** kwargs ):
579+ """ FBNetV3-G """
580+ model = _gen_fbnetv3 ('fbnetv3_g' , pretrained = pretrained , ** kwargs )
581+ return model
0 commit comments