Skip to content

Commit 9ec6824

Browse files
committed
Finally got around to adding EdgeTPU EfficientNet variant
1 parent daeaa11 commit 9ec6824

File tree

1 file changed

+206
-3
lines changed

1 file changed

+206
-3
lines changed

timm/models/gen_efficientnet.py

Lines changed: 206 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ def _cfg(url='', **kwargs):
8888
url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
8989
'efficientnet_b7': _cfg(
9090
url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
91+
'efficientnet_es': _cfg(
92+
url=''),
93+
'efficientnet_em': _cfg(
94+
url='',
95+
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
96+
'efficientnet_el': _cfg(
97+
url='',
98+
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
9199
'tf_efficientnet_b0': _cfg(
92100
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
93101
input_size=(3, 224, 224)),
@@ -112,6 +120,18 @@ def _cfg(url='', **kwargs):
112120
'tf_efficientnet_b7': _cfg(
113121
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth',
114122
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
123+
'tf_efficientnet_es': _cfg(
124+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
125+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
126+
input_size=(3, 224, 224), ),
127+
'tf_efficientnet_em': _cfg(
128+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth',
129+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
130+
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
131+
'tf_efficientnet_el': _cfg(
132+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
133+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
134+
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
115135
'mixnet_s': _cfg(
116136
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'),
117137
'mixnet_m': _cfg(
@@ -239,6 +259,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
239259
act_fn = options['n'] if 'n' in options else None
240260
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
241261
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
262+
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
242263

243264
num_repeat = int(options['r'])
244265
# each type of block has different valid arguments, fill accordingly
@@ -267,6 +288,19 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
267288
pw_act=block_type == 'dsa',
268289
noskip=block_type == 'dsa' or noskip,
269290
)
291+
elif block_type == 'er':
292+
block_args = dict(
293+
block_type=block_type,
294+
exp_kernel_size=_parse_ksize(options['k']),
295+
pw_kernel_size=pw_kernel_size,
296+
out_chs=int(options['c']),
297+
exp_ratio=float(options['e']),
298+
fake_in_chs=fake_in_chs,
299+
se_ratio=float(options['se']) if 'se' in options else None,
300+
stride=int(options['s']),
301+
act_fn=act_fn,
302+
noskip=noskip,
303+
)
270304
elif block_type == 'cn':
271305
block_args = dict(
272306
block_type=block_type,
@@ -356,6 +390,9 @@ def _make_block(self, ba):
356390
bt = ba.pop('block_type')
357391
ba['in_chs'] = self.in_chs
358392
ba['out_chs'] = self._round_channels(ba['out_chs'])
393+
if 'fake_in_chs' in ba and ba['fake_in_chs']:
394+
# FIXME this is a hack to work around mismatch in origin impl input filters
395+
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
359396
ba['bn_args'] = self.bn_args
360397
ba['pad_type'] = self.pad_type
361398
# block act fn overrides the model default
@@ -373,6 +410,13 @@ def _make_block(self, ba):
373410
if self.verbose:
374411
logging.info(' DepthwiseSeparable {}, Args: {}'.format(self.block_idx, str(ba)))
375412
block = DepthwiseSeparableConv(**ba)
413+
elif bt == 'er':
414+
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
415+
ba['se_gate_fn'] = self.se_gate_fn
416+
ba['se_reduce_mid'] = self.se_reduce_mid
417+
if self.verbose:
418+
logging.info(' EdgeResidual {}, Args: {}'.format(self.block_idx, str(ba)))
419+
block = EdgeResidual(**ba)
376420
elif bt == 'cn':
377421
if self.verbose:
378422
logging.info(' ConvBnAct {}, Args: {}'.format(self.block_idx, str(ba)))
@@ -519,10 +563,62 @@ def forward(self, x):
519563
return x
520564

521565

566+
class EdgeResidual(nn.Module):
567+
""" Residual block with expansion convolution followed by pointwise-linear w/ stride"""
568+
569+
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
570+
stride=1, pad_type='', act_fn=F.relu, noskip=False, pw_kernel_size=1,
571+
se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid,
572+
bn_args=_BN_ARGS_PT, drop_connect_rate=0.):
573+
super(EdgeResidual, self).__init__()
574+
mid_chs = int(fake_in_chs * exp_ratio) if fake_in_chs > 0 else int(in_chs * exp_ratio)
575+
self.has_se = se_ratio is not None and se_ratio > 0.
576+
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
577+
self.act_fn = act_fn
578+
self.drop_connect_rate = drop_connect_rate
579+
580+
# Expansion convolution
581+
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
582+
self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args)
583+
584+
# Squeeze-and-excitation
585+
if self.has_se:
586+
se_base_chs = mid_chs if se_reduce_mid else in_chs
587+
self.se = SqueezeExcite(
588+
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
589+
590+
# Point-wise linear projection
591+
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
592+
self.bn2 = nn.BatchNorm2d(out_chs, **bn_args)
593+
594+
def forward(self, x):
595+
residual = x
596+
597+
# Expansion convolution
598+
x = self.conv_exp(x)
599+
x = self.bn1(x)
600+
x = self.act_fn(x, inplace=True)
601+
602+
# Squeeze-and-excitation
603+
if self.has_se:
604+
x = self.se(x)
605+
606+
# Point-wise linear projection
607+
x = self.conv_pwl(x)
608+
x = self.bn2(x)
609+
610+
if self.has_residual:
611+
if self.drop_connect_rate > 0.:
612+
x = drop_connect(x, self.training, self.drop_connect_rate)
613+
x += residual
614+
615+
return x
616+
617+
522618
class DepthwiseSeparableConv(nn.Module):
523619
""" DepthwiseSeparable block
524620
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
525-
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
621+
factor of 1.0. This is an alternative to having a IR with an optional first pw conv.
526622
"""
527623
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
528624
stride=1, pad_type='', act_fn=F.relu, noskip=False,
@@ -1092,7 +1188,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
10921188
['ir_r4_k5_s2_e6_c192_se0.25'],
10931189
['ir_r1_k3_s1_e6_c320_se0.25'],
10941190
]
1095-
# NOTE: other models in the family didn't scale the feature count
10961191
num_features = _round_channels(1280, channel_multiplier, 8, None)
10971192
model = GenEfficientNet(
10981193
_decode_arch_def(arch_def, depth_multiplier),
@@ -1107,6 +1202,31 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
11071202
return model
11081203

11091204

1205+
def _gen_efficientnet_edge(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
1206+
arch_def = [
1207+
# NOTE `fc` is present to override a mismatch between stem channels and in chs not
1208+
# present in other models
1209+
['er_r1_k3_s1_e4_c24_fc24_noskip'],
1210+
['er_r2_k3_s2_e8_c32'],
1211+
['er_r4_k3_s2_e8_c48'],
1212+
['ir_r5_k5_s2_e8_c96'],
1213+
['ir_r4_k5_s1_e8_c144'],
1214+
['ir_r2_k5_s2_e8_c192'],
1215+
]
1216+
num_features = _round_channels(1280, channel_multiplier, 8, None)
1217+
model = GenEfficientNet(
1218+
_decode_arch_def(arch_def, depth_multiplier),
1219+
num_classes=num_classes,
1220+
stem_size=32,
1221+
channel_multiplier=channel_multiplier,
1222+
num_features=num_features,
1223+
bn_args=_resolve_bn_args(kwargs),
1224+
act_fn=F.relu,
1225+
**kwargs
1226+
)
1227+
return model
1228+
1229+
11101230
def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
11111231
"""Creates a MixNet Small model.
11121232
@@ -1481,7 +1601,6 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
14811601
return model
14821602

14831603

1484-
14851604
@register_model
14861605
def efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
14871606
""" EfficientNet-B6 """
@@ -1512,6 +1631,45 @@ def efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
15121631
return model
15131632

15141633

1634+
@register_model
1635+
def efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1636+
""" EfficientNet-Edge Small. """
1637+
default_cfg = default_cfgs['efficientnet_es']
1638+
model = _gen_efficientnet_edge(
1639+
channel_multiplier=1.0, depth_multiplier=1.0,
1640+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1641+
model.default_cfg = default_cfg
1642+
if pretrained:
1643+
load_pretrained(model, default_cfg, num_classes, in_chans)
1644+
return model
1645+
1646+
1647+
@register_model
1648+
def efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1649+
""" EfficientNet-Edge-Medium. """
1650+
default_cfg = default_cfgs['efficientnet_em']
1651+
model = _gen_efficientnet_edge(
1652+
channel_multiplier=1.0, depth_multiplier=1.1,
1653+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1654+
model.default_cfg = default_cfg
1655+
if pretrained:
1656+
load_pretrained(model, default_cfg, num_classes, in_chans)
1657+
return model
1658+
1659+
1660+
@register_model
1661+
def efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1662+
""" EfficientNet-Edge-Large. """
1663+
default_cfg = default_cfgs['efficientnet_el']
1664+
model = _gen_efficientnet_edge(
1665+
channel_multiplier=1.2, depth_multiplier=1.4,
1666+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1667+
model.default_cfg = default_cfg
1668+
if pretrained:
1669+
load_pretrained(model, default_cfg, num_classes, in_chans)
1670+
return model
1671+
1672+
15151673
@register_model
15161674
def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
15171675
""" EfficientNet-B0. Tensorflow compatible variant """
@@ -1634,6 +1792,51 @@ def tf_efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
16341792
return model
16351793

16361794

1795+
@register_model
1796+
def tf_efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1797+
""" EfficientNet-Edge Small. Tensorflow compatible variant """
1798+
default_cfg = default_cfgs['tf_efficientnet_es']
1799+
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
1800+
kwargs['pad_type'] = 'same'
1801+
model = _gen_efficientnet_edge(
1802+
channel_multiplier=1.0, depth_multiplier=1.0,
1803+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1804+
model.default_cfg = default_cfg
1805+
if pretrained:
1806+
load_pretrained(model, default_cfg, num_classes, in_chans)
1807+
return model
1808+
1809+
1810+
@register_model
1811+
def tf_efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1812+
""" EfficientNet-Edge-Medium. Tensorflow compatible variant """
1813+
default_cfg = default_cfgs['tf_efficientnet_em']
1814+
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
1815+
kwargs['pad_type'] = 'same'
1816+
model = _gen_efficientnet_edge(
1817+
channel_multiplier=1.0, depth_multiplier=1.1,
1818+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1819+
model.default_cfg = default_cfg
1820+
if pretrained:
1821+
load_pretrained(model, default_cfg, num_classes, in_chans)
1822+
return model
1823+
1824+
1825+
@register_model
1826+
def tf_efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1827+
""" EfficientNet-Edge-Large. Tensorflow compatible variant """
1828+
default_cfg = default_cfgs['tf_efficientnet_el']
1829+
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
1830+
kwargs['pad_type'] = 'same'
1831+
model = _gen_efficientnet_edge(
1832+
channel_multiplier=1.2, depth_multiplier=1.4,
1833+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1834+
model.default_cfg = default_cfg
1835+
if pretrained:
1836+
load_pretrained(model, default_cfg, num_classes, in_chans)
1837+
return model
1838+
1839+
16371840
@register_model
16381841
def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
16391842
"""Creates a MixNet Small model.

0 commit comments

Comments
 (0)