Skip to content

Commit e5ba5dc

Browse files
committed
Merge branch 'yoniaflalo-adding_Hardcore_NAS'
2 parents 06aa926 + 1f799af commit e5ba5dc

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed

timm/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
from .vovnet import *
3030
from .xception import *
3131
from .xception_aligned import *
32+
from .hardcorenas import *
3233

3334
from .factory import create_model
3435
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
3536
from .layers import TestTimePoolHead, apply_test_time_pool
3637
from .layers import convert_splitbn_model
3738
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
3839
from .registry import *
40+

timm/models/hardcorenas.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import torch.nn as nn
2+
from .efficientnet_builder import decode_arch_def, resolve_bn_args
3+
from .mobilenetv3 import MobileNetV3, MobileNetV3Features, build_model_with_cfg, default_cfg_for_features
4+
from .layers import hard_sigmoid
5+
from .efficientnet_blocks import resolve_act_layer
6+
from .registry import register_model
7+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
8+
9+
10+
def _cfg(url='', **kwargs):
11+
return {
12+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
13+
'crop_pct': 0.875, 'interpolation': 'bilinear',
14+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
15+
'first_conv': 'conv_stem', 'classifier': 'classifier',
16+
**kwargs
17+
}
18+
19+
20+
default_cfgs = {
21+
'hardcorenas_a': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_A_Green_38ms_75.9_23474aeb.pth'),
22+
'hardcorenas_b': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_B_Green_40ms_76.5_1f882d1e.pth'),
23+
'hardcorenas_c': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_C_Green_44ms_77.1_d4148c9e.pth'),
24+
'hardcorenas_d': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_D_Green_50ms_77.4_23e3cdde.pth'),
25+
'hardcorenas_e': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_E_Green_55ms_77.9_90f20e8a.pth'),
26+
'hardcorenas_f': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_F_Green_60ms_78.1_2855edf1.pth'),
27+
}
28+
29+
30+
def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
31+
"""Creates a hardcorenas model
32+
33+
Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS
34+
Paper: https://arxiv.org/abs/2102.11646
35+
36+
"""
37+
num_features = 1280
38+
act_layer = resolve_act_layer(kwargs, 'hard_swish')
39+
40+
model_kwargs = dict(
41+
block_args=decode_arch_def(arch_def),
42+
num_features=num_features,
43+
stem_size=32,
44+
channel_multiplier=1,
45+
norm_kwargs=resolve_bn_args(kwargs),
46+
act_layer=act_layer,
47+
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
48+
**kwargs,
49+
)
50+
51+
features_only = False
52+
model_cls = MobileNetV3
53+
if model_kwargs.pop('features_only', False):
54+
features_only = True
55+
model_kwargs.pop('num_classes', 0)
56+
model_kwargs.pop('num_features', 0)
57+
model_kwargs.pop('head_conv', None)
58+
model_kwargs.pop('head_bias', None)
59+
model_cls = MobileNetV3Features
60+
model = build_model_with_cfg(
61+
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
62+
pretrained_strict=not features_only, **model_kwargs)
63+
if features_only:
64+
model.default_cfg = default_cfg_for_features(model.default_cfg)
65+
return model
66+
67+
68+
@register_model
69+
def hardcorenas_a(pretrained=False, **kwargs):
70+
""" hardcorenas_A """
71+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
72+
['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
73+
['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'],
74+
['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'],
75+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
76+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_a', arch_def=arch_def, **kwargs)
77+
return model
78+
79+
80+
@register_model
81+
def hardcorenas_b(pretrained=False, **kwargs):
82+
""" hardcorenas_B """
83+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'],
84+
['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'],
85+
['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'],
86+
['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
87+
['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
88+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
89+
['cn_r1_k1_s1_c960']]
90+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_b', arch_def=arch_def, **kwargs)
91+
return model
92+
93+
94+
@register_model
95+
def hardcorenas_c(pretrained=False, **kwargs):
96+
""" hardcorenas_C """
97+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
98+
['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre',
99+
'ir_r1_k5_s1_e3_c40_nre'],
100+
['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
101+
['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
102+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
103+
['cn_r1_k1_s1_c960']]
104+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_c', arch_def=arch_def, **kwargs)
105+
return model
106+
107+
108+
@register_model
109+
def hardcorenas_d(pretrained=False, **kwargs):
110+
""" hardcorenas_D """
111+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
112+
['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'],
113+
['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
114+
'ir_r1_k3_s1_e3_c80_se0.25'],
115+
['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25',
116+
'ir_r1_k5_s1_e3_c112_se0.25'],
117+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
118+
'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
119+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_d', arch_def=arch_def, **kwargs)
120+
return model
121+
122+
123+
@register_model
124+
def hardcorenas_e(pretrained=False, **kwargs):
125+
""" hardcorenas_E """
126+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
127+
['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25',
128+
'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'],
129+
['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
130+
'ir_r1_k5_s1_e3_c112_se0.25'],
131+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
132+
'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
133+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_e', arch_def=arch_def, **kwargs)
134+
return model
135+
136+
137+
@register_model
138+
def hardcorenas_f(pretrained=False, **kwargs):
139+
""" hardcorenas_F """
140+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
141+
['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
142+
['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
143+
'ir_r1_k3_s1_e3_c80_se0.25'],
144+
['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
145+
'ir_r1_k3_s1_e3_c112_se0.25'],
146+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25',
147+
'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
148+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_f', arch_def=arch_def, **kwargs)
149+
return model

0 commit comments

Comments
 (0)