@@ -488,7 +488,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
488488
489489def _gen_mobilenet_v1 (
490490 variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
491- fix_stem_head = False , head_conv = False , pretrained = False , ** kwargs ):
491+ group_size = None , fix_stem_head = False , head_conv = False , pretrained = False , ** kwargs ):
492492 """
493493 Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
494494 Paper: https://arxiv.org/abs/1801.04381
@@ -503,7 +503,12 @@ def _gen_mobilenet_v1(
503503 round_chs_fn = partial (round_channels , multiplier = channel_multiplier )
504504 head_features = (1024 if fix_stem_head else max (1024 , round_chs_fn (1024 ))) if head_conv else 0
505505 model_kwargs = dict (
506- block_args = decode_arch_def (arch_def , depth_multiplier = depth_multiplier , fix_first_last = fix_stem_head ),
506+ block_args = decode_arch_def (
507+ arch_def ,
508+ depth_multiplier = depth_multiplier ,
509+ fix_first_last = fix_stem_head ,
510+ group_size = group_size ,
511+ ),
507512 num_features = head_features ,
508513 stem_size = 32 ,
509514 fix_stem = fix_stem_head ,
@@ -517,7 +522,9 @@ def _gen_mobilenet_v1(
517522
518523
519524def _gen_mobilenet_v2 (
520- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , fix_stem_head = False , pretrained = False , ** kwargs ):
525+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
526+ group_size = None , fix_stem_head = False , pretrained = False , ** kwargs
527+ ):
521528 """ Generate MobileNet-V2 network
522529 Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
523530 Paper: https://arxiv.org/abs/1801.04381
@@ -533,7 +540,12 @@ def _gen_mobilenet_v2(
533540 ]
534541 round_chs_fn = partial (round_channels , multiplier = channel_multiplier )
535542 model_kwargs = dict (
536- block_args = decode_arch_def (arch_def , depth_multiplier = depth_multiplier , fix_first_last = fix_stem_head ),
543+ block_args = decode_arch_def (
544+ arch_def ,
545+ depth_multiplier = depth_multiplier ,
546+ fix_first_last = fix_stem_head ,
547+ group_size = group_size ,
548+ ),
537549 num_features = 1280 if fix_stem_head else max (1280 , round_chs_fn (1280 )),
538550 stem_size = 32 ,
539551 fix_stem = fix_stem_head ,
@@ -764,7 +776,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
764776
765777
766778def _gen_efficientnetv2_base (
767- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
779+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs ):
768780 """ Creates an EfficientNet-V2 base model
769781
770782 Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -780,7 +792,7 @@ def _gen_efficientnetv2_base(
780792 ]
781793 round_chs_fn = partial (round_channels , multiplier = channel_multiplier , round_limit = 0. )
782794 model_kwargs = dict (
783- block_args = decode_arch_def (arch_def , depth_multiplier ),
795+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
784796 num_features = round_chs_fn (1280 ),
785797 stem_size = 32 ,
786798 round_chs_fn = round_chs_fn ,
@@ -831,7 +843,8 @@ def _gen_efficientnetv2_s(
831843 return model
832844
833845
834- def _gen_efficientnetv2_m (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
846+ def _gen_efficientnetv2_m (
847+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs ):
835848 """ Creates an EfficientNet-V2 Medium model
836849
837850 Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -849,7 +862,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
849862 ]
850863
851864 model_kwargs = dict (
852- block_args = decode_arch_def (arch_def , depth_multiplier ),
865+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
853866 num_features = 1280 ,
854867 stem_size = 24 ,
855868 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
@@ -861,7 +874,8 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
861874 return model
862875
863876
864- def _gen_efficientnetv2_l (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
877+ def _gen_efficientnetv2_l (
878+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs ):
865879 """ Creates an EfficientNet-V2 Large model
866880
867881 Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -879,7 +893,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
879893 ]
880894
881895 model_kwargs = dict (
882- block_args = decode_arch_def (arch_def , depth_multiplier ),
896+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
883897 num_features = 1280 ,
884898 stem_size = 32 ,
885899 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
@@ -891,7 +905,8 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
891905 return model
892906
893907
894- def _gen_efficientnetv2_xl (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
908+ def _gen_efficientnetv2_xl (
909+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs ):
895910 """ Creates an EfficientNet-V2 Xtra-Large model
896911
897912 Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -909,7 +924,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
909924 ]
910925
911926 model_kwargs = dict (
912- block_args = decode_arch_def (arch_def , depth_multiplier ),
927+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
913928 num_features = 1280 ,
914929 stem_size = 32 ,
915930 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
@@ -1094,7 +1109,8 @@ def _gen_tinynet(
10941109 return model
10951110
10961111
1097- def _gen_mobilenet_edgetpu (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
1112+ def _gen_mobilenet_edgetpu (
1113+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs ):
10981114 """
10991115 Based on definitions in: https://github.com/tensorflow/models/tree/d2427a562f401c9af118e47af2f030a0a5599f55/official/projects/edgetpu/vision
11001116 """
@@ -1170,7 +1186,7 @@ def _arch_def(chs: List[int], group_size: int):
11701186 ]
11711187
11721188 model_kwargs = dict (
1173- block_args = decode_arch_def (arch_def , depth_multiplier ),
1189+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
11741190 num_features = num_features ,
11751191 stem_size = stem_size ,
11761192 stem_kernel_size = stem_kernel_size ,
0 commit comments