77import torch .nn as nn
88from torch .nn import functional as F
99
10- from .layers import create_conv2d , drop_path , make_divisible , get_act_fn , create_act_layer
10+ from .layers import create_conv2d , drop_path , make_divisible , create_act_layer
1111from .layers .activations import sigmoid
1212
1313__all__ = [
@@ -19,31 +19,32 @@ class SqueezeExcite(nn.Module):
1919
2020 Args:
2121 in_chs (int): input channels to layer
22- se_ratio (float): ratio of squeeze reduction
22+ rd_ratio (float): ratio of squeeze reduction
2323 act_layer (nn.Module): activation layer of containing block
24- gate_fn (Callable): attention gate function
24+ gate_layer (Callable): attention gate function
2525 force_act_layer (nn.Module): override block's activation fn if this is set/bound
26- round_chs_fn (Callable): specify a fn to calculate rounding of reduced chs
26+ rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs
2727 """
2828
2929 def __init__ (
30- self , in_chs , se_ratio = 0.25 , act_layer = nn .ReLU , gate_fn = sigmoid ,
31- force_act_layer = None , round_chs_fn = None ):
30+ self , in_chs , rd_ratio = 0.25 , rd_channels = None , act_layer = nn .ReLU ,
31+ gate_layer = nn . Sigmoid , force_act_layer = None , rd_round_fn = None ):
3232 super (SqueezeExcite , self ).__init__ ()
33- round_chs_fn = round_chs_fn or round
34- reduced_chs = round_chs_fn (in_chs * se_ratio )
33+ if rd_channels is None :
34+ rd_round_fn = rd_round_fn or round
35+ rd_channels = rd_round_fn (in_chs * rd_ratio )
3536 act_layer = force_act_layer or act_layer
36- self .conv_reduce = nn .Conv2d (in_chs , reduced_chs , 1 , bias = True )
37+ self .conv_reduce = nn .Conv2d (in_chs , rd_channels , 1 , bias = True )
3738 self .act1 = create_act_layer (act_layer , inplace = True )
38- self .conv_expand = nn .Conv2d (reduced_chs , in_chs , 1 , bias = True )
39- self .gate_fn = get_act_fn ( gate_fn )
39+ self .conv_expand = nn .Conv2d (rd_channels , in_chs , 1 , bias = True )
40+ self .gate = create_act_layer ( gate_layer )
4041
4142 def forward (self , x ):
4243 x_se = x .mean ((2 , 3 ), keepdim = True )
4344 x_se = self .conv_reduce (x_se )
4445 x_se = self .act1 (x_se )
4546 x_se = self .conv_expand (x_se )
46- return x * self .gate_fn (x_se )
47+ return x * self .gate (x_se )
4748
4849
4950class ConvBnAct (nn .Module ):
@@ -85,10 +86,9 @@ class DepthwiseSeparableConv(nn.Module):
8586 """
8687 def __init__ (
8788 self , in_chs , out_chs , dw_kernel_size = 3 , stride = 1 , dilation = 1 , pad_type = '' ,
88- noskip = False , pw_kernel_size = 1 , pw_act = False , se_ratio = 0. ,
89- act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d , se_layer = None , drop_path_rate = 0. ):
89+ noskip = False , pw_kernel_size = 1 , pw_act = False , act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d ,
90+ se_layer = None , drop_path_rate = 0. ):
9091 super (DepthwiseSeparableConv , self ).__init__ ()
91- has_se = se_layer is not None and se_ratio > 0.
9292 self .has_residual = (stride == 1 and in_chs == out_chs ) and not noskip
9393 self .has_pw_act = pw_act # activation after point-wise conv
9494 self .drop_path_rate = drop_path_rate
@@ -99,7 +99,7 @@ def __init__(
9999 self .act1 = act_layer (inplace = True )
100100
101101 # Squeeze-and-excitation
102- self .se = se_layer (in_chs , se_ratio = se_ratio , act_layer = act_layer ) if has_se else nn .Identity ()
102+ self .se = se_layer (in_chs , act_layer = act_layer ) if se_layer else nn .Identity ()
103103
104104 self .conv_pw = create_conv2d (in_chs , out_chs , pw_kernel_size , padding = pad_type )
105105 self .bn2 = norm_layer (out_chs )
@@ -144,12 +144,11 @@ class InvertedResidual(nn.Module):
144144
145145 def __init__ (
146146 self , in_chs , out_chs , dw_kernel_size = 3 , stride = 1 , dilation = 1 , pad_type = '' ,
147- noskip = False , exp_ratio = 1.0 , exp_kernel_size = 1 , pw_kernel_size = 1 , se_ratio = 0. ,
148- act_layer = nn . ReLU , norm_layer = nn .BatchNorm2d , se_layer = None , conv_kwargs = None , drop_path_rate = 0. ):
147+ noskip = False , exp_ratio = 1.0 , exp_kernel_size = 1 , pw_kernel_size = 1 , act_layer = nn . ReLU ,
148+ norm_layer = nn .BatchNorm2d , se_layer = None , conv_kwargs = None , drop_path_rate = 0. ):
149149 super (InvertedResidual , self ).__init__ ()
150150 conv_kwargs = conv_kwargs or {}
151151 mid_chs = make_divisible (in_chs * exp_ratio )
152- has_se = se_layer is not None and se_ratio > 0.
153152 self .has_residual = (in_chs == out_chs and stride == 1 ) and not noskip
154153 self .drop_path_rate = drop_path_rate
155154
@@ -166,7 +165,7 @@ def __init__(
166165 self .act2 = act_layer (inplace = True )
167166
168167 # Squeeze-and-excitation
169- self .se = se_layer (mid_chs , se_ratio = se_ratio , act_layer = act_layer ) if has_se else nn .Identity ()
168+ self .se = se_layer (mid_chs , act_layer = act_layer ) if se_layer else nn .Identity ()
170169
171170 # Point-wise linear projection
172171 self .conv_pwl = create_conv2d (mid_chs , out_chs , pw_kernel_size , padding = pad_type , ** conv_kwargs )
@@ -212,17 +211,17 @@ class CondConvResidual(InvertedResidual):
212211
213212 def __init__ (
214213 self , in_chs , out_chs , dw_kernel_size = 3 , stride = 1 , dilation = 1 , pad_type = '' ,
215- noskip = False , exp_ratio = 1.0 , exp_kernel_size = 1 , pw_kernel_size = 1 , se_ratio = 0. ,
216- act_layer = nn . ReLU , norm_layer = nn .BatchNorm2d , se_layer = None , num_experts = 0 , drop_path_rate = 0. ):
214+ noskip = False , exp_ratio = 1.0 , exp_kernel_size = 1 , pw_kernel_size = 1 , act_layer = nn . ReLU ,
215+ norm_layer = nn .BatchNorm2d , se_layer = None , num_experts = 0 , drop_path_rate = 0. ):
217216
218217 self .num_experts = num_experts
219218 conv_kwargs = dict (num_experts = self .num_experts )
220219
221220 super (CondConvResidual , self ).__init__ (
222221 in_chs , out_chs , dw_kernel_size = dw_kernel_size , stride = stride , dilation = dilation , pad_type = pad_type ,
223222 act_layer = act_layer , noskip = noskip , exp_ratio = exp_ratio , exp_kernel_size = exp_kernel_size ,
224- pw_kernel_size = pw_kernel_size , se_ratio = se_ratio , se_layer = se_layer ,
225- norm_layer = norm_layer , conv_kwargs = conv_kwargs , drop_path_rate = drop_path_rate )
223+ pw_kernel_size = pw_kernel_size , se_layer = se_layer , norm_layer = norm_layer , conv_kwargs = conv_kwargs ,
224+ drop_path_rate = drop_path_rate )
226225
227226 self .routing_fn = nn .Linear (in_chs , self .num_experts )
228227
@@ -271,8 +270,8 @@ class EdgeResidual(nn.Module):
271270
272271 def __init__ (
273272 self , in_chs , out_chs , exp_kernel_size = 3 , stride = 1 , dilation = 1 , pad_type = '' ,
274- force_in_chs = 0 , noskip = False , exp_ratio = 1.0 , pw_kernel_size = 1 , se_ratio = 0. ,
275- act_layer = nn . ReLU , norm_layer = nn .BatchNorm2d , se_layer = None , drop_path_rate = 0. ):
273+ force_in_chs = 0 , noskip = False , exp_ratio = 1.0 , pw_kernel_size = 1 , act_layer = nn . ReLU ,
274+ norm_layer = nn .BatchNorm2d , se_layer = None , drop_path_rate = 0. ):
276275 super (EdgeResidual , self ).__init__ ()
277276 if force_in_chs > 0 :
278277 mid_chs = make_divisible (force_in_chs * exp_ratio )
@@ -289,7 +288,7 @@ def __init__(
289288 self .act1 = act_layer (inplace = True )
290289
291290 # Squeeze-and-excitation
292- self .se = SqueezeExcite (mid_chs , se_ratio = se_ratio , act_layer = act_layer ) if has_se else nn .Identity ()
291+ self .se = se_layer (mid_chs , act_layer = act_layer ) if se_layer else nn .Identity ()
293292
294293 # Point-wise linear projection
295294 self .conv_pwl = create_conv2d (mid_chs , out_chs , pw_kernel_size , padding = pad_type )
0 commit comments