@@ -600,7 +600,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
600600 block_args = decode_arch_def (arch_def ),
601601 stem_size = 32 ,
602602 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
603- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
603+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
604604 ** kwargs
605605 )
606606 model = _create_effnet (variant , pretrained , ** model_kwargs )
@@ -636,7 +636,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
636636 block_args = decode_arch_def (arch_def ),
637637 stem_size = 32 ,
638638 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
639- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
639+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
640640 ** kwargs
641641 )
642642 model = _create_effnet (variant , pretrained , ** model_kwargs )
@@ -665,7 +665,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
665665 block_args = decode_arch_def (arch_def ),
666666 stem_size = 8 ,
667667 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
668- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
668+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
669669 ** kwargs
670670 )
671671 model = _create_effnet (variant , pretrained , ** model_kwargs )
@@ -694,7 +694,7 @@ def _gen_mobilenet_v2(
694694 stem_size = 32 ,
695695 fix_stem = fix_stem_head ,
696696 round_chs_fn = round_chs_fn ,
697- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
697+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
698698 act_layer = resolve_act_layer (kwargs , 'relu6' ),
699699 ** kwargs
700700 )
@@ -725,7 +725,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
725725 stem_size = 16 ,
726726 num_features = 1984 , # paper suggests this, but is not 100% clear
727727 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
728- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
728+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
729729 ** kwargs
730730 )
731731 model = _create_effnet (variant , pretrained , ** model_kwargs )
@@ -760,7 +760,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
760760 block_args = decode_arch_def (arch_def ),
761761 stem_size = 32 ,
762762 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
763- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
763+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
764764 ** kwargs
765765 )
766766 model = _create_effnet (variant , pretrained , ** model_kwargs )
@@ -807,7 +807,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
807807 stem_size = 32 ,
808808 round_chs_fn = round_chs_fn ,
809809 act_layer = resolve_act_layer (kwargs , 'swish' ),
810- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
810+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
811811 ** kwargs ,
812812 )
813813 model = _create_effnet (variant , pretrained , ** model_kwargs )
@@ -836,7 +836,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
836836 num_features = round_chs_fn (1280 ),
837837 stem_size = 32 ,
838838 round_chs_fn = round_chs_fn ,
839- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
839+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
840840 act_layer = resolve_act_layer (kwargs , 'relu' ),
841841 ** kwargs ,
842842 )
@@ -867,7 +867,7 @@ def _gen_efficientnet_condconv(
867867 num_features = round_chs_fn (1280 ),
868868 stem_size = 32 ,
869869 round_chs_fn = round_chs_fn ,
870- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
870+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
871871 act_layer = resolve_act_layer (kwargs , 'swish' ),
872872 ** kwargs ,
873873 )
@@ -909,7 +909,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
909909 fix_stem = True ,
910910 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
911911 act_layer = resolve_act_layer (kwargs , 'relu6' ),
912- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
912+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
913913 ** kwargs ,
914914 )
915915 model = _create_effnet (variant , pretrained , ** model_kwargs )
@@ -937,7 +937,7 @@ def _gen_efficientnetv2_base(
937937 num_features = round_chs_fn (1280 ),
938938 stem_size = 32 ,
939939 round_chs_fn = round_chs_fn ,
940- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
940+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
941941 act_layer = resolve_act_layer (kwargs , 'silu' ),
942942 ** kwargs ,
943943 )
@@ -976,7 +976,7 @@ def _gen_efficientnetv2_s(
976976 num_features = round_chs_fn (num_features ),
977977 stem_size = 24 ,
978978 round_chs_fn = round_chs_fn ,
979- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
979+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
980980 act_layer = resolve_act_layer (kwargs , 'silu' ),
981981 ** kwargs ,
982982 )
@@ -1006,7 +1006,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
10061006 num_features = 1280 ,
10071007 stem_size = 24 ,
10081008 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
1009- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
1009+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
10101010 act_layer = resolve_act_layer (kwargs , 'silu' ),
10111011 ** kwargs ,
10121012 )
@@ -1036,7 +1036,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
10361036 num_features = 1280 ,
10371037 stem_size = 32 ,
10381038 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
1039- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
1039+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
10401040 act_layer = resolve_act_layer (kwargs , 'silu' ),
10411041 ** kwargs ,
10421042 )
@@ -1066,7 +1066,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
10661066 num_features = 1280 ,
10671067 stem_size = 32 ,
10681068 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
1069- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
1069+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
10701070 act_layer = resolve_act_layer (kwargs , 'silu' ),
10711071 ** kwargs ,
10721072 )
@@ -1100,7 +1100,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
11001100 num_features = 1536 ,
11011101 stem_size = 16 ,
11021102 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
1103- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
1103+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
11041104 ** kwargs
11051105 )
11061106 model = _create_effnet (variant , pretrained , ** model_kwargs )
@@ -1133,7 +1133,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
11331133 num_features = 1536 ,
11341134 stem_size = 24 ,
11351135 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
1136- norm_layer = partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
1136+ norm_layer = kwargs . pop ( 'norm_layer' , None ) or partial (nn .BatchNorm2d , ** resolve_bn_args (kwargs )),
11371137 ** kwargs
11381138 )
11391139 model = _create_effnet (variant , pretrained , ** model_kwargs )
0 commit comments