@@ -110,9 +110,10 @@ class MobileNetV3(nn.Module):
110110 * LCNet - https://arxiv.org/abs/2109.15099
111111 """
112112
113- def __init__ (self , block_args , num_classes = 1000 , in_chans = 3 , stem_size = 16 , num_features = 1280 , head_bias = True ,
114- pad_type = '' , act_layer = None , norm_layer = None , se_layer = None , se_from_exp = True ,
115- round_chs_fn = round_channels , drop_rate = 0. , drop_path_rate = 0. , global_pool = 'avg' ):
113+ def __init__ (
114+ self , block_args , num_classes = 1000 , in_chans = 3 , stem_size = 16 , fix_stem = False , num_features = 1280 ,
115+ head_bias = True , pad_type = '' , act_layer = None , norm_layer = None , se_layer = None , se_from_exp = True ,
116+ round_chs_fn = round_channels , drop_rate = 0. , drop_path_rate = 0. , global_pool = 'avg' ):
116117 super (MobileNetV3 , self ).__init__ ()
117118 act_layer = act_layer or nn .ReLU
118119 norm_layer = norm_layer or nn .BatchNorm2d
@@ -122,7 +123,8 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_f
122123 self .drop_rate = drop_rate
123124
124125 # Stem
125- stem_size = round_chs_fn (stem_size )
126+ if not fix_stem :
127+ stem_size = round_chs_fn (stem_size )
126128 self .conv_stem = create_conv2d (in_chans , stem_size , 3 , stride = 2 , padding = pad_type )
127129 self .bn1 = norm_layer (stem_size )
128130 self .act1 = act_layer (inplace = True )
@@ -188,16 +190,17 @@ class MobileNetV3Features(nn.Module):
188190 """
189191
190192 def __init__ (self , block_args , out_indices = (0 , 1 , 2 , 3 , 4 ), feature_location = 'bottleneck' , in_chans = 3 ,
191- stem_size = 16 , output_stride = 32 , pad_type = '' , round_chs_fn = round_channels , se_from_exp = True ,
192- act_layer = None , norm_layer = None , se_layer = None , drop_rate = 0. , drop_path_rate = 0. ):
193+ stem_size = 16 , fix_stem = False , output_stride = 32 , pad_type = '' , round_chs_fn = round_channels ,
194+ se_from_exp = True , act_layer = None , norm_layer = None , se_layer = None , drop_rate = 0. , drop_path_rate = 0. ):
193195 super (MobileNetV3Features , self ).__init__ ()
194196 act_layer = act_layer or nn .ReLU
195197 norm_layer = norm_layer or nn .BatchNorm2d
196198 se_layer = se_layer or SqueezeExcite
197199 self .drop_rate = drop_rate
198200
199201 # Stem
200- stem_size = round_chs_fn (stem_size )
202+ if not fix_stem :
203+ stem_size = round_chs_fn (stem_size )
201204 self .conv_stem = create_conv2d (in_chans , stem_size , 3 , stride = 2 , padding = pad_type )
202205 self .bn1 = norm_layer (stem_size )
203206 self .act1 = act_layer (inplace = True )
@@ -381,6 +384,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
381384 block_args = decode_arch_def (arch_def ),
382385 num_features = num_features ,
383386 stem_size = 16 ,
387+ fix_stem = channel_multiplier < 0.75 ,
384388 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
385389 norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
386390 act_layer = act_layer ,
0 commit comments