Skip to content

Commit 758802e

Browse files
authored
Merge pull request #636 from rwightman/effnetv2_official
EfficientNetV2 official impl w/ weights ported from TF.
2 parents c16d65a + 328249f commit 758802e

File tree

9 files changed

+608
-320
lines changed

9 files changed

+608
-320
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
2323

2424
## What's New
2525

26+
### May 14, 2021
27+
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
28+
* 1k trained variants: `tf_efficientnetv2_s/m/l`
29+
* 21k trained variants: `tf_efficientnetv2_s/m/l_21k`
30+
* 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_21ft1k`
31+
* v2 models w/ v1 scaling: `tf_efficientnet_v2_b0` through `b3`
32+
* Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s`
33+
* Some blank `efficientnetv2_*` models in-place for future native PyTorch training
34+
2635
### May 5, 2021
2736
* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen)
2837
* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit)

timm/models/efficientnet.py

Lines changed: 397 additions & 78 deletions
Large diffs are not rendered by default.

timm/models/efficientnet_blocks.py

Lines changed: 74 additions & 158 deletions
Large diffs are not rendered by default.

timm/models/efficientnet_builder.py

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,55 @@
1414
import torch.nn as nn
1515

1616
from .efficientnet_blocks import *
17-
from .layers import CondConv2d, get_condconv_initializer
17+
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible
1818

19-
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
19+
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
20+
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
2021

2122
_logger = logging.getLogger(__name__)
2223

2324

25+
_DEBUG_BUILDER = False
26+
27+
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
28+
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
29+
# NOTE: momentum varies btw .99 and .9997 depending on source
30+
# .99 in official TF TPU impl
31+
# .9997 (/w .999 in search space) for paper
32+
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
33+
BN_EPS_TF_DEFAULT = 1e-3
34+
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
35+
36+
37+
def get_bn_args_tf():
38+
return _BN_ARGS_TF.copy()
39+
40+
41+
def resolve_bn_args(kwargs):
42+
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
43+
bn_momentum = kwargs.pop('bn_momentum', None)
44+
if bn_momentum is not None:
45+
bn_args['momentum'] = bn_momentum
46+
bn_eps = kwargs.pop('bn_eps', None)
47+
if bn_eps is not None:
48+
bn_args['eps'] = bn_eps
49+
return bn_args
50+
51+
52+
def resolve_act_layer(kwargs, default='relu'):
53+
act_layer = kwargs.pop('act_layer', default)
54+
if isinstance(act_layer, str):
55+
act_layer = get_act_layer(act_layer)
56+
return act_layer
57+
58+
59+
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
60+
"""Round number of filters based on depth multiplier."""
61+
if not multiplier:
62+
return channels
63+
return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit)
64+
65+
2466
def _log_info_if(msg, condition):
2567
if condition:
2668
_logger.info(msg)
@@ -63,11 +105,13 @@ def _decode_block_str(block_str):
63105
block_type = ops[0] # take the block type off the front
64106
ops = ops[1:]
65107
options = {}
66-
noskip = False
108+
skip = None
67109
for op in ops:
68110
# string options being checked on individual basis, combine if they grow
69111
if op == 'noskip':
70-
noskip = True
112+
skip = False # force no skip connection
113+
elif op == 'skip':
114+
skip = True # force a skip connection
71115
elif op.startswith('n'):
72116
# activation fn
73117
key = op[0]
@@ -94,7 +138,7 @@ def _decode_block_str(block_str):
94138
act_layer = options['n'] if 'n' in options else None
95139
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
96140
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
97-
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
141+
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
98142

