Skip to content

Commit 3d9c23a

Browse files
authored
Merge pull request #875 from alexander-soare/effnets-norm-layer
make it possible to provide norm_layer via create_model
2 parents a6e8598 + 6bbc50b commit 3d9c23a

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

timm/models/efficientnet.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
600600
block_args=decode_arch_def(arch_def),
601601
stem_size=32,
602602
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
603-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
603+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
604604
**kwargs
605605
)
606606
model = _create_effnet(variant, pretrained, **model_kwargs)
@@ -636,7 +636,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
636636
block_args=decode_arch_def(arch_def),
637637
stem_size=32,
638638
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
639-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
639+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
640640
**kwargs
641641
)
642642
model = _create_effnet(variant, pretrained, **model_kwargs)
@@ -665,7 +665,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
665665
block_args=decode_arch_def(arch_def),
666666
stem_size=8,
667667
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
668-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
668+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
669669
**kwargs
670670
)
671671
model = _create_effnet(variant, pretrained, **model_kwargs)
@@ -694,7 +694,7 @@ def _gen_mobilenet_v2(
694694
stem_size=32,
695695
fix_stem=fix_stem_head,
696696
round_chs_fn=round_chs_fn,
697-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
697+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
698698
act_layer=resolve_act_layer(kwargs, 'relu6'),
699699
**kwargs
700700
)
@@ -725,7 +725,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
725725
stem_size=16,
726726
num_features=1984, # paper suggests this, but is not 100% clear
727727
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
728-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
728+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
729729
**kwargs
730730
)
731731
model = _create_effnet(variant, pretrained, **model_kwargs)
@@ -760,7 +760,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
760760
block_args=decode_arch_def(arch_def),
761761
stem_size=32,
762762
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
763-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
763+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
764764
**kwargs
765765
)
766766
model = _create_effnet(variant, pretrained, **model_kwargs)
@@ -807,7 +807,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
807807
stem_size=32,
808808
round_chs_fn=round_chs_fn,
809809
act_layer=resolve_act_layer(kwargs, 'swish'),
810-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
810+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
811811
**kwargs,
812812
)
813813
model = _create_effnet(variant, pretrained, **model_kwargs)
@@ -836,7 +836,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
836836
num_features=round_chs_fn(1280),
837837
stem_size=32,
838838
round_chs_fn=round_chs_fn,
839-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
839+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
840840
act_layer=resolve_act_layer(kwargs, 'relu'),
841841
**kwargs,
842842
)
@@ -867,7 +867,7 @@ def _gen_efficientnet_condconv(
867867
num_features=round_chs_fn(1280),
868868
stem_size=32,
869869
round_chs_fn=round_chs_fn,
870-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
870+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
871871
act_layer=resolve_act_layer(kwargs, 'swish'),
872872
**kwargs,
873873
)
@@ -909,7 +909,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
909909
fix_stem=True,
910910
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
911911
act_layer=resolve_act_layer(kwargs, 'relu6'),
912-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
912+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
913913
**kwargs,
914914
)
915915
model = _create_effnet(variant, pretrained, **model_kwargs)
@@ -937,7 +937,7 @@ def _gen_efficientnetv2_base(
937937
num_features=round_chs_fn(1280),
938938
stem_size=32,
939939
round_chs_fn=round_chs_fn,
940-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
940+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
941941
act_layer=resolve_act_layer(kwargs, 'silu'),
942942
**kwargs,
943943
)
@@ -976,7 +976,7 @@ def _gen_efficientnetv2_s(
976976
num_features=round_chs_fn(num_features),
977977
stem_size=24,
978978
round_chs_fn=round_chs_fn,
979-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
979+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
980980
act_layer=resolve_act_layer(kwargs, 'silu'),
981981
**kwargs,
982982
)
@@ -1006,7 +1006,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
10061006
num_features=1280,
10071007
stem_size=24,
10081008
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
1009-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
1009+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
10101010
act_layer=resolve_act_layer(kwargs, 'silu'),
10111011
**kwargs,
10121012
)
@@ -1036,7 +1036,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
10361036
num_features=1280,
10371037
stem_size=32,
10381038
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
1039-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
1039+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
10401040
act_layer=resolve_act_layer(kwargs, 'silu'),
10411041
**kwargs,
10421042
)
@@ -1066,7 +1066,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
10661066
num_features=1280,
10671067
stem_size=32,
10681068
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
1069-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
1069+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
10701070
act_layer=resolve_act_layer(kwargs, 'silu'),
10711071
**kwargs,
10721072
)
@@ -1100,7 +1100,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
11001100
num_features=1536,
11011101
stem_size=16,
11021102
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
1103-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
1103+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
11041104
**kwargs
11051105
)
11061106
model = _create_effnet(variant, pretrained, **model_kwargs)
@@ -1133,7 +1133,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
11331133
num_features=1536,
11341134
stem_size=24,
11351135
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
1136-
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
1136+
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
11371137
**kwargs
11381138
)
11391139
model = _create_effnet(variant, pretrained, **model_kwargs)

0 commit comments

Comments
 (0)