Skip to content

Commit 9a25fdf

Browse files
authored
Merge pull request #297 from rwightman/ema_simplify
Simplified JIT compatible Ema module. Fixes for SiLU export and torchscript training w/ Linear layer.
2 parents c9ebe86 + cd72e66 commit 9a25fdf

File tree

15 files changed

+225
-107
lines changed

15 files changed

+225
-107
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_model_load_pretrained(model_name, batch_size):
121121
create_model(model_name, pretrained=True, in_chans=in_chans)
122122

123123
@pytest.mark.timeout(120)
124-
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
124+
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=['vit_*']))
125125
@pytest.mark.parametrize('batch_size', [1])
126126
def test_model_features_pretrained(model_name, batch_size):
127127
"""Create that pretrained weights load when features_only==True."""

timm/models/efficientnet.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
3535
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
3636
from .features import FeatureInfo, FeatureHooks
37-
from .helpers import build_model_with_cfg
37+
from .helpers import build_model_with_cfg, default_cfg_for_features
3838
from .layers import create_conv2d, create_classifier
3939
from .registry import register_model
4040

@@ -453,18 +453,20 @@ def forward(self, x) -> List[torch.Tensor]:
453453

454454

455455
def _create_effnet(model_kwargs, variant, pretrained=False):
456+
features_only = False
457+
model_cls = EfficientNet
456458
if model_kwargs.pop('features_only', False):
457-
load_strict = False
459+
features_only = True
458460
model_kwargs.pop('num_classes', 0)
459461
model_kwargs.pop('num_features', 0)
460462
model_kwargs.pop('head_conv', None)
461463
model_cls = EfficientNetFeatures
462-
else:
463-
load_strict = True
464-
model_cls = EfficientNet
465-
return build_model_with_cfg(
464+
model = build_model_with_cfg(
466465
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
467-
pretrained_strict=load_strict, **model_kwargs)
466+
pretrained_strict=not features_only, **model_kwargs)
467+
if features_only:
468+
model.default_cfg = default_cfg_for_features(model.default_cfg)
469+
return model
468470

469471

470472
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):

timm/models/helpers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.utils.model_zoo as model_zoo
1515

1616
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
17-
from .layers import Conv2dSame
17+
from .layers import Conv2dSame, Linear
1818

1919

