Skip to content

Commit 151679c

Browse files
committed
Add custom grad tests, fix cut & paste error with hard_mish ME, add a few more pytorch act fns to factory
1 parent 6c7932f commit 151679c

File tree

3 files changed

+84
-2
lines changed

3 files changed

+84
-2
lines changed

tests/test_layers.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
import platform
5+
import os
6+
7+
from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
8+
9+
10+
class MLP(nn.Module):
11+
def __init__(self, act_layer="relu"):
12+
super(MLP, self).__init__()
13+
self.fc1 = nn.Linear(1000, 100)
14+
self.act = create_act_layer(act_layer, inplace=True)
15+
self.fc2 = nn.Linear(100, 10)
16+
17+
def forward(self, x):
18+
x = self.fc1(x)
19+
x = self.act(x)
20+
x = self.fc2(x)
21+
return x
22+
23+
24+
def _run_act_layer_grad(act_type):
25+
x = torch.rand(10, 1000) * 10
26+
m = MLP(act_layer=act_type)
27+
28+
def _run(x, act_layer=''):
29+
if act_layer:
30+
# replace act layer if set
31+
m.act = create_act_layer(act_layer, inplace=True)
32+
out = m(x)
33+
l = (out - 0).pow(2).sum()
34+
return l
35+
36+
out_me = _run(x)
37+
38+
with set_layer_config(scriptable=True):
39+
out_jit = _run(x, act_type)
40+
41+
assert torch.isclose(out_jit, out_me)
42+
43+
with set_layer_config(no_jit=True):
44+
out_basic = _run(x, act_type)
45+
46+
assert torch.isclose(out_basic, out_jit)
47+
48+
49+
def test_swish_grad():
50+
for _ in range(100):
51+
_run_act_layer_grad('swish')
52+
53+
54+
def test_mish_grad():
55+
for _ in range(100):
56+
_run_act_layer_grad('mish')
57+
58+
59+
def test_hard_sigmoid_grad():
60+
for _ in range(100):
61+
_run_act_layer_grad('hard_sigmoid')
62+
63+
64+
def test_hard_swish_grad():
65+
for _ in range(100):
66+
_run_act_layer_grad('hard_swish')
67+
68+
69+
def test_hard_mish_grad():
70+
for _ in range(100):
71+
_run_act_layer_grad('hard_mish')

timm/models/layers/activations_me.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,12 @@ class HardMishJitAutoFn(torch.autograd.Function):
185185
@staticmethod
186186
def forward(ctx, x):
187187
ctx.save_for_backward(x)
188-
return mish_jit_fwd(x)
188+
return hard_mish_jit_fwd(x)
189189

190190
@staticmethod
191191
def backward(ctx, grad_output):
192192
x = ctx.saved_tensors[0]
193-
return mish_jit_bwd(x, grad_output)
193+
return hard_mish_jit_bwd(x, grad_output)
194194

195195

196196
def hard_mish_me(x, inplace: bool = False):

timm/models/layers/create_act.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
mish=mish,
1010
relu=F.relu,
1111
relu6=F.relu6,
12+
leaky_relu=F.leaky_relu,
13+
elu=F.elu,
14+
prelu=F.prelu,
15+
celu=F.celu,
16+
selu=F.selu,
17+
gelu=F.gelu,
1218
sigmoid=sigmoid,
1319
tanh=tanh,
1420
hard_sigmoid=hard_sigmoid,
@@ -37,6 +43,11 @@
3743
mish=Mish,
3844
relu=nn.ReLU,
3945
relu6=nn.ReLU6,
46+
elu=nn.ELU,
47+
prelu=nn.PReLU,
48+
celu=nn.CELU,
49+
selu=nn.SELU,
50+
gelu=nn.GELU,
4051
sigmoid=Sigmoid,
4152
tanh=Tanh,
4253
hard_sigmoid=HardSigmoid,

0 commit comments

Comments
 (0)