Skip to content

Commit 780860d

Browse files
committed
Add norm_act factory method, move JIT of norm layers to factory
1 parent 14edacd commit 780860d

File tree

4 files changed

+74
-79
lines changed

4 files changed

+74
-79
lines changed

timm/models/densenet.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
import re
66
from collections import OrderedDict
7+
from functools import partial
78

89
import torch
910
import torch.nn as nn
@@ -13,7 +14,7 @@
1314

1415
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1516
from .helpers import load_pretrained
16-
from .layers import SelectAdaptivePool2d, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d
17+
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act
1718
from .registry import register_model
1819

1920
__all__ = ['DenseNet']
@@ -327,9 +328,11 @@ def densenet121d_evob(pretrained=False, **kwargs):
327328
r"""Densenet-121 model from
328329
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
329330
"""
331+
def norm_act_fn(num_features, **kwargs):
332+
return create_norm_act('EvoNormBatch', num_features, jit=True, **kwargs)
330333
model = _densenet(
331334
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
332-
norm_act_layer=EvoNormBatch2d, pretrained=pretrained, **kwargs)
335+
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
333336
return model
334337

335338

@@ -338,9 +341,11 @@ def densenet121d_evos(pretrained=False, **kwargs):
338341
r"""Densenet-121 model from
339342
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
340343
"""
344+
def norm_act_fn(num_features, **kwargs):
345+
return create_norm_act('EvoNormSample', num_features, jit=True, **kwargs)
341346
model = _densenet(
342347
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
343-
norm_act_layer=EvoNormSample2d, pretrained=pretrained, **kwargs)
348+
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
344349
return model
345350

346351

@@ -349,10 +354,11 @@ def densenet121d_iabn(pretrained=False, **kwargs):
349354
r"""Densenet-121 model from
350355
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
351356
"""
352-
from inplace_abn import InPlaceABN
357+
def norm_act_fn(num_features, **kwargs):
358+
return create_norm_act('iabn', num_features, **kwargs)
353359
model = _densenet(
354360
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
355-
norm_act_layer=InPlaceABN, pretrained=pretrained, **kwargs)
361+
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
356362
return model
357363

358364

timm/models/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@
2020
from .space_to_depth import SpaceToDepthModule
2121
from .blur_pool import BlurPool2d
2222
from .norm_act import BatchNormAct2d
23-
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
23+
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
24+
from .create_norm_act import create_norm_act
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
5+
from .norm_act import BatchNormAct2d
6+
try:
7+
from inplace_abn import InPlaceABN
8+
has_iabn = True
9+
except ImportError:
10+
has_iabn = False
11+
12+
13+
def create_norm_act(layer_type, num_features, jit=False, **kwargs):
14+
layer_parts = layer_type.split('_')
15+
assert len(layer_parts) in (1, 2)
16+
layer_class = layer_parts[0].lower()
17+
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection
18+
19+
if layer_class == "batchnormact":
20+
layer = BatchNormAct2d(num_features, **kwargs) # defaults to RELU of no kwargs override
21+
elif layer_class == "batchnormrelu":
22+
assert 'act_layer' not in kwargs
23+
layer = BatchNormAct2d(num_features, act_layer=nn.ReLU, **kwargs)
24+
elif layer_class == "evonormbatch":
25+
layer = EvoNormBatch2d(num_features, **kwargs)
26+
elif layer_class == "evonormsample":
27+
layer = EvoNormSample2d(num_features, **kwargs)
28+
elif layer_class == "iabn" or layer_class == "inplaceabn":
29+
if not has_iabn:
30+
raise ImportError(
31+
"Pplease install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
32+
layer = InPlaceABN(num_features, **kwargs)
33+
else:
34+
assert False, "Invalid norm_act layer (%s)" % layer_class
35+
if jit:
36+
layer = torch.jit.script(layer)
37+
return layer

timm/models/layers/evo_norm.py

Lines changed: 24 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,12 @@
1313
import torch.nn as nn
1414

1515

16-
@torch.jit.script
17-
def evo_batch_jit(
18-
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, running_var: torch.Tensor,
19-
momentum: float, training: bool, nonlin: bool, eps: float):
20-
x_type = x.dtype
21-
running_var = running_var.detach() # FIXME why is this needed, it's a buffer?
22-
if training:
23-
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) # FIXME biased, unbiased?
24-
running_var.copy_(momentum * var + (1 - momentum) * running_var)
25-
else:
26-
var = running_var.clone()
27-
28-
if nonlin:
29-
# FIXME biased, unbiased?
30-
d = (x * v.to(x_type)) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(eps).sqrt_().to(dtype=x_type)
31-
d = d.max(var.add(eps).sqrt_().to(dtype=x_type))
32-
x = x / d
33-
return x.mul_(weight).add_(bias)
34-
else:
35-
return x.mul(weight).add_(bias)
36-
37-
3816
class EvoNormBatch2d(nn.Module):
39-
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5, jit=True):
17+
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5):
4018
super(EvoNormBatch2d, self).__init__()
4119
self.momentum = momentum
4220
self.nonlin = nonlin
4321
self.eps = eps
44-
self.jit = jit
4522
param_shape = (1, num_features, 1, 1)
4623
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
4724
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
@@ -58,50 +35,29 @@ def reset_parameters(self):
5835

