Skip to content

Commit 5f4b607

Browse files
committed
Fix inplace arg compat for GELU and PreLU via activation factory
1 parent fd962c4 commit 5f4b607

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

timm/models/layers/activations.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,27 @@ def __init__(self, inplace: bool = False):
119119

120120
def forward(self, x):
121121
return hard_mish(x, self.inplace)
122+
123+
124+
class PReLU(nn.PReLU):
125+
"""Applies PReLU (w/ dummy inplace arg)
126+
"""
127+
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
128+
super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
129+
130+
def forward(self, input: torch.Tensor) -> torch.Tensor:
131+
return F.prelu(input, self.weight)
132+
133+
134+
def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
135+
return F.gelu(x)
136+
137+
138+
class GELU(nn.Module):
139+
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
140+
"""
141+
def __init__(self, inplace: bool = False):
142+
super(GELU, self).__init__()
143+
144+
def forward(self, input: torch.Tensor) -> torch.Tensor:
145+
return F.gelu(input)

timm/models/layers/create_act.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
relu6=F.relu6,
2020
leaky_relu=F.leaky_relu,
2121
elu=F.elu,
22-
prelu=F.prelu,
2322
celu=F.celu,
2423
selu=F.selu,
25-
gelu=F.gelu,
24+
gelu=gelu,
2625
sigmoid=sigmoid,
2726
tanh=tanh,
2827
hard_sigmoid=hard_sigmoid,
@@ -56,10 +55,10 @@
5655
relu6=nn.ReLU6,
5756
leaky_relu=nn.LeakyReLU,
5857
elu=nn.ELU,
59-
prelu=nn.PReLU,
58+
prelu=PReLU,
6059
celu=nn.CELU,
6160
selu=nn.SELU,
62-
gelu=nn.GELU,
61+
gelu=GELU,
6362
sigmoid=Sigmoid,
6463
tanh=Tanh,
6564
hard_sigmoid=HardSigmoid,

0 commit comments

Comments
 (0)