@@ -290,6 +290,7 @@ def _make_block(self, ba):
290290 ba ['bn_eps' ] = self .bn_eps
291291 ba ['folded_bn' ] = self .folded_bn
292292 ba ['padding_same' ] = self .padding_same
293+ # block act fn overrides the model default
293294 ba ['act_fn' ] = ba ['act_fn' ] if ba ['act_fn' ] is not None else self .act_fn
294295 assert ba ['act_fn' ] is not None
295296 if _DEBUG :
@@ -611,15 +612,14 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
611612 depth_multiplier = 1.0 , depth_divisor = 8 , min_depth = None ,
612613 bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
613614 drop_rate = 0. , act_fn = F .relu , se_gate_fn = torch .sigmoid , se_reduce_mid = False ,
614- global_pool = 'avg' , skip_head_conv = False , efficient_head = False ,
615- weight_init = 'goog' , folded_bn = False , padding_same = False ):
615+ global_pool = 'avg' , head_conv = 'default' , weight_init = 'goog' ,
616+ folded_bn = False , padding_same = False ):
616617 super (GenMobileNet , self ).__init__ ()
617618 self .num_classes = num_classes
618619 self .depth_multiplier = depth_multiplier
619620 self .drop_rate = drop_rate
620621 self .act_fn = act_fn
621622 self .num_features = num_features
622- self .efficient_head = efficient_head # pool before last conv
623623
624624 stem_size = _round_channels (stem_size , depth_multiplier , depth_divisor , min_depth )
625625 self .conv_stem = sconv2d (
@@ -629,19 +629,22 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
629629 in_chs = stem_size
630630
631631 builder = _BlockBuilder (
632- depth_multiplier , depth_divisor , min_depth , act_fn , se_gate_fn , se_reduce_mid ,
632+ depth_multiplier , depth_divisor , min_depth ,
633+ act_fn , se_gate_fn , se_reduce_mid ,
633634 bn_momentum , bn_eps , folded_bn , padding_same )
634635 self .blocks = nn .Sequential (* builder (in_chs , block_args ))
635636 in_chs = builder .in_chs
636637
637- if skip_head_conv :
638+ if not head_conv or head_conv == 'none' :
639+ self .efficient_head = False
638640 self .conv_head = None
639641 assert in_chs == self .num_features
640642 else :
643+ self .efficient_head = head_conv == 'efficient'
641644 self .conv_head = sconv2d (
642645 in_chs , self .num_features , 1 ,
643- padding = _padding_arg (0 , padding_same ), bias = folded_bn and not efficient_head )
644- self .bn2 = None if (folded_bn or efficient_head ) else \
646+ padding = _padding_arg (0 , padding_same ), bias = folded_bn and not self . efficient_head )
647+ self .bn2 = None if (folded_bn or self . efficient_head ) else \
645648 nn .BatchNorm2d (self .num_features , momentum = bn_momentum , eps = bn_eps )
646649
647650 self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
@@ -674,7 +677,7 @@ def forward_features(self, x, pool=True):
674677 x = self .blocks (x )
675678 if self .efficient_head :
676679 # efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
677- x = self .global_pool (x ) # always need to pool here regardless of bool
680+ x = self .global_pool (x ) # always need to pool here regardless of flag
678681 x = self .conv_head (x )
679682 # no BN
680683 x = self .act_fn (x )
@@ -836,7 +839,7 @@ def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs):
836839 bn_momentum = bn_momentum ,
837840 bn_eps = bn_eps ,
838841 act_fn = F .relu6 ,
839- skip_head_conv = True ,
842+ head_conv = 'none' ,
840843 ** kwargs
841844 )
842845 return model
@@ -914,6 +917,7 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
914917 act_fn = hard_swish ,
915918 se_gate_fn = hard_sigmoid ,
916919 se_reduce_mid = True ,
920+ head_conv = 'efficient' ,
917921 ** kwargs
918922 )
919923 return model
0 commit comments