5936
def forward(self, x):
6037
assert x.dim() == 4, 'expected 4D input'
61-
62-
if self.jit:
63-
return evo_batch_jit(
64-
x, self.v, self.weight, self.bias, self.running_var, self.momentum,
65-
self.training, self.nonlin, self.eps)
38+
x_type = x.dtype
39+
if self.training:
40+
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
41+
self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var)
6642
else:
67-
x_type = x.dtype
68-
if self.training:
69-
var = x.var(dim=(0, 2, 3), keepdim=True)
70-
self.running_var.copy_(self.momentum * var + (1 - self.momentum) * self.running_var)
71-
else:
72-
var = self.running_var.clone()
73-
74-
if self.nonlin:
75-
v = self.v.to(dtype=x_type)
76-
d = (x * v) + x.var(dim=(2, 3), keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
77-
d = d.max(var.add(self.eps).sqrt_().to(dtype=x_type))
78-
x = x / d
79-
return x.mul_(self.weight).add_(self.bias)
80-
else:
81-
return x.mul(self.weight).add_(self.bias)
43+
var = self.running_var.clone()
8244

83-
84-
@torch.jit.script
85-
def evo_sample_jit(
86-
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
87-
groups: int, nonlin: bool, eps: float):
88-
B, C, H, W = x.shape
89-
assert C % groups == 0
90-
if nonlin:
91-
n = (x * v).sigmoid_().reshape(B, groups, -1)
92-
x = x.reshape(B, groups, -1)
93-
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(eps).sqrt_()
94-
x = x.reshape(B, C, H, W)
95-
return x.mul_(weight).add_(bias)
45+
if self.nonlin:
46+
v = self.v.to(dtype=x_type)
47+
d = (x * v) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
48+
d = d.max(var.add_(self.eps).sqrt_().to(dtype=x_type))
49+
x = x / d
50+
return x.mul_(self.weight).add_(self.bias)
51+
else:
52+
return x.mul(self.weight).add_(self.bias)
9653

9754

9855
class EvoNormSample2d(nn.Module):
99-
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5, jit=True):
56+
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5):
10057
super(EvoNormSample2d, self).__init__()
10158
self.nonlin = nonlin
10259
self.groups = groups
10360
self.eps = eps
104-
self.jit = jit
10561
param_shape = (1, num_features, 1, 1)
10662
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
10763
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
@@ -117,18 +73,13 @@ def reset_parameters(self):
11773

11874
def forward(self, x):
11975
assert x.dim() == 4, 'expected 4D input'
120-
121-
if self.jit:
122-
return evo_sample_jit(
123-
x, self.v, self.weight, self.bias, self.groups, self.nonlin, self.eps)
76+
B, C, H, W = x.shape
77+
assert C % self.groups == 0
78+
if self.nonlin:
79+
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
80+
x = x.reshape(B, self.groups, -1)
81+
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(self.eps).sqrt_()
82+
x = x.reshape(B, C, H, W)
83+
return x.mul_(self.weight).add_(self.bias)
12484
else:
125-
B, C, H, W = x.shape
126-
assert C % self.groups == 0
127-
if self.nonlin:
128-
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
129-
x = x.reshape(B, self.groups, -1)
130-
x = n / (x.std(dim=-1, unbiased=False, keepdim=True) + self.eps)
131-
x = x.reshape(B, C, H, W)
132-
return x.mul_(self.weight).add_(self.bias)
133-
else:
134-
return x.mul(self.weight).add_(self.bias)
85+
return x.mul(self.weight).add_(self.bias)

0 commit comments

Comments
 (0)