3333import torch .nn as nn
3434
3535from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
36- from .helpers import build_model_with_cfg
36+ from .helpers import build_model_with_cfg , named_apply
3737from .layers import ClassifierHead , ConvBnAct , BatchNormAct2d , DropPath , AvgPool2dSame , \
3838 create_conv2d , get_act_layer , convert_norm_act , get_attn , make_divisible , to_2tuple
3939from .registry import register_model
@@ -166,7 +166,7 @@ class ByoModelCfg:
166166 stem_chs : int = 32
167167 width_factor : float = 1.0
168168 num_features : int = 0 # num out_channels for final conv, no final 1x1 conv if 0
169- zero_init_last_bn : bool = True
169+ zero_init_last : bool = True # zero init last weight (usually bn) in residual path
170170 fixed_input_size : bool = False # model constrained to a fixed-input size / img_size must be provided on creation
171171
172172 act_layer : str = 'relu'
@@ -757,8 +757,8 @@ def __init__(
757757 self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0. else nn .Identity ()
758758 self .act = nn .Identity () if linear_out else layers .act (inplace = True )
759759
760- def init_weights (self , zero_init_last_bn : bool = False ):
761- if zero_init_last_bn :
760+ def init_weights (self , zero_init_last : bool = False ):
761+ if zero_init_last :
762762 nn .init .zeros_ (self .conv2_kxk .bn .weight )
763763 for attn in (self .attn , self .attn_last ):
764764 if hasattr (attn , 'reset_parameters' ):
@@ -814,8 +814,8 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
814814 self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0. else nn .Identity ()
815815 self .act = nn .Identity () if linear_out else layers .act (inplace = True )
816816
817- def init_weights (self , zero_init_last_bn : bool = False ):
818- if zero_init_last_bn :
817+ def init_weights (self , zero_init_last : bool = False ):
818+ if zero_init_last :
819819 nn .init .zeros_ (self .conv3_1x1 .bn .weight )
820820 for attn in (self .attn , self .attn_last ):
821821 if hasattr (attn , 'reset_parameters' ):
@@ -871,8 +871,8 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
871871 self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0. else nn .Identity ()
872872 self .act = nn .Identity () if linear_out else layers .act (inplace = True )
873873
874- def init_weights (self , zero_init_last_bn : bool = False ):
875- if zero_init_last_bn :
874+ def init_weights (self , zero_init_last : bool = False ):
875+ if zero_init_last :
876876 nn .init .zeros_ (self .conv2_kxk .bn .weight )
877877 for attn in (self .attn , self .attn_last ):
878878 if hasattr (attn , 'reset_parameters' ):
@@ -924,8 +924,8 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
924924 self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0. else nn .Identity ()
925925 self .act = nn .Identity () if linear_out else layers .act (inplace = True )
926926
927- def init_weights (self , zero_init_last_bn : bool = False ):
928- if zero_init_last_bn :
927+ def init_weights (self , zero_init_last : bool = False ):
928+ if zero_init_last :
929929 nn .init .zeros_ (self .conv2_1x1 .bn .weight )
930930 for attn in (self .attn , self .attn_last ):
931931 if hasattr (attn , 'reset_parameters' ):
@@ -967,7 +967,7 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
967967 self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0. and use_ident else nn .Identity ()
968968 self .act = layers .act (inplace = True )
969969
970- def init_weights (self , zero_init_last_bn : bool = False ):
970+ def init_weights (self , zero_init_last : bool = False ):
971971 # NOTE this init overrides that base model init with specific changes for the block type
972972 for m in self .modules ():
973973 if isinstance (m , nn .BatchNorm2d ):
@@ -1024,8 +1024,8 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
10241024 self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0. else nn .Identity ()
10251025 self .act = nn .Identity () if linear_out else layers .act (inplace = True )
10261026
1027- def init_weights (self , zero_init_last_bn : bool = False ):
1028- if zero_init_last_bn :
1027+ def init_weights (self , zero_init_last : bool = False ):
1028+ if zero_init_last :
10291029 nn .init .zeros_ (self .conv3_1x1 .bn .weight )
10301030 if hasattr (self .self_attn , 'reset_parameters' ):
10311031 self .self_attn .reset_parameters ()
@@ -1278,7 +1278,7 @@ class ByobNet(nn.Module):
12781278 Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
12791279 """
12801280 def __init__ (self , cfg : ByoModelCfg , num_classes = 1000 , in_chans = 3 , global_pool = 'avg' , output_stride = 32 ,
1281- zero_init_last_bn = True , img_size = None , drop_rate = 0. , drop_path_rate = 0. ):
1281+ zero_init_last = True , img_size = None , drop_rate = 0. , drop_path_rate = 0. ):
12821282 super ().__init__ ()
12831283 self .num_classes = num_classes
12841284 self .drop_rate = drop_rate
@@ -1309,12 +1309,8 @@ def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='
13091309
13101310 self .head = ClassifierHead (self .num_features , num_classes , pool_type = global_pool , drop_rate = self .drop_rate )
13111311
1312- for n , m in self .named_modules ():
1313- _init_weights (m , n )
1314- for m in self .modules ():
1315- # call each block's weight init for block-specific overrides to init above
1316- if hasattr (m , 'init_weights' ):
1317- m .init_weights (zero_init_last_bn = zero_init_last_bn )
1312+ # init weights
1313+ named_apply (partial (_init_weights , zero_init_last = zero_init_last ), self )
13181314
13191315 def get_classifier (self ):
13201316 return self .head .fc
@@ -1334,20 +1330,22 @@ def forward(self, x):
13341330 return x
13351331
13361332
1337- def _init_weights (m , n = '' ):
1338- if isinstance (m , nn .Conv2d ):
1339- fan_out = m .kernel_size [0 ] * m .kernel_size [1 ] * m .out_channels
1340- fan_out //= m .groups
1341- m .weight .data .normal_ (0 , math .sqrt (2.0 / fan_out ))
1342- if m .bias is not None :
1343- m .bias .data .zero_ ()
1344- elif isinstance (m , nn .Linear ):
1345- nn .init .normal_ (m .weight , mean = 0.0 , std = 0.01 )
1346- if m .bias is not None :
1347- nn .init .zeros_ (m .bias )
1348- elif isinstance (m , nn .BatchNorm2d ):
1349- nn .init .ones_ (m .weight )
1350- nn .init .zeros_ (m .bias )
1333+ def _init_weights (module , name = '' , zero_init_last = False ):
1334+ if isinstance (module , nn .Conv2d ):
1335+ fan_out = module .kernel_size [0 ] * module .kernel_size [1 ] * module .out_channels
1336+ fan_out //= module .groups
1337+ module .weight .data .normal_ (0 , math .sqrt (2.0 / fan_out ))
1338+ if module .bias is not None :
1339+ module .bias .data .zero_ ()
1340+ elif isinstance (module , nn .Linear ):
1341+ nn .init .normal_ (module .weight , mean = 0.0 , std = 0.01 )
1342+ if module .bias is not None :
1343+ nn .init .zeros_ (module .bias )
1344+ elif isinstance (module , nn .BatchNorm2d ):
1345+ nn .init .ones_ (module .weight )
1346+ nn .init .zeros_ (module .bias )
1347+ elif hasattr (module , 'init_weights' ):
1348+ module .init_weights (zero_init_last = zero_init_last )
13511349
13521350
13531351def _create_byobnet (variant , pretrained = False , ** kwargs ):
0 commit comments