|
| 1 | +import torch.nn as nn |
| 2 | +from .efficientnet_builder import decode_arch_def, resolve_bn_args |
| 3 | +from .mobilenetv3 import MobileNetV3, MobileNetV3Features, build_model_with_cfg, default_cfg_for_features |
| 4 | +from .layers import hard_sigmoid |
| 5 | +from .efficientnet_blocks import resolve_act_layer |
| 6 | +from .registry import register_model |
| 7 | +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| 8 | + |
| 9 | + |
| 10 | +def _cfg(url='', **kwargs): |
| 11 | + return { |
| 12 | + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), |
| 13 | + 'crop_pct': 0.875, 'interpolation': 'bilinear', |
| 14 | + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, |
| 15 | + 'first_conv': 'conv_stem', 'classifier': 'classifier', |
| 16 | + **kwargs |
| 17 | + } |
| 18 | + |
| 19 | + |
| 20 | +default_cfgs = { |
| 21 | + 'hardcorenas_A': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_A_Green_38ms_75.9_23474aeb.pth'), |
| 22 | + 'hardcorenas_B': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_B_Green_40ms_76.5_1f882d1e.pth'), |
| 23 | + 'hardcorenas_C': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_C_Green_44ms_77.1_d4148c9e.pth'), |
| 24 | + 'hardcorenas_D': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_D_Green_50ms_77.4_23e3cdde.pth'), |
| 25 | + 'hardcorenas_E': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_E_Green_55ms_77.9_90f20e8a.pth'), |
| 26 | + 'hardcorenas_F': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_F_Green_60ms_78.1_2855edf1.pth'), |
| 27 | +} |
| 28 | + |
| 29 | +def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): |
| 30 | + """Creates a hardcorenas model |
| 31 | +
|
| 32 | + Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS |
| 33 | + Paper: https://arxiv.org/abs/2102.11646 |
| 34 | +
|
| 35 | + """ |
| 36 | + num_features = 1280 |
| 37 | + act_layer = resolve_act_layer(kwargs, 'hard_swish') |
| 38 | + |
| 39 | + model_kwargs = dict( |
| 40 | + block_args=decode_arch_def(arch_def), |
| 41 | + num_features=num_features, |
| 42 | + stem_size=32, |
| 43 | + channel_multiplier=1, |
| 44 | + norm_kwargs=resolve_bn_args(kwargs), |
| 45 | + act_layer=act_layer, |
| 46 | + se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), |
| 47 | + **kwargs, |
| 48 | + ) |
| 49 | + |
| 50 | + features_only = False |
| 51 | + model_cls = MobileNetV3 |
| 52 | + if model_kwargs.pop('features_only', False): |
| 53 | + features_only = True |
| 54 | + model_kwargs.pop('num_classes', 0) |
| 55 | + model_kwargs.pop('num_features', 0) |
| 56 | + model_kwargs.pop('head_conv', None) |
| 57 | + model_kwargs.pop('head_bias', None) |
| 58 | + model_cls = MobileNetV3Features |
| 59 | + model = build_model_with_cfg( |
| 60 | + model_cls, variant, pretrained, default_cfg=default_cfgs[variant], |
| 61 | + pretrained_strict=not features_only, **model_kwargs) |
| 62 | + if features_only: |
| 63 | + model.default_cfg = default_cfg_for_features(model.default_cfg) |
| 64 | + return model |
| 65 | + |
| 66 | + |
| 67 | +@register_model |
| 68 | +def hardcorenas_A(pretrained=False, **kwargs): |
| 69 | + """ hardcorenas_A """ |
| 70 | + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], |
| 71 | + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], |
| 72 | + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'], |
| 73 | + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'], |
| 74 | + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] |
| 75 | + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_A', arch_def=arch_def, **kwargs) |
| 76 | + return model |
| 77 | + |
| 78 | + |
| 79 | +@register_model |
| 80 | +def hardcorenas_B(pretrained=False, **kwargs): |
| 81 | + """ hardcorenas_B """ |
| 82 | + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], |
| 83 | + ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'], |
| 84 | + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'], |
| 85 | + ['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], |
| 86 | + ['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], |
| 87 | + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], |
| 88 | + ['cn_r1_k1_s1_c960']] |
| 89 | + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_B', arch_def=arch_def, **kwargs) |
| 90 | + return model |
| 91 | + |
| 92 | + |
| 93 | +@register_model |
| 94 | +def hardcorenas_C(pretrained=False, **kwargs): |
| 95 | + """ hardcorenas_C """ |
| 96 | + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], |
| 97 | + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', |
| 98 | + 'ir_r1_k5_s1_e3_c40_nre'], |
| 99 | + ['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], |
| 100 | + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], |
| 101 | + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], |
| 102 | + ['cn_r1_k1_s1_c960']] |
| 103 | + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_C', arch_def=arch_def, **kwargs) |
| 104 | + return model |
| 105 | + |
| 106 | + |
| 107 | +@register_model |
| 108 | +def hardcorenas_D(pretrained=False, **kwargs): |
| 109 | + """ hardcorenas_D """ |
| 110 | + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], |
| 111 | + ['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'], |
| 112 | + ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', |
| 113 | + 'ir_r1_k3_s1_e3_c80_se0.25'], |
| 114 | + ['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25', |
| 115 | + 'ir_r1_k5_s1_e3_c112_se0.25'], |
| 116 | + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', |
| 117 | + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] |
| 118 | + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_D', arch_def=arch_def, **kwargs) |
| 119 | + return model |
| 120 | + |
| 121 | + |
| 122 | +@register_model |
| 123 | +def hardcorenas_E(pretrained=False, **kwargs): |
| 124 | + """ hardcorenas_E """ |
| 125 | + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], |
| 126 | + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', |
| 127 | + 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'], |
| 128 | + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', |
| 129 | + 'ir_r1_k5_s1_e3_c112_se0.25'], |
| 130 | + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', |
| 131 | + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] |
| 132 | + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_E', arch_def=arch_def, **kwargs) |
| 133 | + return model |
| 134 | + |
| 135 | + |
| 136 | +@register_model |
| 137 | +def hardcorenas_F(pretrained=False, **kwargs): |
| 138 | + """ hardcorenas_F """ |
| 139 | + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], |
| 140 | + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], |
| 141 | + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', |
| 142 | + 'ir_r1_k3_s1_e3_c80_se0.25'], |
| 143 | + ['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', |
| 144 | + 'ir_r1_k3_s1_e3_c112_se0.25'], |
| 145 | + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25', |
| 146 | + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] |
| 147 | + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_F', arch_def=arch_def, **kwargs) |
| 148 | + return model |
0 commit comments