Skip to content

Commit ff930ec

Browse files
committed
fix activations
1 parent 9e6386b commit ff930ec

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

model_constructor/activations.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from torch import nn as nn
44
from torch.nn import functional as F
55

6+
67
__all__ = ['mish', 'Mish', 'mish_jit', 'MishJit', 'mish_jit_fwd', 'mish_jit_bwd', 'MishJitAutoFn', 'mish_me', 'MishMe',
78
'hard_mish_jit', 'HardMishJit', 'hard_mish_jit_fwd', 'hard_mish_jit_bwd', 'HardMishJitAutoFn',
89
'hard_mish_me', 'HardMishMe']
910

10-
# Cell
11+
1112
def mish(x, inplace: bool = False):
1213
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
1314
NOTE: I don't have a working inplace variant
@@ -24,7 +25,7 @@ def __init__(self, inplace: bool = False):
2425
def forward(self, x):
2526
return mish(x)
2627

27-
# Cell
28+
2829
@torch.jit.script
2930
def mish_jit(x, _inplace: bool = False):
3031
"""Jit version of Mish.
@@ -42,7 +43,7 @@ def __init__(self, inplace: bool = False):
4243
def forward(self, x):
4344
return mish_jit(x)
4445

45-
# Cell
46+
4647
@torch.jit.script
4748
def mish_jit_fwd(x):
4849
# return x.mul(torch.tanh(F.softplus(x)))
@@ -83,7 +84,7 @@ def __init__(self, inplace: bool = False):
8384
def forward(self, x):
8485
return MishJitAutoFn.apply(x)
8586

86-
# Cell
87+
8788
@torch.jit.script
8889
def hard_mish_jit(x, inplace: bool = False):
8990
""" Hard Mish
@@ -104,7 +105,7 @@ def __init__(self, inplace: bool = False):
104105
def forward(self, x):
105106
return hard_mish_jit(x)
106107

107-
# Cell
108+
108109
@torch.jit.script
109110
def hard_mish_jit_fwd(x):
110111
return 0.5 * x * (x + 2).clamp(min=0, max=2)
@@ -146,4 +147,4 @@ def __init__(self, inplace: bool = False):
146147
super(HardMishMe, self).__init__()
147148

148149
def forward(self, x):
149-
return HardMishJitAutoFn.apply(x)
150+
return HardMishJitAutoFn.apply(x)

0 commit comments

Comments
 (0)