Skip to content

Commit b9f020a

Browse files
committed
Allow group_size override for more efficientnet and mobilenetv3 based models
1 parent 00c5be7 commit b9f020a

File tree

2 files changed

+38
-18
lines changed

2 files changed

+38
-18
lines changed

timm/models/efficientnet.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
488488

489489
def _gen_mobilenet_v1(
490490
variant, channel_multiplier=1.0, depth_multiplier=1.0,
491-
fix_stem_head=False, head_conv=False, pretrained=False, **kwargs):
491+
group_size=None, fix_stem_head=False, head_conv=False, pretrained=False, **kwargs):
492492
"""
493493
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
494494
Paper: https://arxiv.org/abs/1801.04381
@@ -503,7 +503,12 @@ def _gen_mobilenet_v1(
503503
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
504504
head_features = (1024 if fix_stem_head else max(1024, round_chs_fn(1024))) if head_conv else 0
505505
model_kwargs = dict(
506-
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
506+
block_args=decode_arch_def(
507+
arch_def,
508+
depth_multiplier=depth_multiplier,
509+
fix_first_last=fix_stem_head,
510+
group_size=group_size,
511+
),
507512
num_features=head_features,
508513
stem_size=32,
509514
fix_stem=fix_stem_head,
@@ -517,7 +522,9 @@ def _gen_mobilenet_v1(
517522

518523

519524
def _gen_mobilenet_v2(
520-
variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
525+
variant, channel_multiplier=1.0, depth_multiplier=1.0,
526+
group_size=None, fix_stem_head=False, pretrained=False, **kwargs
527+
):
521528
""" Generate MobileNet-V2 network
522529
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
523530
Paper: https://arxiv.org/abs/1801.04381
@@ -533,7 +540,12 @@ def _gen_mobilenet_v2(
533540
]
534541
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
535542
model_kwargs = dict(
536-
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
543+
block_args=decode_arch_def(
544+
arch_def,
545+
depth_multiplier=depth_multiplier,
546+
fix_first_last=fix_stem_head,
547+
group_size=group_size,
548+
),
537549
num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)),
538550
stem_size=32,
539551
fix_stem=fix_stem_head,
@@ -764,7 +776,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
764776

765777

766778
def _gen_efficientnetv2_base(
767-
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
779+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
768780
""" Creates an EfficientNet-V2 base model
769781
770782
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -780,7 +792,7 @@ def _gen_efficientnetv2_base(
780792
]
781793
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
782794
model_kwargs = dict(
783-
block_args=decode_arch_def(arch_def, depth_multiplier),
795+
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
784796
num_features=round_chs_fn(1280),
785797
stem_size=32,
786798
round_chs_fn=round_chs_fn,
@@ -831,7 +843,8 @@ def _gen_efficientnetv2_s(
831843
return model
832844

833845

834-
def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
846+
def _gen_efficientnetv2_m(
847+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
835848
""" Creates an EfficientNet-V2 Medium model
836849
837850
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -849,7 +862,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
849862
]
850863

851864
model_kwargs = dict(
852-
block_args=decode_arch_def(arch_def, depth_multiplier),
865+
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
853866
num_features=1280,
854867
stem_size=24,
855868
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
@@ -861,7 +874,8 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
861874
return model
862875

863876

864-
def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
877+
def _gen_efficientnetv2_l(
878+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
865879
""" Creates an EfficientNet-V2 Large model
866880
867881
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -879,7 +893,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
879893
]
880894

881895
model_kwargs = dict(
882-
block_args=decode_arch_def(arch_def, depth_multiplier),
896+
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
883897
num_features=1280,
884898
stem_size=32,
885899
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
@@ -891,7 +905,8 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
891905
return model
892906

893907

894-
def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
908+
def _gen_efficientnetv2_xl(
909+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
895910
""" Creates an EfficientNet-V2 Xtra-Large model
896911
897912
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -909,7 +924,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
909924
]
910925

911926
model_kwargs = dict(
912-
block_args=decode_arch_def(arch_def, depth_multiplier),
927+
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
913928
num_features=1280,
914929
stem_size=32,
915930
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
@@ -1094,7 +1109,8 @@ def _gen_tinynet(
10941109
return model
10951110

10961111

1097-
def _gen_mobilenet_edgetpu(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
1112+
def _gen_mobilenet_edgetpu(
1113+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
10981114
"""
10991115
Based on definitions in: https://github.com/tensorflow/models/tree/d2427a562f401c9af118e47af2f030a0a5599f55/official/projects/edgetpu/vision
11001116
"""
@@ -1170,7 +1186,7 @@ def _arch_def(chs: List[int], group_size: int):
11701186
]
11711187

11721188
model_kwargs = dict(
1173-
block_args=decode_arch_def(arch_def, depth_multiplier),
1189+
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
11741190
num_features=num_features,
11751191
stem_size=stem_size,
11761192
stem_kernel_size=stem_kernel_size,

timm/models/mobilenetv3.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,9 @@ def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrain
450450
return model
451451

452452

453-
def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
453+
def _gen_mobilenet_v3(
454+
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs
455+
) -> MobileNetV3:
454456
"""Creates a MobileNet-V3 model.
455457
456458
Ref impl: ?
@@ -533,7 +535,7 @@ def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained:
533535
]
534536
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
535537
model_kwargs = dict(
536-
block_args=decode_arch_def(arch_def),
538+
block_args=decode_arch_def(arch_def, group_size=group_size),
537539
num_features=num_features,
538540
stem_size=16,
539541
fix_stem=channel_multiplier < 0.75,
@@ -646,7 +648,9 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool =
646648
return model
647649

648650

649-
def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
651+
def _gen_mobilenet_v4(
652+
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs,
653+
) -> MobileNetV3:
650654
"""Creates a MobileNet-V4 model.
651655
652656
Ref impl: ?
@@ -877,7 +881,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
877881
assert False, f'Unknown variant {variant}.'
878882

879883
model_kwargs = dict(
880-
block_args=decode_arch_def(arch_def),
884+
block_args=decode_arch_def(arch_def, group_size=group_size),
881885
head_bias=False,
882886
head_norm=True,
883887
num_features=num_features,

0 commit comments

Comments
 (0)