Skip to content

Commit 1019414

Browse files
committed
Add ported Tensorflow EfficientNet B4/B5 weights
1 parent c9a61b7 commit 1019414

File tree

3 files changed

+70
-3
lines changed

3 files changed

+70
-3
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,19 @@ I've leveraged the training scripts in this repository to train a few of the mod
129129
| tf_inception_v3 | 77.856 (22.144) | 93.644 (6.356) | 27.16M | bicubic | [Tensorflow Slim](https://github.com/tensorflow/models/tree/master/research/slim) |
130130
| adv_inception_v3 | 77.576 (22.424) | 93.724 (6.276) | 27.16M | bicubic | [Tensorflow Adv models](https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models) |
131131

132+
#### @ 380x380
133+
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Source |
134+
|---|---|---|---|---|---|
135+
| tf_efficientnet_b4 | 82.604 (17.396) | 96.128 (3.872) | 19.34 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
136+
| tf_efficientnet_b4 *tfp | 82.604 (17.396) | 96.094 (3.906) | 19.34 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
137+
138+
#### @ 456x456
139+
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Source |
140+
|---|---|---|---|---|---|
141+
| tf_efficientnet_b5 *tfp | 83.200 (16.800) | 96.456 (3.544) | 30.39 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
142+
| tf_efficientnet_b5 | 83.176 (16.824) | 96.536 (3.464) | 30.39 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
143+
144+
132145
NOTE: For some reason I can't hit the stated accuracy with my impl of MNASNet and Google's tflite weights. Using a TF equivalent to 'SAME' padding was important to get > 70%, but something small is still missing. Trying to train my own weights from scratch with these models has so far to leveled off in the same 72-73% range.
133146

134147
Models with `*tfp` next to them were scored with `--tf-preprocessing` flag.

models/gen_efficientnet.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
3232
'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
3333
'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100',
34-
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0',
35-
'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'tf_efficientnet_b0',
36-
'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3']
34+
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0', 'efficientnet_b1',
35+
'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'tf_efficientnet_b0',
36+
'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', 'tf_efficientnet_b4', 'tf_efficientnet_b5']
3737
__all__ = ['GenEfficientNet', 'gen_efficientnet_model_names'] + _models
3838

3939

@@ -91,6 +91,8 @@ def _cfg(url='', **kwargs):
9191
url='', input_size=(3, 300, 300), pool_size=(10, 10)),
9292
'efficientnet_b4': _cfg(
9393
url='', input_size=(3, 380, 380), pool_size=(12, 12)),
94+
'efficientnet_b5': _cfg(
95+
url='', input_size=(3, 456, 456), pool_size=(15, 15)),
9496
'tf_efficientnet_b0': _cfg(
9597
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth',
9698
input_size=(3, 224, 224), interpolation='bicubic'),
@@ -103,8 +105,15 @@ def _cfg(url='', **kwargs):
103105
'tf_efficientnet_b3': _cfg(
104106
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth',
105107
input_size=(3, 300, 300), pool_size=(10, 10), interpolation='bicubic', crop_pct=0.904),
108+
'tf_efficientnet_b4': _cfg(
109+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth',
110+
input_size=(3, 380, 380), pool_size=(12, 12), interpolation='bicubic', crop_pct=0.922),
111+
'tf_efficientnet_b5': _cfg(
112+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth',
113+
input_size=(3, 456, 456), pool_size=(15, 15), interpolation='bicubic', crop_pct=0.934)
106114
}
107115

116+
108117
_DEBUG = False
109118

110119
# Default args for PyTorch BN impl
@@ -1436,6 +1445,19 @@ def efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
14361445
return model
14371446

14381447

1448+
def efficientnet_b5(num_classes, in_chans=3, pretrained=False, **kwargs):
1449+
""" EfficientNet-B5 """
1450+
# NOTE for train, drop_rate should be 0.4
1451+
default_cfg = default_cfgs['efficientnet_b5']
1452+
model = _gen_efficientnet(
1453+
channel_multiplier=1.6, depth_multiplier=2.2,
1454+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1455+
model.default_cfg = default_cfg
1456+
if pretrained:
1457+
load_pretrained(model, default_cfg, num_classes, in_chans)
1458+
return model
1459+
1460+
14391461
def tf_efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
14401462
""" EfficientNet-B0. Tensorflow compatible variant """
14411463
default_cfg = default_cfgs['tf_efficientnet_b0']
@@ -1492,5 +1514,33 @@ def tf_efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
14921514
return model
14931515

14941516

1517+
def tf_efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
1518+
""" EfficientNet-B4. Tensorflow compatible variant """
1519+
default_cfg = default_cfgs['tf_efficientnet_b4']
1520+
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
1521+
kwargs['padding_same'] = True
1522+
model = _gen_efficientnet(
1523+
channel_multiplier=1.4, depth_multiplier=1.8,
1524+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1525+
model.default_cfg = default_cfg
1526+
if pretrained:
1527+
load_pretrained(model, default_cfg, num_classes, in_chans)
1528+
return model
1529+
1530+
1531+
def tf_efficientnet_b5(num_classes, in_chans=3, pretrained=False, **kwargs):
1532+
""" EfficientNet-B5. Tensorflow compatible variant """
1533+
default_cfg = default_cfgs['tf_efficientnet_b5']
1534+
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
1535+
kwargs['padding_same'] = True
1536+
model = _gen_efficientnet(
1537+
channel_multiplier=1.6, depth_multiplier=2.2,
1538+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1539+
model.default_cfg = default_cfg
1540+
if pretrained:
1541+
load_pretrained(model, default_cfg, num_classes, in_chans)
1542+
return model
1543+
1544+
14951545
def gen_efficientnet_model_names():
14961546
return set(_models)

models/helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
5353

5454

5555
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
56+
if 'url' not in default_cfg or not default_cfg['url']:
57+
print("Warning: pretrained model URL is invalid, using random initialization.")
58+
return
59+
5660
state_dict = model_zoo.load_url(default_cfg['url'])
5761

5862
if in_chans == 1:

0 commit comments

Comments
 (0)