33from torch import nn as nn
44from torch .nn import functional as F
55
6+
67__all__ = ['mish' , 'Mish' , 'mish_jit' , 'MishJit' , 'mish_jit_fwd' , 'mish_jit_bwd' , 'MishJitAutoFn' , 'mish_me' , 'MishMe' ,
78 'hard_mish_jit' , 'HardMishJit' , 'hard_mish_jit_fwd' , 'hard_mish_jit_bwd' , 'HardMishJitAutoFn' ,
89 'hard_mish_me' , 'HardMishMe' ]
910
10- # Cell
11+
1112def mish (x , inplace : bool = False ):
1213 """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
1314 NOTE: I don't have a working inplace variant
@@ -24,7 +25,7 @@ def __init__(self, inplace: bool = False):
2425 def forward (self , x ):
2526 return mish (x )
2627
27- # Cell
28+
2829@torch .jit .script
2930def mish_jit (x , _inplace : bool = False ):
3031 """Jit version of Mish.
@@ -42,7 +43,7 @@ def __init__(self, inplace: bool = False):
4243 def forward (self , x ):
4344 return mish_jit (x )
4445
45- # Cell
46+
4647@torch .jit .script
4748def mish_jit_fwd (x ):
4849 # return x.mul(torch.tanh(F.softplus(x)))
@@ -83,7 +84,7 @@ def __init__(self, inplace: bool = False):
8384 def forward (self , x ):
8485 return MishJitAutoFn .apply (x )
8586
86- # Cell
87+
8788@torch .jit .script
8889def hard_mish_jit (x , inplace : bool = False ):
8990 """ Hard Mish
@@ -104,7 +105,7 @@ def __init__(self, inplace: bool = False):
104105 def forward (self , x ):
105106 return hard_mish_jit (x )
106107
107- # Cell
108+
108109@torch .jit .script
109110def hard_mish_jit_fwd (x ):
110111 return 0.5 * x * (x + 2 ).clamp (min = 0 , max = 2 )
@@ -146,4 +147,4 @@ def __init__(self, inplace: bool = False):
146147 super (HardMishMe , self ).__init__ ()
147148
148149 def forward (self , x ):
149- return HardMishJitAutoFn .apply (x )
150+ return HardMishJitAutoFn .apply (x )
0 commit comments