Skip to content

Commit 5bd0471

Browse files
committed
Cleanup weight init for byob/byoanet and related
1 parent 8642401 commit 5bd0471

File tree

4 files changed

+38
-34
lines changed

4 files changed

+38
-34
lines changed

timm/models/byobnet.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import torch.nn as nn
3434

3535
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
36-
from .helpers import build_model_with_cfg
36+
from .helpers import build_model_with_cfg, named_apply
3737
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
3838
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple
3939
from .registry import register_model
@@ -166,7 +166,7 @@ class ByoModelCfg:
166166
stem_chs: int = 32
167167
width_factor: float = 1.0
168168
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
169-
zero_init_last_bn: bool = True
169+
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
170170
fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
171171

172172
act_layer: str = 'relu'
@@ -757,8 +757,8 @@ def __init__(
757757
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
758758
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
759759

760-
def init_weights(self, zero_init_last_bn: bool = False):
761-
if zero_init_last_bn:
760+
def init_weights(self, zero_init_last: bool = False):
761+
if zero_init_last:
762762
nn.init.zeros_(self.conv2_kxk.bn.weight)
763763
for attn in (self.attn, self.attn_last):
764764
if hasattr(attn, 'reset_parameters'):
@@ -814,8 +814,8 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
814814
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
815815
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
816816

817-
def init_weights(self, zero_init_last_bn: bool = False):
818-
if zero_init_last_bn:
817+
def init_weights(self, zero_init_last: bool = False):
818+
if zero_init_last:
819819
nn.init.zeros_(self.conv3_1x1.bn.weight)
820820
for attn in (self.attn, self.attn_last):
821821
if hasattr(attn, 'reset_parameters'):
@@ -871,8 +871,8 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
871871
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
872872
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
873873

874-
def init_weights(self, zero_init_last_bn: bool = False):
875-
if zero_init_last_bn:
874+
def init_weights(self, zero_init_last: bool = False):
875+
if zero_init_last:
876876
nn.init.zeros_(self.conv2_kxk.bn.weight)
877877
for attn in (self.attn, self.attn_last):
878878
if hasattr(attn, 'reset_parameters'):
@@ -924,8 +924,8 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
924924
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
925925
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
926926

927-
def init_weights(self, zero_init_last_bn: bool = False):
928-
if zero_init_last_bn:
927+
def init_weights(self, zero_init_last: bool = False):
928+
if zero_init_last:
929929
nn.init.zeros_(self.conv2_1x1.bn.weight)
930930
for attn in (self.attn, self.attn_last):
931931
if hasattr(attn, 'reset_parameters'):
@@ -967,7 +967,7 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
967967
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
968968
self.act = layers.act(inplace=True)
969969

970-
def init_weights(self, zero_init_last_bn: bool = False):
970+
def init_weights(self, zero_init_last: bool = False):
971971
# NOTE this init overrides that base model init with specific changes for the block type
972972
for m in self.modules():
973973
if isinstance(m, nn.BatchNorm2d):
@@ -1024,8 +1024,8 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
10241024
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
10251025
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
10261026

1027-
def init_weights(self, zero_init_last_bn: bool = False):
1028-
if zero_init_last_bn:
1027+
def init_weights(self, zero_init_last: bool = False):
1028+
if zero_init_last:
10291029
nn.init.zeros_(self.conv3_1x1.bn.weight)
10301030
if hasattr(self.self_attn, 'reset_parameters'):
10311031
self.self_attn.reset_parameters()
@@ -1278,7 +1278,7 @@ class ByobNet(nn.Module):
12781278
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
12791279
"""
12801280
def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
1281-
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
1281+
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
12821282
super().__init__()
12831283
self.num_classes = num_classes
12841284
self.drop_rate = drop_rate
@@ -1309,12 +1309,8 @@ def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='
13091309

13101310
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
13111311

1312-
for n, m in self.named_modules():
1313-
_init_weights(m, n)
1314-
for m in self.modules():
1315-
# call each block's weight init for block-specific overrides to init above
1316-
if hasattr(m, 'init_weights'):
1317-
m.init_weights(zero_init_last_bn=zero_init_last_bn)
1312+
# init weights
1313+
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
13181314

13191315
def get_classifier(self):
13201316
return self.head.fc
@@ -1334,20 +1330,22 @@ def forward(self, x):
13341330
return x
13351331

13361332

1337-
def _init_weights(m, n=''):
1338-
if isinstance(m, nn.Conv2d):
1339-
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
1340-
fan_out //= m.groups
1341-
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
1342-
if m.bias is not None:
1343-
m.bias.data.zero_()
1344-
elif isinstance(m, nn.Linear):
1345-
nn.init.normal_(m.weight, mean=0.0, std=0.01)
1346-
if m.bias is not None:
1347-
nn.init.zeros_(m.bias)
1348-
elif isinstance(m, nn.BatchNorm2d):
1349-
nn.init.ones_(m.weight)
1350-
nn.init.zeros_(m.bias)
1333+
def _init_weights(module, name='', zero_init_last=False):
1334+
if isinstance(module, nn.Conv2d):
1335+
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
1336+
fan_out //= module.groups
1337+
module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
1338+
if module.bias is not None:
1339+
module.bias.data.zero_()
1340+
elif isinstance(module, nn.Linear):
1341+
nn.init.normal_(module.weight, mean=0.0, std=0.01)
1342+
if module.bias is not None:
1343+
nn.init.zeros_(module.bias)
1344+
elif isinstance(module, nn.BatchNorm2d):
1345+
nn.init.ones_(module.weight)
1346+
nn.init.zeros_(module.bias)
1347+
elif hasattr(module, 'init_weights'):
1348+
module.init_weights(zero_init_last=zero_init_last)
13511349

13521350

13531351
def _create_byobnet(variant, pretrained=False, **kwargs):

timm/models/layers/bottleneck_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv
102102

103103
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
104104

105+
self.reset_parameters()
106+
105107
def reset_parameters(self):
106108
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
107109
trunc_normal_(self.pos_embed.height_rel, std=self.scale)

timm/models/layers/halo_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def __init__(
123123
self.pos_embed = PosEmbedRel(
124124
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
125125

126+
self.reset_parameters()
127+
126128
def reset_parameters(self):
127129
std = self.q.weight.shape[1] ** -0.5 # fan-in
128130
trunc_normal_(self.q.weight, std=std)

timm/models/layers/lambda_layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def __init__(
5757

5858
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
5959

60+
self.reset_parameters()
61+
6062
def reset_parameters(self):
6163
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
6264
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)

0 commit comments

Comments
 (0)