77import torch .nn as nn
88from torch .nn import functional as F
99
10- from .layers import create_conv2d , drop_path , get_act_layer
10+ from .layers import create_conv2d , drop_path , make_divisible
1111from .layers .activations import sigmoid
1212
13- # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
14- # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
15- # NOTE: momentum varies btw .99 and .9997 depending on source
16- # .99 in official TF TPU impl
17- # .9997 (/w .999 in search space) for paper
18- BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
19- BN_EPS_TF_DEFAULT = 1e-3
20- _BN_ARGS_TF = dict (momentum = BN_MOMENTUM_TF_DEFAULT , eps = BN_EPS_TF_DEFAULT )
21-
22-
23- def get_bn_args_tf ():
24- return _BN_ARGS_TF .copy ()
25-
26-
27- def resolve_bn_args (kwargs ):
28- bn_args = get_bn_args_tf () if kwargs .pop ('bn_tf' , False ) else {}
29- bn_momentum = kwargs .pop ('bn_momentum' , None )
30- if bn_momentum is not None :
31- bn_args ['momentum' ] = bn_momentum
32- bn_eps = kwargs .pop ('bn_eps' , None )
33- if bn_eps is not None :
34- bn_args ['eps' ] = bn_eps
35- return bn_args
36-
37-
38- _SE_ARGS_DEFAULT = dict (
39- gate_fn = sigmoid ,
40- act_layer = None ,
41- reduce_mid = False ,
42- divisor = 1 )
43-
44-
45- def resolve_se_args (kwargs , in_chs , act_layer = None ):
46- se_kwargs = kwargs .copy () if kwargs is not None else {}
47- # fill in args that aren't specified with the defaults
48- for k , v in _SE_ARGS_DEFAULT .items ():
49- se_kwargs .setdefault (k , v )
50- # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
51- if not se_kwargs .pop ('reduce_mid' ):
52- se_kwargs ['reduced_base_chs' ] = in_chs
53- # act_layer override, if it remains None, the containing block's act_layer will be used
54- if se_kwargs ['act_layer' ] is None :
55- assert act_layer is not None
56- se_kwargs ['act_layer' ] = act_layer
57- return se_kwargs
58-
59-
60- def resolve_act_layer (kwargs , default = 'relu' ):
61- act_layer = kwargs .pop ('act_layer' , default )
62- if isinstance (act_layer , str ):
63- act_layer = get_act_layer (act_layer )
64- return act_layer
65-
66-
67- def make_divisible (v , divisor = 8 , min_value = None ):
68- min_value = min_value or divisor
69- new_v = max (min_value , int (v + divisor / 2 ) // divisor * divisor )
70- # Make sure that round down does not go down by more than 10%.
71- if new_v < 0.9 * v :
72- new_v += divisor
73- return new_v
74-
75-
76- def round_channels (channels , multiplier = 1.0 , divisor = 8 , channel_min = None ):
77- """Round number of filters based on depth multiplier."""
78- if not multiplier :
79- return channels
80- channels *= multiplier
81- return make_divisible (channels , divisor , channel_min )
82-
83-
84- class ChannelShuffle (nn .Module ):
85- # FIXME haven't used yet
86- def __init__ (self , groups ):
87- super (ChannelShuffle , self ).__init__ ()
88- self .groups = groups
89-
90- def forward (self , x ):
91- """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
92- N , C , H , W = x .size ()
93- g = self .groups
94- assert C % g == 0 , "Incompatible group size {} for input channel {}" .format (
95- g , C
96- )
97- return (
98- x .view (N , g , int (C / g ), H , W )
99- .permute (0 , 2 , 1 , 3 , 4 )
100- .contiguous ()
101- .view (N , C , H , W )
102- )
13+ __all__ = [
14+ 'SqueezeExcite' , 'ConvBnAct' , 'DepthwiseSeparableConv' , 'InvertedResidual' , 'CondConvResidual' , 'EdgeResidual' ]
10315
10416
10517class SqueezeExcite (nn .Module ):
106- def __init__ (self , in_chs , se_ratio = 0.25 , reduced_base_chs = None ,
107- act_layer = nn .ReLU , gate_fn = sigmoid , divisor = 1 , ** _ ):
18+ """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family
19+
20+ Args:
21+ in_chs (int): input channels to layer
22+ se_ratio (float): ratio of squeeze reduction
23+ act_layer (nn.Module): activation layer of containing block
24+ gate_fn (Callable): attention gate function
25+ block_in_chs (int): input channels of containing block (for calculating reduction from)
26+ reduce_from_block (bool): calculate reduction from block input channels if True
27+ force_act_layer (nn.Module): override block's activation fn if this is set/bound
28+ divisor (int): make reduction channels divisible by this
29+ """
30+
31+ def __init__ (
32+ self , in_chs , se_ratio = 0.25 , act_layer = nn .ReLU , gate_fn = sigmoid ,
33+ block_in_chs = None , reduce_from_block = True , force_act_layer = None , divisor = 1 ):
10834 super (SqueezeExcite , self ).__init__ ()
109- reduced_chs = make_divisible ((reduced_base_chs or in_chs ) * se_ratio , divisor )
35+ reduced_chs = (block_in_chs or in_chs ) if reduce_from_block else in_chs
36+ reduced_chs = make_divisible (reduced_chs * se_ratio , divisor )
37+ act_layer = force_act_layer or act_layer
11038 self .conv_reduce = nn .Conv2d (in_chs , reduced_chs , 1 , bias = True )
11139 self .act1 = act_layer (inplace = True )
11240 self .conv_expand = nn .Conv2d (reduced_chs , in_chs , 1 , bias = True )
@@ -121,13 +49,16 @@ def forward(self, x):
12149
12250
12351class ConvBnAct (nn .Module ):
124- def __init__ (self , in_chs , out_chs , kernel_size ,
125- stride = 1 , dilation = 1 , pad_type = '' , act_layer = nn .ReLU ,
126- norm_layer = nn .BatchNorm2d , norm_kwargs = None ):
52+ """ Conv + Norm Layer + Activation w/ optional skip connection
53+ """
54+ def __init__ (
55+ self , in_chs , out_chs , kernel_size , stride = 1 , dilation = 1 , pad_type = '' ,
56+ skip = False , act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , drop_path_rate = 0. ):
12757 super (ConvBnAct , self ).__init__ ()
128- norm_kwargs = norm_kwargs or {}
58+ self .has_residual = skip and stride == 1 and in_chs == out_chs
59+ self .drop_path_rate = drop_path_rate
12960 self .conv = create_conv2d (in_chs , out_chs , kernel_size , stride = stride , dilation = dilation , padding = pad_type )
130- self .bn1 = norm_layer (out_chs , ** norm_kwargs )
61+ self .bn1 = norm_layer (out_chs )
13162 self .act1 = act_layer (inplace = True )
13263
13364 def feature_info (self , location ):
@@ -138,9 +69,14 @@ def feature_info(self, location):
13869 return info
13970
14071 def forward (self , x ):
72+ shortcut = x
14173 x = self .conv (x )
14274 x = self .bn1 (x )
14375 x = self .act1 (x )
76+ if self .has_residual :
77+ if self .drop_path_rate > 0. :
78+ x = drop_path (x , self .drop_path_rate , self .training )
79+ x += shortcut
14480 return x
14581
14682
@@ -149,31 +85,26 @@ class DepthwiseSeparableConv(nn.Module):
14985 Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
15086 (factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
15187 """
152- def __init__ (self , in_chs , out_chs , dw_kernel_size = 3 ,
153- stride = 1 , dilation = 1 , pad_type = '' , act_layer = nn . ReLU , noskip = False ,
154- pw_kernel_size = 1 , pw_act = False , se_ratio = 0. , se_kwargs = None ,
155- norm_layer = nn .BatchNorm2d , norm_kwargs = None , drop_path_rate = 0. ):
88+ def __init__ (
89+ self , in_chs , out_chs , dw_kernel_size = 3 , stride = 1 , dilation = 1 , pad_type = '' ,
90+ noskip = False , pw_kernel_size = 1 , pw_act = False , se_ratio = 0. ,
91+ act_layer = nn . ReLU , norm_layer = nn .BatchNorm2d , se_layer = None , drop_path_rate = 0. ):
15692 super (DepthwiseSeparableConv , self ).__init__ ()
157- norm_kwargs = norm_kwargs or {}
158- has_se = se_ratio is not None and se_ratio > 0.
93+ has_se = se_layer is not None and se_ratio > 0.
15994 self .has_residual = (stride == 1 and in_chs == out_chs ) and not noskip
16095 self .has_pw_act = pw_act # activation after point-wise conv
16196 self .drop_path_rate = drop_path_rate
16297
16398 self .conv_dw = create_conv2d (
16499 in_chs , in_chs , dw_kernel_size , stride = stride , dilation = dilation , padding = pad_type , depthwise = True )
165- self .bn1 = norm_layer (in_chs , ** norm_kwargs )
100+ self .bn1 = norm_layer (in_chs )
166101 self .act1 = act_layer (inplace = True )
167102
168103 # Squeeze-and-excitation
169- if has_se :
170- se_kwargs = resolve_se_args (se_kwargs , in_chs , act_layer )
171- self .se = SqueezeExcite (in_chs , se_ratio = se_ratio , ** se_kwargs )
172- else :
173- self .se = None
104+ self .se = se_layer (in_chs , se_ratio = se_ratio , act_layer = act_layer ) if has_se else nn .Identity ()
174105
175106 self .conv_pw = create_conv2d (in_chs , out_chs , pw_kernel_size , padding = pad_type )
176- self .bn2 = norm_layer (out_chs , ** norm_kwargs )
107+ self .bn2 = norm_layer (out_chs )
177108 self .act2 = act_layer (inplace = True ) if self .has_pw_act else nn .Identity ()
178109
179110 def feature_info (self , location ):
@@ -190,8 +121,7 @@ def forward(self, x):
190121 x = self .bn1 (x )
191122 x = self .act1 (x )
192123
193- if self .se is not None :
194- x = self .se (x )
124+ x = self .se (x )
195125
196126 x = self .conv_pw (x )
197127 x = self .bn2 (x )
@@ -214,41 +144,36 @@ class InvertedResidual(nn.Module):
214144 * MobileNet-V3 - https://arxiv.org/abs/1905.02244
215145 """
216146
217- def __init__ (self , in_chs , out_chs , dw_kernel_size = 3 ,
218- stride = 1 , dilation = 1 , pad_type = '' , act_layer = nn .ReLU , noskip = False ,
219- exp_ratio = 1.0 , exp_kernel_size = 1 , pw_kernel_size = 1 ,
220- se_ratio = 0. , se_kwargs = None , norm_layer = nn .BatchNorm2d , norm_kwargs = None ,
221- conv_kwargs = None , drop_path_rate = 0. ):
147+ def __init__ (
148+ self , in_chs , out_chs , dw_kernel_size = 3 , stride = 1 , dilation = 1 , pad_type = '' ,
149+ noskip = False , exp_ratio = 1.0 , exp_kernel_size = 1 , pw_kernel_size = 1 , se_ratio = 0. ,
150+ act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , se_layer = None , conv_kwargs = None , drop_path_rate = 0. ):
222151 super (InvertedResidual , self ).__init__ ()
223- norm_kwargs = norm_kwargs or {}
224152 conv_kwargs = conv_kwargs or {}
225153 mid_chs = make_divisible (in_chs * exp_ratio )
226- has_se = se_ratio is not None and se_ratio > 0.
154+ has_se = se_layer is not None and se_ratio > 0.
227155 self .has_residual = (in_chs == out_chs and stride == 1 ) and not noskip
228156 self .drop_path_rate = drop_path_rate
229157
230158 # Point-wise expansion
231159 self .conv_pw = create_conv2d (in_chs , mid_chs , exp_kernel_size , padding = pad_type , ** conv_kwargs )
232- self .bn1 = norm_layer (mid_chs , ** norm_kwargs )
160+ self .bn1 = norm_layer (mid_chs )
233161 self .act1 = act_layer (inplace = True )
234162
235163 # Depth-wise convolution
236164 self .conv_dw = create_conv2d (
237165 mid_chs , mid_chs , dw_kernel_size , stride = stride , dilation = dilation ,
238166 padding = pad_type , depthwise = True , ** conv_kwargs )
239- self .bn2 = norm_layer (mid_chs , ** norm_kwargs )
167+ self .bn2 = norm_layer (mid_chs )
240168 self .act2 = act_layer (inplace = True )
241169
242170 # Squeeze-and-excitation
243- if has_se :
244- se_kwargs = resolve_se_args (se_kwargs , in_chs , act_layer )
245- self .se = SqueezeExcite (mid_chs , se_ratio = se_ratio , ** se_kwargs )
246- else :
247- self .se = None
171+ self .se = se_layer (
172+ mid_chs , se_ratio = se_ratio , act_layer = act_layer , block_in_chs = in_chs ) if has_se else nn .Identity ()
248173
249174 # Point-wise linear projection
250175 self .conv_pwl = create_conv2d (mid_chs , out_chs , pw_kernel_size , padding = pad_type , ** conv_kwargs )
251- self .bn3 = norm_layer (out_chs , ** norm_kwargs )
176+ self .bn3 = norm_layer (out_chs )
252177
253178 def feature_info (self , location ):
254179 if location == 'expansion' : # after SE, input to PWL
@@ -271,8 +196,7 @@ def forward(self, x):
271196 x = self .act2 (x )
272197
273198 # Squeeze-and-excitation
274- if self .se is not None :
275- x = self .se (x )
199+ x = self .se (x )
276200
277201 # Point-wise linear projection
278202 x = self .conv_pwl (x )
@@ -289,21 +213,19 @@ def forward(self, x):
289213class CondConvResidual (InvertedResidual ):
290214 """ Inverted residual block w/ CondConv routing"""
291215
292- def __init__ (self , in_chs , out_chs , dw_kernel_size = 3 ,
293- stride = 1 , dilation = 1 , pad_type = '' , act_layer = nn .ReLU , noskip = False ,
294- exp_ratio = 1.0 , exp_kernel_size = 1 , pw_kernel_size = 1 ,
295- se_ratio = 0. , se_kwargs = None , norm_layer = nn .BatchNorm2d , norm_kwargs = None ,
296- num_experts = 0 , drop_path_rate = 0. ):
216+ def __init__ (
217+ self , in_chs , out_chs , dw_kernel_size = 3 , stride = 1 , dilation = 1 , pad_type = '' ,
218+ noskip = False , exp_ratio = 1.0 , exp_kernel_size = 1 , pw_kernel_size = 1 , se_ratio = 0. ,
219+ act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , se_layer = None , num_experts = 0 , drop_path_rate = 0. ):
297220
298221 self .num_experts = num_experts
299222 conv_kwargs = dict (num_experts = self .num_experts )
300223
301224 super (CondConvResidual , self ).__init__ (
302225 in_chs , out_chs , dw_kernel_size = dw_kernel_size , stride = stride , dilation = dilation , pad_type = pad_type ,
303226 act_layer = act_layer , noskip = noskip , exp_ratio = exp_ratio , exp_kernel_size = exp_kernel_size ,
304- pw_kernel_size = pw_kernel_size , se_ratio = se_ratio , se_kwargs = se_kwargs ,
305- norm_layer = norm_layer , norm_kwargs = norm_kwargs , conv_kwargs = conv_kwargs ,
306- drop_path_rate = drop_path_rate )
227+ pw_kernel_size = pw_kernel_size , se_ratio = se_ratio , se_layer = se_layer ,
228+ norm_layer = norm_layer , conv_kwargs = conv_kwargs , drop_path_rate = drop_path_rate )
307229
308230 self .routing_fn = nn .Linear (in_chs , self .num_experts )
309231
@@ -325,8 +247,7 @@ def forward(self, x):
325247 x = self .act2 (x )
326248
327249 # Squeeze-and-excitation
328- if self .se is not None :
329- x = self .se (x )
250+ x = self .se (x )
330251
331252 # Point-wise linear projection
332253 x = self .conv_pwl (x , routing_weights )
@@ -351,36 +272,32 @@ class EdgeResidual(nn.Module):
351272 * EfficientNet-V2 - https://arxiv.org/abs/2104.00298
352273 """
353274
354- def __init__ (self , in_chs , out_chs , exp_kernel_size = 3 , exp_ratio = 1.0 , fake_in_chs = 0 ,
355- stride = 1 , dilation = 1 , pad_type = '' , act_layer = nn . ReLU , noskip = False , pw_kernel_size = 1 ,
356- se_ratio = 0. , se_kwargs = None , norm_layer = nn . BatchNorm2d , norm_kwargs = None ,
357- drop_path_rate = 0. ):
275+ def __init__ (
276+ self , in_chs , out_chs , exp_kernel_size = 3 , stride = 1 , dilation = 1 , pad_type = '' ,
277+ force_in_chs = 0 , noskip = False , exp_ratio = 1.0 , pw_kernel_size = 1 , se_ratio = 0. ,
278+ act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d , se_layer = None , drop_path_rate = 0. ):
358279 super (EdgeResidual , self ).__init__ ()
359- norm_kwargs = norm_kwargs or {}
360- if fake_in_chs > 0 :
361- mid_chs = make_divisible (fake_in_chs * exp_ratio )
280+ if force_in_chs > 0 :
281+ mid_chs = make_divisible (force_in_chs * exp_ratio )
362282 else :
363283 mid_chs = make_divisible (in_chs * exp_ratio )
364- has_se = se_ratio is not None and se_ratio > 0.
284+ has_se = se_layer is not None and se_ratio > 0.
365285 self .has_residual = (in_chs == out_chs and stride == 1 ) and not noskip
366286 self .drop_path_rate = drop_path_rate
367287
368288 # Expansion convolution
369289 self .conv_exp = create_conv2d (
370290 in_chs , mid_chs , exp_kernel_size , stride = stride , dilation = dilation , padding = pad_type )
371- self .bn1 = norm_layer (mid_chs , ** norm_kwargs )
291+ self .bn1 = norm_layer (mid_chs )
372292 self .act1 = act_layer (inplace = True )
373293
374294 # Squeeze-and-excitation
375- if has_se :
376- se_kwargs = resolve_se_args (se_kwargs , in_chs , act_layer )
377- self .se = SqueezeExcite (mid_chs , se_ratio = se_ratio , ** se_kwargs )
378- else :
379- self .se = None
295+ self .se = SqueezeExcite (
296+ mid_chs , se_ratio = se_ratio , act_layer = act_layer , block_in_chs = in_chs ) if has_se else nn .Identity ()
380297
381298 # Point-wise linear projection
382299 self .conv_pwl = create_conv2d (mid_chs , out_chs , pw_kernel_size , padding = pad_type )
383- self .bn2 = norm_layer (out_chs , ** norm_kwargs )
300+ self .bn2 = norm_layer (out_chs )
384301
385302 def feature_info (self , location ):
386303 if location == 'expansion' : # after SE, before PWL
@@ -398,8 +315,7 @@ def forward(self, x):
398315 x = self .act1 (x )
399316
400317 # Squeeze-and-excitation
401- if self .se is not None :
402- x = self .se (x )
318+ x = self .se (x )
403319
404320 # Point-wise linear projection
405321 x = self .conv_pwl (x )
0 commit comments