1414import torch .nn as nn
1515
1616from .efficientnet_blocks import *
17- from .layers import CondConv2d , get_condconv_initializer
17+ from .layers import CondConv2d , get_condconv_initializer , get_act_layer , make_divisible
1818
19- __all__ = ["EfficientNetBuilder" , "decode_arch_def" , "efficientnet_init_weights" ]
19+ __all__ = ["EfficientNetBuilder" , "decode_arch_def" , "efficientnet_init_weights" ,
20+ 'resolve_bn_args' , 'resolve_act_layer' , 'round_channels' , 'BN_MOMENTUM_TF_DEFAULT' , 'BN_EPS_TF_DEFAULT' ]
2021
2122_logger = logging .getLogger (__name__ )
2223
2324
25+ _DEBUG_BUILDER = False
26+
27+ # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
28+ # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
29+ # NOTE: momentum varies btw .99 and .9997 depending on source
30+ # .99 in official TF TPU impl
31+ # .9997 (/w .999 in search space) for paper
32+ BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
33+ BN_EPS_TF_DEFAULT = 1e-3
34+ _BN_ARGS_TF = dict (momentum = BN_MOMENTUM_TF_DEFAULT , eps = BN_EPS_TF_DEFAULT )
35+
36+
37+ def get_bn_args_tf ():
38+ return _BN_ARGS_TF .copy ()
39+
40+
41+ def resolve_bn_args (kwargs ):
42+ bn_args = get_bn_args_tf () if kwargs .pop ('bn_tf' , False ) else {}
43+ bn_momentum = kwargs .pop ('bn_momentum' , None )
44+ if bn_momentum is not None :
45+ bn_args ['momentum' ] = bn_momentum
46+ bn_eps = kwargs .pop ('bn_eps' , None )
47+ if bn_eps is not None :
48+ bn_args ['eps' ] = bn_eps
49+ return bn_args
50+
51+
52+ def resolve_act_layer (kwargs , default = 'relu' ):
53+ act_layer = kwargs .pop ('act_layer' , default )
54+ if isinstance (act_layer , str ):
55+ act_layer = get_act_layer (act_layer )
56+ return act_layer
57+
58+
59+ def round_channels (channels , multiplier = 1.0 , divisor = 8 , channel_min = None , round_limit = 0.9 ):
60+ """Round number of filters based on depth multiplier."""
61+ if not multiplier :
62+ return channels
63+ return make_divisible (channels * multiplier , divisor , channel_min , round_limit = round_limit )
64+
65+
2466def _log_info_if (msg , condition ):
2567 if condition :
2668 _logger .info (msg )
@@ -63,11 +105,13 @@ def _decode_block_str(block_str):
63105 block_type = ops [0 ] # take the block type off the front
64106 ops = ops [1 :]
65107 options = {}
66- noskip = False
108+ skip = None
67109 for op in ops :
68110 # string options being checked on individual basis, combine if they grow
69111 if op == 'noskip' :
70- noskip = True
112+ skip = False # force no skip connection
113+ elif op == 'skip' :
114+ skip = True # force a skip connection
71115 elif op .startswith ('n' ):
72116 # activation fn
73117 key = op [0 ]
@@ -94,7 +138,7 @@ def _decode_block_str(block_str):
94138 act_layer = options ['n' ] if 'n' in options else None
95139 exp_kernel_size = _parse_ksize (options ['a' ]) if 'a' in options else 1
96140 pw_kernel_size = _parse_ksize (options ['p' ]) if 'p' in options else 1
97- fake_in_chs = int (options ['fc' ]) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
141+ force_in_chs = int (options ['fc' ]) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
98142
99143 num_repeat = int (options ['r' ])
100144 # each type of block has different valid arguments, fill accordingly
@@ -106,10 +150,10 @@ def _decode_block_str(block_str):
106150 pw_kernel_size = pw_kernel_size ,
107151 out_chs = int (options ['c' ]),
108152 exp_ratio = float (options ['e' ]),
109- se_ratio = float (options ['se' ]) if 'se' in options else None ,
153+ se_ratio = float (options ['se' ]) if 'se' in options else 0. ,
110154 stride = int (options ['s' ]),
111155 act_layer = act_layer ,
112- noskip = noskip ,
156+ noskip = skip is False ,
113157 )
114158 if 'cc' in options :
115159 block_args ['num_experts' ] = int (options ['cc' ])
@@ -119,11 +163,11 @@ def _decode_block_str(block_str):
119163 dw_kernel_size = _parse_ksize (options ['k' ]),
120164 pw_kernel_size = pw_kernel_size ,
121165 out_chs = int (options ['c' ]),
122- se_ratio = float (options ['se' ]) if 'se' in options else None ,
166+ se_ratio = float (options ['se' ]) if 'se' in options else 0. ,
123167 stride = int (options ['s' ]),
124168 act_layer = act_layer ,
125169 pw_act = block_type == 'dsa' ,
126- noskip = block_type == 'dsa' or noskip ,
170+ noskip = block_type == 'dsa' or skip is False ,
127171 )
128172 elif block_type == 'er' :
129173 block_args = dict (
@@ -132,11 +176,11 @@ def _decode_block_str(block_str):
132176 pw_kernel_size = pw_kernel_size ,
133177 out_chs = int (options ['c' ]),
134178 exp_ratio = float (options ['e' ]),
135- fake_in_chs = fake_in_chs ,
136- se_ratio = float (options ['se' ]) if 'se' in options else None ,
179+ force_in_chs = force_in_chs ,
180+ se_ratio = float (options ['se' ]) if 'se' in options else 0. ,
137181 stride = int (options ['s' ]),
138182 act_layer = act_layer ,
139- noskip = noskip ,
183+ noskip = skip is False ,
140184 )
141185 elif block_type == 'cn' :
142186 block_args = dict (
@@ -145,6 +189,7 @@ def _decode_block_str(block_str):
145189 out_chs = int (options ['c' ]),
146190 stride = int (options ['s' ]),
147191 act_layer = act_layer ,
192+ skip = skip is True ,
148193 )
149194 else :
150195 assert False , 'Unknown block type (%s)' % block_type
@@ -219,74 +264,63 @@ class EfficientNetBuilder:
219264 https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
220265
221266 """
222- def __init__ (self , channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
223- output_stride = 32 , pad_type = '' , act_layer = None , se_kwargs = None ,
224- norm_layer = nn .BatchNorm2d , norm_kwargs = None , drop_path_rate = 0. , feature_location = '' ,
225- verbose = False ):
226- self .channel_multiplier = channel_multiplier
227- self .channel_divisor = channel_divisor
228- self .channel_min = channel_min
267+ def __init__ (self , output_stride = 32 , pad_type = '' , round_chs_fn = round_channels ,
268+ act_layer = None , norm_layer = None , se_layer = None , drop_path_rate = 0. , feature_location = '' ):
229269 self .output_stride = output_stride
230270 self .pad_type = pad_type
271+ self .round_chs_fn = round_chs_fn
231272 self .act_layer = act_layer
232- self .se_kwargs = se_kwargs
233273 self .norm_layer = norm_layer
234- self .norm_kwargs = norm_kwargs
274+ self .se_layer = se_layer
235275 self .drop_path_rate = drop_path_rate
236276 if feature_location == 'depthwise' :
237277 # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
238278 _logger .warning ("feature_location=='depthwise' is deprecated, using 'expansion'" )
239279 feature_location = 'expansion'
240280 self .feature_location = feature_location
241281 assert feature_location in ('bottleneck' , 'expansion' , '' )
242- self .verbose = verbose
282+ self .verbose = _DEBUG_BUILDER
243283
244284 # state updated during build, consumed by model
245285 self .in_chs = None
246286 self .features = []
247287
248- def _round_channels (self , chs ):
249- return round_channels (chs , self .channel_multiplier , self .channel_divisor , self .channel_min )
250-
251288 def _make_block (self , ba , block_idx , block_count ):
252289 drop_path_rate = self .drop_path_rate * block_idx / block_count
253290 bt = ba .pop ('block_type' )
254291 ba ['in_chs' ] = self .in_chs
255- ba ['out_chs' ] = self ._round_channels (ba ['out_chs' ])
256- if 'fake_in_chs' in ba and ba ['fake_in_chs' ]:
257- # FIXME this is a hack to work around mismatch in origin impl input filters
258- ba ['fake_in_chs' ] = self ._round_channels (ba ['fake_in_chs' ])
259- ba ['norm_layer' ] = self .norm_layer
260- ba ['norm_kwargs' ] = self .norm_kwargs
292+ ba ['out_chs' ] = self .round_chs_fn (ba ['out_chs' ])
293+ if 'force_in_chs' in ba and ba ['force_in_chs' ]:
294+ # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
295+ ba ['force_in_chs' ] = self .round_chs_fn (ba ['force_in_chs' ])
261296 ba ['pad_type' ] = self .pad_type
262297 # block act fn overrides the model default
263298 ba ['act_layer' ] = ba ['act_layer' ] if ba ['act_layer' ] is not None else self .act_layer
264299 assert ba ['act_layer' ] is not None
265- if bt == 'ir' :
300+ ba ['norm_layer' ] = self .norm_layer
301+ if bt != 'cn' :
302+ ba ['se_layer' ] = self .se_layer
266303 ba ['drop_path_rate' ] = drop_path_rate
267- ba ['se_kwargs' ] = self .se_kwargs
304+
305+ if bt == 'ir' :
268306 _log_info_if (' InvertedResidual {}, Args: {}' .format (block_idx , str (ba )), self .verbose )
269307 if ba .get ('num_experts' , 0 ) > 0 :
270308 block = CondConvResidual (** ba )
271309 else :
272310 block = InvertedResidual (** ba )
273311 elif bt == 'ds' or bt == 'dsa' :
274- ba ['drop_path_rate' ] = drop_path_rate
275- ba ['se_kwargs' ] = self .se_kwargs
276312 _log_info_if (' DepthwiseSeparable {}, Args: {}' .format (block_idx , str (ba )), self .verbose )
277313 block = DepthwiseSeparableConv (** ba )
278314 elif bt == 'er' :
279- ba ['drop_path_rate' ] = drop_path_rate
280- ba ['se_kwargs' ] = self .se_kwargs
281315 _log_info_if (' EdgeResidual {}, Args: {}' .format (block_idx , str (ba )), self .verbose )
282316 block = EdgeResidual (** ba )
283317 elif bt == 'cn' :
284318 _log_info_if (' ConvBnAct {}, Args: {}' .format (block_idx , str (ba )), self .verbose )
285319 block = ConvBnAct (** ba )
286320 else :
287321 assert False , 'Uknkown block type (%s) while building model.' % bt
288- self .in_chs = ba ['out_chs' ] # update in_chs for arg of next block
289322
323+ self .in_chs = ba ['out_chs' ] # update in_chs for arg of next block
290324 return block
291325
292326 def __call__ (self , in_chs , model_block_args ):
0 commit comments