Skip to content

Commit 7419e98

Browse files
committed
Add MxNet Gluon ResNet variants w/ converted pretrained weights. Very well trained set of models.
1 parent 2da0b4d commit 7419e98

File tree

3 files changed

+795
-0
lines changed

3 files changed

+795
-0
lines changed

convert/convert_from_mxnet.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import argparse
2+
import hashlib
3+
import os
4+
5+
import mxnet as mx
6+
import gluoncv
7+
import torch
8+
from models.model_factory import create_model
9+
10+
parser = argparse.ArgumentParser(description='Training')
11+
parser.add_argument('--model', default='all', type=str, metavar='MODEL',
12+
help='Name of model to train (default: "all"')
13+
14+
15+
def convert(mxnet_name, torch_name):
16+
# download and load the pre-trained model
17+
net = gluoncv.model_zoo.get_model(mxnet_name, pretrained=True)
18+
19+
# create corresponding torch model
20+
torch_net = create_model(torch_name)
21+
22+
mxp = [(k, v) for k, v in net.collect_params().items() if 'running' not in k]
23+
torchp = list(torch_net.named_parameters())
24+
torch_params = {}
25+
26+
# convert parameters
27+
# NOTE: we are relying on the fact that the order of parameters
28+
# are usually exactly the same between these models, thus no key name mapping
29+
# is necessary. Asserts will trip if this is not the case.
30+
for (tn, tv), (mn, mv) in zip(torchp, mxp):
31+
m_split = mn.split('_')
32+
t_split = tn.split('.')
33+
print(t_split, m_split)
34+
print(tv.shape, mv.shape)
35+
36+
# ensure ordering of BN params match since their sizes are not specific
37+
if m_split[-1] == 'gamma':
38+
assert t_split[-1] == 'weight'
39+
if m_split[-1] == 'beta':
40+
assert t_split[-1] == 'bias'
41+
42+
# ensure shapes match
43+
assert all(t == m for t, m in zip(tv.shape, mv.shape))
44+
45+
torch_tensor = torch.from_numpy(mv.data().asnumpy())
46+
torch_params[tn] = torch_tensor
47+
48+
# convert buffers (batch norm running stats)
49+
mxb = [(k, v) for k, v in net.collect_params().items() if any(x in k for x in ['running_mean', 'running_var'])]
50+
torchb = [(k, v) for k, v in torch_net.named_buffers() if 'num_batches' not in k]
51+
for (tn, tv), (mn, mv) in zip(torchb, mxb):
52+
print(tn, mn)
53+
print(tv.shape, mv.shape)
54+
55+
# ensure ordering of BN params match since their sizes are not specific
56+
if 'running_var' in tn:
57+
assert 'running_var' in mn
58+
if 'running_mean' in tn:
59+
assert 'running_mean' in mn
60+
61+
torch_tensor = torch.from_numpy(mv.data().asnumpy())
62+
torch_params[tn] = torch_tensor
63+
64+
torch_net.load_state_dict(torch_params)
65+
torch_filename = './%s.pth' % torch_name
66+
torch.save(torch_net.state_dict(), torch_filename)
67+
with open(torch_filename, 'rb') as f:
68+
sha_hash = hashlib.sha256(f.read()).hexdigest()
69+
final_filename = os.path.splitext(torch_filename)[0] + '-' + sha_hash[:8] + '.pth'
70+
os.rename(torch_filename, final_filename)
71+
print("=> Saved converted model to '{}, SHA256: {}'".format(final_filename, sha_hash))
72+
73+
74+
def map_mx_to_torch_model(mx_name):
75+
torch_name = mx_name.lower()
76+
if torch_name.startswith('se_'):
77+
torch_name = torch_name.replace('se_', 'se')
78+
elif torch_name.startswith('senet_'):
79+
torch_name = torch_name.replace('senet_', 'senet')
80+
elif torch_name.startswith('inceptionv3'):
81+
torch_name = torch_name.replace('inceptionv3', 'inception_v3')
82+
torch_name = 'gluon_' + torch_name
83+
return torch_name
84+
85+
86+
ALL = ['resnet18_v1b', 'resnet34_v1b', 'resnet50_v1b', 'resnet101_v1b', 'resnet152_v1b',
87+
'resnet50_v1c', 'resnet101_v1c', 'resnet152_v1c', 'resnet50_v1d', 'resnet101_v1d', 'resnet152_v1d',
88+
#'resnet50_v1e', 'resnet101_v1e', 'resnet152_v1e',
89+
'resnet50_v1s', 'resnet101_v1s', 'resnet152_v1s', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d',
90+
'se_resnext50_32x4d', 'se_resnext101_32x4d', 'se_resnext101_64x4d', 'senet_154', 'inceptionv3']
91+
92+
93+
def main():
94+
args = parser.parse_args()
95+
96+
if not args.model or args.model == 'all':
97+
for mx_model in ALL:
98+
torch_model = map_mx_to_torch_model(mx_model)
99+
convert(mx_model, torch_model)
100+
else:
101+
mx_model = args.model
102+
torch_model = map_mx_to_torch_model(mx_model)
103+
convert(mx_model, torch_model)
104+
105+
106+
if __name__ == '__main__':
107+
main()

0 commit comments

Comments
 (0)