2020
_logger = logging.getLogger(__name__)
@@ -234,7 +234,7 @@ def adapt_model_from_string(parent_module, model_string):
234234
if isinstance(old_module, nn.Linear):
235235
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
236236
num_features = state_dict[n + '.weight'][1]
237-
new_fc = nn.Linear(
237+
new_fc = Linear(
238238
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
239239
set_layer(new_module, n, new_fc)
240240
if hasattr(new_module, 'num_features'):
@@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant):
251251
return adapt_model_from_string(parent_module, f.read().strip())
252252

253253

254+
def default_cfg_for_features(default_cfg):
255+
default_cfg = deepcopy(default_cfg)
256+
# remove default pretrained cfg fields that don't have much relevance for feature backbone
257+
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
258+
for tr in to_remove:
259+
default_cfg.pop(tr, None)
260+
return default_cfg
261+
262+
254263
def build_model_with_cfg(
255264
model_cls: Callable,
256265
variant: str,
@@ -296,5 +305,6 @@ def build_model_with_cfg(
296305
else:
297306
assert False, f'Unknown feature class {feature_cls}'
298307
model = feature_cls(model, **feature_cfg)
308+
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
299309

300310
return model

timm/models/hrnet.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1919
from .features import FeatureInfo
20-
from .helpers import build_model_with_cfg
20+
from .helpers import build_model_with_cfg, default_cfg_for_features
2121
from .layers import create_classifier
2222
from .registry import register_model
2323
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
@@ -773,15 +773,17 @@ def forward(self, x) -> List[torch.tensor]:
773773

774774
def _create_hrnet(variant, pretrained, **model_kwargs):
775775
model_cls = HighResolutionNet
776-
strict = True
776+
features_only = False
777777
if model_kwargs.pop('features_only', False):
778778
model_cls = HighResolutionNetFeatures
779779
model_kwargs['num_classes'] = 0
780-
strict = False
781-
782-
return build_model_with_cfg(
780+
features_only = True
781+
model = build_model_with_cfg(
783782
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
784-
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
783+
model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs)
784+
if features_only:
785+
model.default_cfg = default_cfg_for_features(model.default_cfg)
786+
return model
785787

786788

787789
@register_model

timm/models/inception_v3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1111
from .helpers import build_model_with_cfg
1212
from .registry import register_model
13-
from .layers import trunc_normal_, create_classifier
13+
from .layers import trunc_normal_, create_classifier, Linear
1414

1515

1616
def _cfg(url='', **kwargs):
@@ -250,7 +250,7 @@ def __init__(self, in_channels, num_classes, conv_block=None):
250250
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
251251
self.conv1 = conv_block(128, 768, kernel_size=5)
252252
self.conv1.stddev = 0.01
253-
self.fc = nn.Linear(768, num_classes)
253+
self.fc = Linear(768, num_classes)
254254
self.fc.stddev = 0.001
255255

256256
def forward(self, x):

timm/models/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
1919
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
2020
from .inplace_abn import InplaceAbn
21+
from .linear import Linear
2122
from .mixed_conv2d import MixedConv2d
2223
from .norm_act import BatchNormAct2d
2324
from .padding import get_padding

timm/models/layers/activations.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,27 @@ def __init__(self, inplace: bool = False):
119119

120120
def forward(self, x):
121121
return hard_mish(x, self.inplace)
122+
123+
124+
class PReLU(nn.PReLU):
125+
"""Applies PReLU (w/ dummy inplace arg)
126+
"""
127+
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
128+
super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
129+
130+
def forward(self, input: torch.Tensor) -> torch.Tensor:
131+
return F.prelu(input, self.weight)
132+
133+
134+
def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
135+
return F.gelu(x)
136+
137+
138+
class GELU(nn.Module):
139+
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
140+
"""
141+
def __init__(self, inplace: bool = False):
142+
super(GELU, self).__init__()
143+
144+
def forward(self, input: torch.Tensor) -> torch.Tensor:
145+
return F.gelu(input)

timm/models/layers/classifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.nn import functional as F
77

88
from .adaptive_avgmax_pool import SelectAdaptivePool2d
9+
from .linear import Linear
910

1011

1112
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
@@ -21,7 +22,8 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
2122
elif use_conv:
2223
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
2324
else:
24-
fc = nn.Linear(num_pooled_features, num_classes, bias=True)
25+
# NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
26+
fc = Linear(num_pooled_features, num_classes, bias=True)
2527
return global_pool, fc
2628

2729

timm/models/layers/create_act.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
relu6=F.relu6,
2020
leaky_relu=F.leaky_relu,
2121
elu=F.elu,
22-
prelu=F.prelu,
2322
celu=F.celu,
2423
selu=F.selu,
25-
gelu=F.gelu,
24+
gelu=gelu,
2625
sigmoid=sigmoid,
2726
tanh=tanh,
2827
hard_sigmoid=hard_sigmoid,
@@ -56,10 +55,10 @@
5655
relu6=nn.ReLU6,
5756
leaky_relu=nn.LeakyReLU,
5857
elu=nn.ELU,
59-
prelu=nn.PReLU,
58+
prelu=PReLU,
6059
celu=nn.CELU,
6160
selu=nn.SELU,
62-
gelu=nn.GELU,
61+
gelu=GELU,
6362
sigmoid=Sigmoid,
6463
tanh=Tanh,
6564
hard_sigmoid=HardSigmoid,
@@ -98,7 +97,10 @@ def get_act_fn(name='relu'):
9897
# custom autograd, then fallback
9998
if name in _ACT_FN_ME:
10099
return _ACT_FN_ME[name]
101-
if not is_no_jit():
100+
if is_exportable() and name in ('silu', 'swish'):
101+
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
102+
return swish
103+
if not (is_no_jit() or is_exportable()):
102104
if name in _ACT_FN_JIT:
103105
return _ACT_FN_JIT[name]
104106
return _ACT_FN_DEFAULT[name]
@@ -114,7 +116,10 @@ def get_act_layer(name='relu'):
114116
if not (is_no_jit() or is_exportable() or is_scriptable()):
115117
if name in _ACT_LAYER_ME:
116118
return _ACT_LAYER_ME[name]
117-
if not is_no_jit():
119+
if is_exportable() and name in ('silu', 'swish'):
120+
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
121+
return Swish
122+
if not (is_no_jit() or is_exportable()):
118123
if name in _ACT_LAYER_JIT:
119124
return _ACT_LAYER_JIT[name]
120125
return _ACT_LAYER_DEFAULT[name]

timm/models/layers/linear.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
""" Linear layer (alternate definition)
2+
"""
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import nn as nn
6+
7+
8+
class Linear(nn.Linear):
9+
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
10+
11+
Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
12+
weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
13+
"""
14+
def forward(self, input: torch.Tensor) -> torch.Tensor:
15+
if torch.jit.is_scripting():
16+
bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
17+
return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
18+
else:
19+
return F.linear(input, self.weight, self.bias)

0 commit comments

Comments
 (0)