@@ -358,22 +358,33 @@ def __call__(self, in_chs, model_block_args):
358358 return stages
359359
360360
361- def _init_weight_goog (m , n = '' ):
361+ def _init_weight_goog (m , n = '' , fix_group_fanout = False ):
362362 """ Weight initialization as per Tensorflow official implementations.
363363
364+ Args:
365+ m (nn.Module): module to init
366+ n (str): module name
367+ fix_group_fanout (bool): enable correct fanout calculation w/ group convs
368+
369+ FIXME change fix_group_fanout to default to True if experiments show better training results
370+
364371 Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
365372 * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
366373 * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
367374 """
368375 if isinstance (m , CondConv2d ):
369376 fan_out = m .kernel_size [0 ] * m .kernel_size [1 ] * m .out_channels
377+ if fix_group_fanout :
378+ fan_out //= m .groups
370379 init_weight_fn = get_condconv_initializer (
371380 lambda w : w .data .normal_ (0 , math .sqrt (2.0 / fan_out )), m .num_experts , m .weight_shape )
372381 init_weight_fn (m .weight )
373382 if m .bias is not None :
374383 m .bias .data .zero_ ()
375384 elif isinstance (m , nn .Conv2d ):
376385 fan_out = m .kernel_size [0 ] * m .kernel_size [1 ] * m .out_channels
386+ if fix_group_fanout :
387+ fan_out //= m .groups
377388 m .weight .data .normal_ (0 , math .sqrt (2.0 / fan_out ))
378389 if m .bias is not None :
379390 m .bias .data .zero_ ()
@@ -390,21 +401,6 @@ def _init_weight_goog(m, n=''):
390401 m .bias .data .zero_ ()
391402
392403
393- def _init_weight_default (m , n = '' ):
394- """ Basic ResNet (Kaiming) style weight init"""
395- if isinstance (m , CondConv2d ):
396- init_fn = get_condconv_initializer (partial (
397- nn .init .kaiming_normal_ , mode = 'fan_out' , nonlinearity = 'relu' ), m .num_experts , m .weight_shape )
398- init_fn (m .weight )
399- elif isinstance (m , nn .Conv2d ):
400- nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
401- elif isinstance (m , nn .BatchNorm2d ):
402- m .weight .data .fill_ (1.0 )
403- m .bias .data .zero_ ()
404- elif isinstance (m , nn .Linear ):
405- nn .init .kaiming_uniform_ (m .weight , mode = 'fan_in' , nonlinearity = 'linear' )
406-
407-
408404def efficientnet_init_weights (model : nn .Module , init_fn = None ):
409405 init_fn = init_fn or _init_weight_goog
410406 for n , m in model .named_modules ():
0 commit comments