33Hacked together by / Copyright 2020 Ross Wightman
44"""
55import math
6- from typing import List , Tuple
6+ from typing import List , Tuple , Union
77
88import torch
99import torch .nn .functional as F
1010
11+ from .helpers import to_2tuple
12+
1113
1214# Calculate symmetric padding for a convolution
13- def get_padding (kernel_size : int , stride : int = 1 , dilation : int = 1 , ** _ ) -> int :
15+ def get_padding (kernel_size : int , stride : int = 1 , dilation : int = 1 , ** _ ) -> Union [int , List [int ]]:
16+ if any ([isinstance (v , (tuple , list )) for v in [kernel_size , stride , dilation ]]):
17+ kernel_size , stride , dilation = to_2tuple (kernel_size ), to_2tuple (stride ), to_2tuple (dilation )
18+ return [get_padding (* a ) for a in zip (kernel_size , stride , dilation )]
1419 padding = ((stride - 1 ) + dilation * (kernel_size - 1 )) // 2
1520 return padding
1621
@@ -25,6 +30,9 @@ def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
2530
2631# Can SAME padding for given args be done statically?
2732def is_static_pad (kernel_size : int , stride : int = 1 , dilation : int = 1 , ** _ ):
33+ if any ([isinstance (v , (tuple , list )) for v in [kernel_size , stride , dilation ]]):
34+ kernel_size , stride , dilation = to_2tuple (kernel_size ), to_2tuple (stride ), to_2tuple (dilation )
35+ return all ([is_static_pad (* a ) for a in zip (kernel_size , stride , dilation )])
2836 return stride == 1 and (dilation * (kernel_size - 1 )) % 2 == 0
2937
3038
0 commit comments