Skip to content

Commit ba793f5

Browse files
committed
Merge branch 'adding_ECA_resnet' of https://github.com/yoniaflalo/pytorch-image-models into yoniaflalo-adding_ECA_resnet
2 parents 1d4ac1b + 07f19dd commit ba793f5

File tree

2 files changed

+185
-3
lines changed

2 files changed

+185
-3
lines changed

timm/models/helpers.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import torch
2+
import torch.nn as nn
3+
from copy import deepcopy
24
import torch.utils.model_zoo as model_zoo
35
import os
46
import logging
57
from collections import OrderedDict
8+
from timm.models.layers.conv2d_same import Conv2dSame
69

710

811
def load_state_dict(checkpoint_path, use_ema=False):
@@ -101,4 +104,91 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
101104

102105

103106

104-
107+
def extract_layer(model, layer):
108+
layer = layer.split('.')
109+
module = model
110+
if hasattr(model, 'module') and layer[0] != 'module':
111+
module = model.module
112+
if not hasattr(model, 'module') and layer[0] == 'module':
113+
layer = layer[1:]
114+
for l in layer:
115+
if hasattr(module, l):
116+
if not l.isdigit():
117+
module = getattr(module, l)
118+
else:
119+
module = module[int(l)]
120+
else:
121+
return module
122+
return module
123+
124+
125+
def set_layer(model, layer, val):
126+
layer = layer.split('.')
127+
module = model
128+
if hasattr(model, 'module') and layer[0] != 'module':
129+
module = model.module
130+
lst_index = 0
131+
module2 = module
132+
for l in layer:
133+
if hasattr(module2, l):
134+
if not l.isdigit():
135+
module2 = getattr(module2, l)
136+
else:
137+
module2 = module2[int(l)]
138+
lst_index += 1
139+
lst_index -= 1
140+
for l in layer[:lst_index]:
141+
if not l.isdigit():
142+
module = getattr(module, l)
143+
else:
144+
module = module[int(l)]
145+
l = layer[lst_index]
146+
setattr(module, l, val)
147+
148+
149+
def adapt_model_from_string(parent_module, model_string):
150+
separator = '***'
151+
state_dict = {}
152+
lst_shape = model_string.split(separator)
153+
for k in lst_shape:
154+
k = k.split(':')
155+
key = k[0]
156+
shape = k[1][1:-1].split(',')
157+
if shape[0] != '':
158+
state_dict[key] = [int(i) for i in shape]
159+
160+
new_module = deepcopy(parent_module)
161+
for n, m in parent_module.named_modules():
162+
old_module = extract_layer(parent_module, n)
163+
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
164+
if isinstance(old_module, Conv2dSame):
165+
conv = Conv2dSame
166+
else:
167+
conv = nn.Conv2d
168+
s = state_dict[n + '.weight']
169+
in_channels = s[1]
170+
out_channels = s[0]
171+
if old_module.groups > 1:
172+
in_channels = out_channels
173+
g = in_channels
174+
else:
175+
g = 1
176+
new_conv = conv(in_channels=in_channels, out_channels=out_channels,
177+
kernel_size=old_module.kernel_size, bias=old_module.bias is not None,
178+
padding=old_module.padding, dilation=old_module.dilation,
179+
groups=g, stride=old_module.stride)
180+
set_layer(new_module, n, new_conv)
181+
if isinstance(old_module, nn.BatchNorm2d):
182+
new_bn = nn.BatchNorm2d(num_features=state_dict[n + '.weight'][0], eps=old_module.eps,
183+
momentum=old_module.momentum,
184+
affine=old_module.affine,
185+
track_running_stats=True)
186+
set_layer(new_module, n, new_bn)
187+
if isinstance(old_module, nn.Linear):
188+
new_fc = nn.Linear(in_features=state_dict[n + '.weight'][1], out_features=old_module.out_features,
189+
bias=old_module.bias is not None)
190+
set_layer(new_module, n, new_fc)
191+
new_module.eval()
192+
parent_module.eval()
193+
194+
return new_module

0 commit comments

Comments
 (0)