Skip to content

Commit bcec14d

Browse files
committed
Bring EfficientNet SE layer in line with others, pull se_ratio outside of blocks. Allows swapping w/ other attn layers.
1 parent 9611458 commit bcec14d

File tree

5 files changed

+57
-68
lines changed

5 files changed

+57
-68
lines changed

timm/models/efficientnet_blocks.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88
from torch.nn import functional as F
99

10-
from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer
10+
from .layers import create_conv2d, drop_path, make_divisible, create_act_layer
1111
from .layers.activations import sigmoid
1212

1313
__all__ = [
@@ -19,31 +19,32 @@ class SqueezeExcite(nn.Module):
1919
2020
Args:
2121
in_chs (int): input channels to layer
22-
se_ratio (float): ratio of squeeze reduction
22+
rd_ratio (float): ratio of squeeze reduction
2323
act_layer (nn.Module): activation layer of containing block
24-
gate_fn (Callable): attention gate function
24+
gate_layer (Callable): attention gate function
2525
force_act_layer (nn.Module): override block's activation fn if this is set/bound
26-
round_chs_fn (Callable): specify a fn to calculate rounding of reduced chs
26+
rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs
2727
"""
2828

2929
def __init__(
30-
self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid,
31-
force_act_layer=None, round_chs_fn=None):
30+
self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU,
31+
gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None):
3232
super(SqueezeExcite, self).__init__()
33-
round_chs_fn = round_chs_fn or round
34-
reduced_chs = round_chs_fn(in_chs * se_ratio)
33+
if rd_channels is None:
34+
rd_round_fn = rd_round_fn or round
35+
rd_channels = rd_round_fn(in_chs * rd_ratio)
3536
act_layer = force_act_layer or act_layer
36-
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
37+
self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)
3738
self.act1 = create_act_layer(act_layer, inplace=True)
38-
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
39-
self.gate_fn = get_act_fn(gate_fn)
39+
self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)
40+
self.gate = create_act_layer(gate_layer)
4041

4142
def forward(self, x):
4243
x_se = x.mean((2, 3), keepdim=True)
4344
x_se = self.conv_reduce(x_se)
4445
x_se = self.act1(x_se)
4546
x_se = self.conv_expand(x_se)
46-
return x * self.gate_fn(x_se)
47+
return x * self.gate(x_se)
4748

4849

4950
class ConvBnAct(nn.Module):
@@ -85,10 +86,9 @@ class DepthwiseSeparableConv(nn.Module):
8586
"""
8687
def __init__(
8788
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
88-
noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0.,
89-
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
89+
noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
90+
se_layer=None, drop_path_rate=0.):
9091
super(DepthwiseSeparableConv, self).__init__()
91-
has_se = se_layer is not None and se_ratio > 0.
9292
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
9393
self.has_pw_act = pw_act # activation after point-wise conv
9494
self.drop_path_rate = drop_path_rate
@@ -99,7 +99,7 @@ def __init__(
9999
self.act1 = act_layer(inplace=True)
100100

101101
# Squeeze-and-excitation
102-
self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
102+
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
103103

104104
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
105105
self.bn2 = norm_layer(out_chs)
@@ -144,12 +144,11 @@ class InvertedResidual(nn.Module):
144144

145145
def __init__(
146146
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
147-
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
148-
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
147+
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
148+
norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
149149
super(InvertedResidual, self).__init__()
150150
conv_kwargs = conv_kwargs or {}
151151
mid_chs = make_divisible(in_chs * exp_ratio)
152-
has_se = se_layer is not None and se_ratio > 0.
153152
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
154153
self.drop_path_rate = drop_path_rate
155154

@@ -166,7 +165,7 @@ def __init__(
166165
self.act2 = act_layer(inplace=True)
167166

168167
# Squeeze-and-excitation
169-
self.se = se_layer(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
168+
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
170169

171170
# Point-wise linear projection
172171
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
@@ -212,17 +211,17 @@ class CondConvResidual(InvertedResidual):
212211

213212
def __init__(
214213
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
215-
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
216-
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
214+
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
215+
norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
217216

218217
self.num_experts = num_experts
219218
conv_kwargs = dict(num_experts=self.num_experts)
220219

221220
super(CondConvResidual, self).__init__(
222221
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
223222
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
224-
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer,
225-
norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate)
223+
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
224+
drop_path_rate=drop_path_rate)
226225

227226
self.routing_fn = nn.Linear(in_chs, self.num_experts)
228227

@@ -271,8 +270,8 @@ class EdgeResidual(nn.Module):
271270

272271
def __init__(
273272
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
274-
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0.,
275-
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
273+
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU,
274+
norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
276275
super(EdgeResidual, self).__init__()
277276
if force_in_chs > 0:
278277
mid_chs = make_divisible(force_in_chs * exp_ratio)
@@ -289,7 +288,7 @@ def __init__(
289288
self.act1 = act_layer(inplace=True)
290289

291290
# Squeeze-and-excitation
292-
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
291+
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
293292

294293
# Point-wise linear projection
295294
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)

timm/models/efficientnet_builder.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
import math
1111
import re
1212
from copy import deepcopy
13+
from functools import partial
1314

1415
import torch.nn as nn
1516

1617
from .efficientnet_blocks import *
17-
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible
18+
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
1819

1920
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
2021
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
@@ -120,7 +121,9 @@ def _decode_block_str(block_str):
120121
elif v == 'hs':
121122
value = get_act_layer('hard_swish')
122123
elif v == 'sw':
123-
value = get_act_layer('swish')
124+
value = get_act_layer('swish') # aka SiLU
125+
elif v == 'mi':
126+
value = get_act_layer('mish')
124127
else:
125128
continue
126129
options[key] = value
@@ -273,7 +276,12 @@ def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, s
273276
self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
274277
self.act_layer = act_layer
275278
self.norm_layer = norm_layer
276-
self.se_layer = se_layer
279+
self.se_layer = get_attn(se_layer)
280+
try:
281+
self.se_layer(8, rd_ratio=1.0)
282+
self.se_has_ratio = True
283+
except RuntimeError as e:
284+
self.se_has_ratio = False
277285
self.drop_path_rate = drop_path_rate
278286
if feature_location == 'depthwise':
279287
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
@@ -300,18 +308,21 @@ def _make_block(self, ba, block_idx, block_count):
300308
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
301309
assert ba['act_layer'] is not None
302310
ba['norm_layer'] = self.norm_layer
311+
ba['drop_path_rate'] = drop_path_rate
303312
if bt != 'cn':
304-
ba['se_layer'] = self.se_layer
305-
if not self.se_from_exp and ba['se_ratio']:
306-
ba['se_ratio'] /= ba.get('exp_ratio', 1.0)
307-
ba['drop_path_rate'] = drop_path_rate
313+
se_ratio = ba.pop('se_ratio')
314+
if se_ratio and self.se_layer is not None:
315+
if not self.se_from_exp:
316+
# adjust se_ratio by expansion ratio if calculating se channels from block input
317+
se_ratio /= ba.get('exp_ratio', 1.0)
318+
if self.se_has_ratio:
319+
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
320+
else:
321+
ba['se_layer'] = self.se_layer
308322

309323
if bt == 'ir':
310324
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
311-
if ba.get('num_experts', 0) > 0:
312-
block = CondConvResidual(**ba)
313-
else:
314-
block = InvertedResidual(**ba)
325+
block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba)
315326
elif bt == 'ds' or bt == 'dsa':
316327
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
317328
block = DepthwiseSeparableConv(**ba)

timm/models/ghostnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _cfg(url='', **kwargs):
4040
}
4141

4242

43-
_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', round_chs_fn=partial(make_divisible, divisor=4))
43+
_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
4444

4545

4646
class GhostModule(nn.Module):
@@ -92,7 +92,7 @@ def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
9292
self.bn_dw = None
9393

9494
# Squeeze-and-excitation
95-
self.se = _SE_LAYER(mid_chs, se_ratio=se_ratio) if has_se else None
95+
self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None
9696

9797
# Point-wise linear projection
9898
self.ghost2 = GhostModule(mid_chs, out_chs, relu=False)

timm/models/hardcorenas.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
3939
4040
"""
4141
num_features = 1280
42-
se_layer = partial(
43-
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, round_chs_fn=round_channels)
42+
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
4443
model_kwargs = dict(
4544
block_args=decode_arch_def(arch_def),
4645
num_features=num_features,

timm/models/mobilenetv3.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
266266
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
267267
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
268268
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
269-
se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid')),
269+
se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'),
270270
**kwargs,
271271
)
272272
model = _create_mnv3(variant, pretrained, **model_kwargs)
@@ -354,8 +354,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
354354
# stage 6, 7x7 in
355355
['cn_r1_k1_s1_c960'], # hard-swish
356356
]
357-
se_layer = partial(
358-
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, round_chs_fn=round_channels)
357+
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
359358
model_kwargs = dict(
360359
block_args=decode_arch_def(arch_def),
361360
num_features=num_features,
@@ -372,67 +371,48 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
372371

373372
def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
374373
""" FBNetV3
374+
Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining`
375+
- https://arxiv.org/abs/2006.02049
375376
FIXME untested, this is a preliminary impl of some FBNet-V3 variants.
376377
"""
377378
vl = variant.split('_')[-1]
378379
if vl in ('a', 'b'):
379380
stem_size = 16
380381
arch_def = [
381-
# stage 0, 112x112 in
382382
['ds_r2_k3_s1_e1_c16'],
383-
# stage 1, 112x112 in
384383
['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'],
385-
# stage 2, 56x56 in
386384
['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'],
387-
# stage 3, 28x28 in
388385
['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
389-
# stage 4, 14x14in
390386
['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'],
391-
# stage 5, 14x14in
392387
['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'],
393-
# stage 6, 7x7 in
394388
['cn_r1_k1_s1_c1344'],
395389
]
396390
elif vl == 'd':
397391
stem_size = 24
398392
arch_def = [
399-
# stage 0, 112x112 in
400393
['ds_r2_k3_s1_e1_c16'],
401-
# stage 1, 112x112 in
402394
['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'],
403-
# stage 2, 56x56 in
404395
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'],
405-
# stage 3, 28x28 in
406396
['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
407-
# stage 4, 14x14in
408397
['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'],
409-
# stage 5, 14x14in
410398
['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'],
411-
# stage 6, 7x7 in
412399
['cn_r1_k1_s1_c1440'],
413400
]
414401
elif vl == 'g':
415402
stem_size = 32
416403
arch_def = [
417-
# stage 0, 112x112 in
418404
['ds_r3_k3_s1_e1_c24'],
419-
# stage 1, 112x112 in
420405
['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'],
421-
# stage 2, 56x56 in
422406
['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'],
423-
# stage 3, 28x28 in
424407
['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'],
425-
# stage 4, 14x14in
426408
['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'],
427-
# stage 5, 14x14in
428409
['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'],
429-
# stage 6, 7x7 in
430-
['cn_r1_k1_s1_c1728'], # hard-swish
410+
['cn_r1_k1_s1_c1728'],
431411
]
432412
else:
433413
raise NotImplemented
434414
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95)
435-
se_layer = partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), round_chs_fn=round_chs_fn)
415+
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn)
436416
act_layer = resolve_act_layer(kwargs, 'hard_swish')
437417
model_kwargs = dict(
438418
block_args=decode_arch_def(arch_def),

0 commit comments

Comments
 (0)