99143
num_repeat = int(options['r'])
100144
# each type of block has different valid arguments, fill accordingly
@@ -106,10 +150,10 @@ def _decode_block_str(block_str):
106150
pw_kernel_size=pw_kernel_size,
107151
out_chs=int(options['c']),
108152
exp_ratio=float(options['e']),
109-
se_ratio=float(options['se']) if 'se' in options else None,
153+
se_ratio=float(options['se']) if 'se' in options else 0.,
110154
stride=int(options['s']),
111155
act_layer=act_layer,
112-
noskip=noskip,
156+
noskip=skip is False,
113157
)
114158
if 'cc' in options:
115159
block_args['num_experts'] = int(options['cc'])
@@ -119,11 +163,11 @@ def _decode_block_str(block_str):
119163
dw_kernel_size=_parse_ksize(options['k']),
120164
pw_kernel_size=pw_kernel_size,
121165
out_chs=int(options['c']),
122-
se_ratio=float(options['se']) if 'se' in options else None,
166+
se_ratio=float(options['se']) if 'se' in options else 0.,
123167
stride=int(options['s']),
124168
act_layer=act_layer,
125169
pw_act=block_type == 'dsa',
126-
noskip=block_type == 'dsa' or noskip,
170+
noskip=block_type == 'dsa' or skip is False,
127171
)
128172
elif block_type == 'er':
129173
block_args = dict(
@@ -132,11 +176,11 @@ def _decode_block_str(block_str):
132176
pw_kernel_size=pw_kernel_size,
133177
out_chs=int(options['c']),
134178
exp_ratio=float(options['e']),
135-
fake_in_chs=fake_in_chs,
136-
se_ratio=float(options['se']) if 'se' in options else None,
179+
force_in_chs=force_in_chs,
180+
se_ratio=float(options['se']) if 'se' in options else 0.,
137181
stride=int(options['s']),
138182
act_layer=act_layer,
139-
noskip=noskip,
183+
noskip=skip is False,
140184
)
141185
elif block_type == 'cn':
142186
block_args = dict(
@@ -145,6 +189,7 @@ def _decode_block_str(block_str):
145189
out_chs=int(options['c']),
146190
stride=int(options['s']),
147191
act_layer=act_layer,
192+
skip=skip is True,
148193
)
149194
else:
150195
assert False, 'Unknown block type (%s)' % block_type
@@ -219,74 +264,63 @@ class EfficientNetBuilder:
219264
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
220265
221266
"""
222-
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
223-
output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
224-
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='',
225-
verbose=False):
226-
self.channel_multiplier = channel_multiplier
227-
self.channel_divisor = channel_divisor
228-
self.channel_min = channel_min
267+
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels,
268+
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
229269
self.output_stride = output_stride
230270
self.pad_type = pad_type
271+
self.round_chs_fn = round_chs_fn
231272
self.act_layer = act_layer
232-
self.se_kwargs = se_kwargs
233273
self.norm_layer = norm_layer
234-
self.norm_kwargs = norm_kwargs
274+
self.se_layer = se_layer
235275
self.drop_path_rate = drop_path_rate
236276
if feature_location == 'depthwise':
237277
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
238278
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
239279
feature_location = 'expansion'
240280
self.feature_location = feature_location
241281
assert feature_location in ('bottleneck', 'expansion', '')
242-
self.verbose = verbose
282+
self.verbose = _DEBUG_BUILDER
243283

244284
# state updated during build, consumed by model
245285
self.in_chs = None
246286
self.features = []
247287

248-
def _round_channels(self, chs):
249-
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
250-
251288
def _make_block(self, ba, block_idx, block_count):
252289
drop_path_rate = self.drop_path_rate * block_idx / block_count
253290
bt = ba.pop('block_type')
254291
ba['in_chs'] = self.in_chs
255-
ba['out_chs'] = self._round_channels(ba['out_chs'])
256-
if 'fake_in_chs' in ba and ba['fake_in_chs']:
257-
# FIXME this is a hack to work around mismatch in origin impl input filters
258-
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
259-
ba['norm_layer'] = self.norm_layer
260-
ba['norm_kwargs'] = self.norm_kwargs
292+
ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
293+
if 'force_in_chs' in ba and ba['force_in_chs']:
294+
# NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
295+
ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
261296
ba['pad_type'] = self.pad_type
262297
# block act fn overrides the model default
263298
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
264299
assert ba['act_layer'] is not None
265-
if bt == 'ir':
300+
ba['norm_layer'] = self.norm_layer
301+
if bt != 'cn':
302+
ba['se_layer'] = self.se_layer
266303
ba['drop_path_rate'] = drop_path_rate
267-
ba['se_kwargs'] = self.se_kwargs
304+
305+
if bt == 'ir':
268306
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
269307
if ba.get('num_experts', 0) > 0:
270308
block = CondConvResidual(**ba)
271309
else:
272310
block = InvertedResidual(**ba)
273311
elif bt == 'ds' or bt == 'dsa':
274-
ba['drop_path_rate'] = drop_path_rate
275-
ba['se_kwargs'] = self.se_kwargs
276312
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
277313
block = DepthwiseSeparableConv(**ba)
278314
elif bt == 'er':
279-
ba['drop_path_rate'] = drop_path_rate
280-
ba['se_kwargs'] = self.se_kwargs
281315
_log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
282316
block = EdgeResidual(**ba)
283317
elif bt == 'cn':
284318
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
285319
block = ConvBnAct(**ba)
286320
else:
287321
assert False, 'Uknkown block type (%s) while building model.' % bt
288-
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
289322

323+
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
290324
return block
291325

292326
def __call__(self, in_chs, model_block_args):

timm/models/ghostnet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414

1515
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16-
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid
17-
from .efficientnet_blocks import SqueezeExcite, ConvBnAct, make_divisible
16+
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible
17+
from .efficientnet_blocks import SqueezeExcite, ConvBnAct
1818
from .helpers import build_model_with_cfg
1919
from .registry import register_model
2020

@@ -110,7 +110,6 @@ def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
110110
nn.BatchNorm2d(out_chs),
111111
)
112112

113-
114113
def forward(self, x):
115114
shortcut = x
116115

timm/models/hardcorenas.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
from functools import partial
2+
13
import torch.nn as nn
2-
from .efficientnet_builder import decode_arch_def, resolve_bn_args
3-
from .mobilenetv3 import MobileNetV3, MobileNetV3Features, build_model_with_cfg, default_cfg_for_features
4-
from .layers import hard_sigmoid
5-
from .efficientnet_blocks import resolve_act_layer
6-
from .registry import register_model
4+
75
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
6+
from .efficientnet_blocks import SqueezeExcite
7+
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args
8+
from .helpers import build_model_with_cfg, default_cfg_for_features
9+
from .layers import get_act_fn
10+
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
11+
from .registry import register_model
812

913

1014
def _cfg(url='', **kwargs):
@@ -35,15 +39,15 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
3539
3640
"""
3741
num_features = 1280
38-
42+
se_layer = partial(
43+
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
3944
model_kwargs = dict(
4045
block_args=decode_arch_def(arch_def),
4146
num_features=num_features,
4247
stem_size=32,
43-
channel_multiplier=1,
44-
norm_kwargs=resolve_bn_args(kwargs),
48+
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
4549
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
46-
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
50+
se_layer=se_layer,
4751
**kwargs,
4852
)
4953

timm/models/layers/helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ def parse(x):
2222
to_ntuple = _ntuple
2323

2424

25-
def make_divisible(v, divisor=8, min_value=None):
25+
def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
2626
min_value = min_value or divisor
2727
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
2828
# Make sure that round down does not go down by more than 10%.
29-
if new_v < 0.9 * v:
29+
if new_v < round_limit * v:
3030
new_v += divisor
31-
return new_v
31+
return new_v

0 commit comments

Comments
 (0)