Skip to content

Commit a9eb484

Browse files
committed
Add memory efficient Swish impl
1 parent 187ecba commit a9eb484

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

timm/models/gen_efficientnet.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,30 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
371371
return arch_args
372372

373373

374-
def swish(x, inplace=False):
375-
if inplace:
376-
return x.mul_(x.sigmoid())
377-
else:
378-
return x * x.sigmoid()
374+
_USE_SWISH_OPT = True
375+
if _USE_SWISH_OPT:
376+
class SwishAutoFn(torch.autograd.Function):
377+
""" Memory Efficient Swish
378+
From: https://blog.ceshine.net/post/pytorch-memory-swish/
379+
"""
380+
@staticmethod
381+
def forward(ctx, x):
382+
result = x.mul(torch.sigmoid(x))
383+
ctx.save_for_backward(x)
384+
return result
385+
386+
@staticmethod
387+
def backward(ctx, grad_output):
388+
x = ctx.saved_variables[0]
389+
sigmoid_x = torch.sigmoid(x)
390+
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
391+
392+
393+
def swish(x, inplace=False):
394+
return SwishAutoFn.apply(x)
395+
else:
396+
def swish(x, inplace=False):
397+
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
379398

380399

381400
def sigmoid(x, inplace=False):

0 commit comments

Comments
 (0)