Skip to content

Commit 88129b2

Browse files
committed
Add set_layer_config contextmgr to adjust all layer configs at once, use in create_module with new args. Remove a few old warning causing constant annotations for jit.
1 parent f28170d commit 88129b2

File tree

8 files changed

+65
-18
lines changed

8 files changed

+65
-18
lines changed

timm/models/dpn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from __future__ import print_function
1111

1212
from collections import OrderedDict
13-
from typing import Union, Optional, List, Tuple
13+
from typing import Tuple
1414

1515
import torch
1616
import torch.nn as nn

timm/models/factory.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .registry import is_model, is_model_in_modules, model_entrypoint
22
from .helpers import load_checkpoint
3+
from .layers import set_layer_config
34

45

56
def create_model(
@@ -8,6 +9,9 @@ def create_model(
89
num_classes=1000,
910
in_chans=3,
1011
checkpoint_path='',
12+
scriptable=None,
13+
exportable=None,
14+
no_jit=None,
1115
**kwargs):
1216
"""Create a model
1317
@@ -17,13 +21,16 @@ def create_model(
1721
num_classes (int): number of classes for final fully connected layer (default: 1000)
1822
in_chans (int): number of input channels / colors (default: 3)
1923
checkpoint_path (str): path of checkpoint to load after model is initialized
24+
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
25+
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
26+
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
2027
2128
Keyword Args:
2229
drop_rate (float): dropout rate for training (default: 0.0)
2330
global_pool (str): global pool type (default: 'avg')
2431
**: other kwargs are model specific
2532
"""
26-
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
33+
model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
2734

2835
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
2936
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
@@ -47,11 +54,12 @@ def create_model(
4754
if kwargs.get('drop_path_rate', None) is None:
4855
kwargs.pop('drop_path_rate', None)
4956

50-
if is_model(model_name):
51-
create_fn = model_entrypoint(model_name)
52-
model = create_fn(**margs, **kwargs)
53-
else:
54-
raise RuntimeError('Unknown model (%s)' % model_name)
57+
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
58+
if is_model(model_name):
59+
create_fn = model_entrypoint(model_name)
60+
model = create_fn(**model_args, **kwargs)
61+
else:
62+
raise RuntimeError('Unknown model (%s)' % model_name)
5563

5664
if checkpoint_path:
5765
load_checkpoint(model, checkpoint_path)

timm/models/inception_resnet_v2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def forward(self, x):
193193

194194

195195
class Block8(nn.Module):
196-
__constants__ = ['relu'] # for pre 1.4 torchscript compat
197196

198197
def __init__(self, scale=1.0, no_relu=False):
199198
super(Block8, self).__init__()

timm/models/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from .anti_aliasing import AntiAliasDownsampleLayer
55
from .blur_pool import BlurPool2d
66
from .cond_conv2d import CondConv2d, get_condconv_initializer
7-
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable, is_no_jit, set_no_jit
7+
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
8+
set_layer_config
89
from .conv2d_same import Conv2dSame
910
from .conv_bn_act import ConvBnAct
1011
from .create_act import create_act_layer, get_act_layer, get_act_fn

timm/models/layers/cond_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class CondConv2d(nn.Module):
3838
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
3939
https://github.com/pytorch/pytorch/issues/17983
4040
"""
41-
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
41+
__constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
4242

4343
def __init__(self, in_channels, out_channels, kernel_size=3,
4444
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):

timm/models/layers/config.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1-
""" Model / Layer Config Singleton
1+
""" Model / Layer Config singleton state
22
"""
3-
from typing import Any
3+
from typing import Any, Optional
44

5-
__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable', 'is_no_jit', 'set_no_jit']
5+
__all__ = [
6+
'is_exportable', 'is_scriptable', 'is_no_jit',
7+
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
8+
]
69

710
# Set to True if prefer to have layers with no jit optimization (includes activations)
811
_NO_JIT = False
912

1013
# Set to True if prefer to have activation layers with no jit optimization
14+
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
15+
# the jit flags so far are activations. This will change as more layers are updated and/or added.
1116
_NO_ACTIVATION_JIT = False
1217

1318
# Set to True if exporting a model with Same padding via ONNX
@@ -72,3 +77,39 @@ def __exit__(self, *args: Any) -> bool:
7277
global _SCRIPTABLE
7378
_SCRIPTABLE = self.prev
7479
return False
80+
81+
82+
class set_layer_config:
83+
""" Layer config context manager that allows setting all layer config flags at once.
84+
If a flag arg is None, it will not change the current value.
85+
"""
86+
def __init__(
87+
self,
88+
scriptable: Optional[bool] = None,
89+
exportable: Optional[bool] = None,
90+
no_jit: Optional[bool] = None,
91+
no_activation_jit: Optional[bool] = None):
92+
global _SCRIPTABLE
93+
global _EXPORTABLE
94+
global _NO_JIT
95+
global _NO_ACTIVATION_JIT
96+
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
97+
if scriptable is not None:
98+
_SCRIPTABLE = scriptable
99+
if exportable is not None:
100+
_EXPORTABLE = exportable
101+
if no_jit is not None:
102+
_NO_JIT = no_jit
103+
if no_activation_jit is not None:
104+
_NO_ACTIVATION_JIT = no_activation_jit
105+
106+
def __enter__(self) -> None:
107+
pass
108+
109+
def __exit__(self, *args: Any) -> bool:
110+
global _SCRIPTABLE
111+
global _EXPORTABLE
112+
global _NO_JIT
113+
global _NO_ACTIVATION_JIT
114+
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
115+
return False

timm/models/layers/pool2d_same.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from typing import Union, List, Tuple, Optional
8+
from typing import List, Tuple, Optional
99

1010
from .helpers import tup_pair
1111
from .padding import pad_same, get_padding_value

validate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,13 @@ def validate(args):
8585
args.pretrained = args.pretrained or not args.checkpoint
8686
args.prefetcher = not args.no_prefetcher
8787

88-
if args.torchscript:
89-
set_scriptable(True)
90-
9188
# create model
9289
model = create_model(
9390
args.model,
91+
pretrained=args.pretrained,
9492
num_classes=args.num_classes,
9593
in_chans=3,
96-
pretrained=args.pretrained)
94+
scriptable=args.torchscript)
9795

9896
if args.checkpoint:
9997
load_checkpoint(model, args.checkpoint, args.use_ema)

0 commit comments

Comments
 (0)