Skip to content

Commit c4f482a

Browse files
committed
EfficientNetV2 official impl w/ weights ported from TF. Cleanup/refactor of related EfficientNet classes and models.
1 parent c16d65a commit c4f482a

File tree

7 files changed

+587
-318
lines changed

7 files changed

+587
-318
lines changed

timm/models/efficientnet.py

Lines changed: 386 additions & 77 deletions
Large diffs are not rendered by default.

timm/models/efficientnet_blocks.py

Lines changed: 74 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -7,106 +7,34 @@
77
import torch.nn as nn
88
from torch.nn import functional as F
99

10-
from .layers import create_conv2d, drop_path, get_act_layer
10+
from .layers import create_conv2d, drop_path, make_divisible
1111
from .layers.activations import sigmoid
1212

13-
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
14-
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
15-
# NOTE: momentum varies btw .99 and .9997 depending on source
16-
# .99 in official TF TPU impl
17-
# .9997 (/w .999 in search space) for paper
18-
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
19-
BN_EPS_TF_DEFAULT = 1e-3
20-
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
21-
22-
23-
def get_bn_args_tf():
24-
return _BN_ARGS_TF.copy()
25-
26-
27-
def resolve_bn_args(kwargs):
28-
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
29-
bn_momentum = kwargs.pop('bn_momentum', None)
30-
if bn_momentum is not None:
31-
bn_args['momentum'] = bn_momentum
32-
bn_eps = kwargs.pop('bn_eps', None)
33-
if bn_eps is not None:
34-
bn_args['eps'] = bn_eps
35-
return bn_args
36-
37-
38-
_SE_ARGS_DEFAULT = dict(
39-
gate_fn=sigmoid,
40-
act_layer=None,
41-
reduce_mid=False,
42-
divisor=1)
43-
44-
45-
def resolve_se_args(kwargs, in_chs, act_layer=None):
46-
se_kwargs = kwargs.copy() if kwargs is not None else {}
47-
# fill in args that aren't specified with the defaults
48-
for k, v in _SE_ARGS_DEFAULT.items():
49-
se_kwargs.setdefault(k, v)
50-
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
51-
if not se_kwargs.pop('reduce_mid'):
52-
se_kwargs['reduced_base_chs'] = in_chs
53-
# act_layer override, if it remains None, the containing block's act_layer will be used
54-
if se_kwargs['act_layer'] is None:
55-
assert act_layer is not None
56-
se_kwargs['act_layer'] = act_layer
57-
return se_kwargs
58-
59-
60-
def resolve_act_layer(kwargs, default='relu'):
61-
act_layer = kwargs.pop('act_layer', default)
62-
if isinstance(act_layer, str):
63-
act_layer = get_act_layer(act_layer)
64-
return act_layer
65-
66-
67-
def make_divisible(v, divisor=8, min_value=None):
68-
min_value = min_value or divisor
69-
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
70-
# Make sure that round down does not go down by more than 10%.
71-
if new_v < 0.9 * v:
72-
new_v += divisor
73-
return new_v
74-
75-
76-
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
77-
"""Round number of filters based on depth multiplier."""
78-
if not multiplier:
79-
return channels
80-
channels *= multiplier
81-
return make_divisible(channels, divisor, channel_min)
82-
83-
84-
class ChannelShuffle(nn.Module):
85-
# FIXME haven't used yet
86-
def __init__(self, groups):
87-
super(ChannelShuffle, self).__init__()
88-
self.groups = groups
89-
90-
def forward(self, x):
91-
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
92-
N, C, H, W = x.size()
93-
g = self.groups
94-
assert C % g == 0, "Incompatible group size {} for input channel {}".format(
95-
g, C
96-
)
97-
return (
98-
x.view(N, g, int(C / g), H, W)
99-
.permute(0, 2, 1, 3, 4)
100-
.contiguous()
101-
.view(N, C, H, W)
102-
)
13+
__all__ = [
14+
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']
10315

10416

10517
class SqueezeExcite(nn.Module):
106-
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
107-
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
18+
""" Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family
19+
20+
Args:
21+
in_chs (int): input channels to layer
22+
se_ratio (float): ratio of squeeze reduction
23+
act_layer (nn.Module): activation layer of containing block
24+
gate_fn (Callable): attention gate function
25+
block_in_chs (int): input channels of containing block (for calculating reduction from)
26+
reduce_from_block (bool): calculate reduction from block input channels if True
27+
force_act_layer (nn.Module): override block's activation fn if this is set/bound
28+
divisor (int): make reduction channels divisible by this
29+
"""
30+
31+
def __init__(
32+
self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid,
33+
block_in_chs=None, reduce_from_block=True, force_act_layer=None, divisor=1):
10834
super(SqueezeExcite, self).__init__()
109-
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
35+
reduced_chs = (block_in_chs or in_chs) if reduce_from_block else in_chs
36+
reduced_chs = make_divisible(reduced_chs * se_ratio, divisor)
37+
act_layer = force_act_layer or act_layer
11038
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
11139
self.act1 = act_layer(inplace=True)
11240
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
@@ -121,13 +49,16 @@ def forward(self, x):
12149

12250

12351
class ConvBnAct(nn.Module):
124-
def __init__(self, in_chs, out_chs, kernel_size,
125-
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
126-
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
52+
""" Conv + Norm Layer + Activation w/ optional skip connection
53+
"""
54+
def __init__(
55+
self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='',
56+
skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.):
12757
super(ConvBnAct, self).__init__()
128-
norm_kwargs = norm_kwargs or {}
58+
self.has_residual = skip and stride == 1 and in_chs == out_chs
59+
self.drop_path_rate = drop_path_rate
12960
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
130-
self.bn1 = norm_layer(out_chs, **norm_kwargs)
61+
self.bn1 = norm_layer(out_chs)
13162
self.act1 = act_layer(inplace=True)
13263

13364
def feature_info(self, location):
@@ -138,9 +69,14 @@ def feature_info(self, location):
13869
return info
13970

14071
def forward(self, x):
72+
shortcut = x
14173
x = self.conv(x)
14274
x = self.bn1(x)
14375
x = self.act1(x)
76+
if self.has_residual:
77+
if self.drop_path_rate > 0.:
78+
x = drop_path(x, self.drop_path_rate, self.training)
79+
x += shortcut
14480
return x
14581

14682

@@ -149,31 +85,26 @@ class DepthwiseSeparableConv(nn.Module):
14985
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
15086
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
15187
"""
152-
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
153-
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
154-
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
155-
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
88+
def __init__(
89+
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
90+
noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0.,
91+
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
15692
super(DepthwiseSeparableConv, self).__init__()
157-
norm_kwargs = norm_kwargs or {}
158-
has_se = se_ratio is not None and se_ratio > 0.
93+
has_se = se_layer is not None and se_ratio > 0.
15994
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
16095
self.has_pw_act = pw_act # activation after point-wise conv
16196
self.drop_path_rate = drop_path_rate
16297

16398
self.conv_dw = create_conv2d(
16499
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
165-
self.bn1 = norm_layer(in_chs, **norm_kwargs)
100+
self.bn1 = norm_layer(in_chs)
166101
self.act1 = act_layer(inplace=True)
167102

168103
# Squeeze-and-excitation
169-
if has_se:
170-
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
171-
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
172-
else:
173-
self.se = None
104+
self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
174105

175106
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
176-
self.bn2 = norm_layer(out_chs, **norm_kwargs)
107+
self.bn2 = norm_layer(out_chs)
177108
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
178109

179110
def feature_info(self, location):
@@ -190,8 +121,7 @@ def forward(self, x):
190121
x = self.bn1(x)
191122
x = self.act1(x)
192123

193-
if self.se is not None:
194-
x = self.se(x)
124+
x = self.se(x)
195125

196126
x = self.conv_pw(x)
197127
x = self.bn2(x)
@@ -214,41 +144,36 @@ class InvertedResidual(nn.Module):
214144
* MobileNet-V3 - https://arxiv.org/abs/1905.02244
215145
"""
216146

217-
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
218-
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
219-
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
220-
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
221-
conv_kwargs=None, drop_path_rate=0.):
147+
def __init__(
148+
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
149+
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
150+
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
222151
super(InvertedResidual, self).__init__()
223-
norm_kwargs = norm_kwargs or {}
224152
conv_kwargs = conv_kwargs or {}
225153
mid_chs = make_divisible(in_chs * exp_ratio)
226-
has_se = se_ratio is not None and se_ratio > 0.
154+
has_se = se_layer is not None and se_ratio > 0.
227155
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
228156
self.drop_path_rate = drop_path_rate
229157

230158
# Point-wise expansion
231159
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
232-
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
160+
self.bn1 = norm_layer(mid_chs)
233161
self.act1 = act_layer(inplace=True)
234162

235163
# Depth-wise convolution
236164
self.conv_dw = create_conv2d(
237165
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
238166
padding=pad_type, depthwise=True, **conv_kwargs)
239-
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
167+
self.bn2 = norm_layer(mid_chs)
240168
self.act2 = act_layer(inplace=True)
241169

242170
# Squeeze-and-excitation
243-
if has_se:
244-
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
245-
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
246-
else:
247-
self.se = None
171+
self.se = se_layer(
172+
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
248173

249174
# Point-wise linear projection
250175
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
251-
self.bn3 = norm_layer(out_chs, **norm_kwargs)
176+
self.bn3 = norm_layer(out_chs)
252177

253178
def feature_info(self, location):
254179
if location == 'expansion': # after SE, input to PWL
@@ -271,8 +196,7 @@ def forward(self, x):
271196
x = self.act2(x)
272197

273198
# Squeeze-and-excitation
274-
if self.se is not None:
275-
x = self.se(x)
199+
x = self.se(x)
276200

277201
# Point-wise linear projection
278202
x = self.conv_pwl(x)
@@ -289,21 +213,19 @@ def forward(self, x):
289213
class CondConvResidual(InvertedResidual):
290214
""" Inverted residual block w/ CondConv routing"""
291215

292-
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
293-
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
294-
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
295-
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
296-
num_experts=0, drop_path_rate=0.):
216+
def __init__(
217+
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
218+
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
219+
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
297220

298221
self.num_experts = num_experts
299222
conv_kwargs = dict(num_experts=self.num_experts)
300223

301224
super(CondConvResidual, self).__init__(
302225
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
303226
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
304-
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
305-
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
306-
drop_path_rate=drop_path_rate)
227+
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer,
228+
norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate)
307229

308230
self.routing_fn = nn.Linear(in_chs, self.num_experts)
309231

@@ -325,8 +247,7 @@ def forward(self, x):
325247
x = self.act2(x)
326248

327249
# Squeeze-and-excitation
328-
if self.se is not None:
329-
x = self.se(x)
250+
x = self.se(x)
330251

331252
# Point-wise linear projection
332253
x = self.conv_pwl(x, routing_weights)
@@ -351,36 +272,32 @@ class EdgeResidual(nn.Module):
351272
* EfficientNet-V2 - https://arxiv.org/abs/2104.00298
352273
"""
353274

354-
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
355-
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
356-
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
357-
drop_path_rate=0.):
275+
def __init__(
276+
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
277+
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0.,
278+
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
358279
super(EdgeResidual, self).__init__()
359-
norm_kwargs = norm_kwargs or {}
360-
if fake_in_chs > 0:
361-
mid_chs = make_divisible(fake_in_chs * exp_ratio)
280+
if force_in_chs > 0:
281+
mid_chs = make_divisible(force_in_chs * exp_ratio)
362282
else:
363283
mid_chs = make_divisible(in_chs * exp_ratio)
364-
has_se = se_ratio is not None and se_ratio > 0.
284+
has_se = se_layer is not None and se_ratio > 0.
365285
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
366286
self.drop_path_rate = drop_path_rate
367287

368288
# Expansion convolution
369289
self.conv_exp = create_conv2d(
370290
in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
371-
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
291+
self.bn1 = norm_layer(mid_chs)
372292
self.act1 = act_layer(inplace=True)
373293

374294
# Squeeze-and-excitation
375-
if has_se:
376-
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
377-
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
378-
else:
379-
self.se = None
295+
self.se = SqueezeExcite(
296+
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
380297

381298
# Point-wise linear projection
382299
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
383-
self.bn2 = norm_layer(out_chs, **norm_kwargs)
300+
self.bn2 = norm_layer(out_chs)
384301

385302
def feature_info(self, location):
386303
if location == 'expansion': # after SE, before PWL
@@ -398,8 +315,7 @@ def forward(self, x):
398315
x = self.act1(x)
399316

400317
# Squeeze-and-excitation
401-
if self.se is not None:
402-
x = self.se(x)
318+
x = self.se(x)
403319

404320
# Point-wise linear projection
405321
x = self.conv_pwl(x)

0 commit comments

Comments
 (0)