Skip to content

Commit b93fcf0

Browse files
committed
Add Facebook Research Semi-Supervised and Semi-Weakly Supervised ResNet model weights.
1 parent a9eb484 commit b93fcf0

File tree

3 files changed

+274
-51
lines changed

3 files changed

+274
-51
lines changed

sotabench.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,65 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
167167
_entry('ig_resnext101_32x48d', 'ResNeXt-101 32x48d (288x288 Mean-Max Pooling)', '1805.00932',
168168
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 8),
169169

170+
## Facebook SSL weights
171+
_entry('ssl_resnet18', 'ResNet-18', '1905.00546',
172+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
173+
_entry('ssl_resnet50', 'ResNet-50', '1905.00546',
174+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
175+
_entry('ssl_resnext50_32x4d', 'ResNeXt-50 32x4d', '1905.00546',
176+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
177+
_entry('ssl_resnext101_32x4d', 'ResNeXt-101 32x4d', '1905.00546',
178+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
179+
_entry('ssl_resnext101_32x8d', 'ResNeXt-101 32x8d', '1905.00546',
180+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
181+
_entry('ssl_resnext101_32x16d', 'ResNeXt-101 32x16d', '1905.00546',
182+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
183+
184+
_entry('ssl_resnet50', 'ResNet-50 (288x288 Mean-Max Pooling)', '1905.00546',
185+
ttp=True, args=dict(img_size=288),
186+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
187+
_entry('ssl_resnext50_32x4d', 'ResNeXt-50 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
188+
ttp=True, args=dict(img_size=288),
189+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
190+
_entry('ssl_resnext101_32x4d', 'ResNeXt-101 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
191+
ttp=True, args=dict(img_size=288),
192+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
193+
_entry('ssl_resnext101_32x8d', 'ResNeXt-101 32x8d (288x288 Mean-Max Pooling)', '1905.00546',
194+
ttp=True, args=dict(img_size=288),
195+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
196+
_entry('ssl_resnext101_32x16d', 'ResNeXt-101 32x16d (288x288 Mean-Max Pooling)', '1905.00546',
197+
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2,
198+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
199+
200+
## Facebook SWSL weights
201+
_entry('swsl_resnet18', 'ResNet-18', '1905.00546',
202+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
203+
_entry('swsl_resnet50', 'ResNet-50', '1905.00546',
204+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
205+
_entry('swsl_resnext50_32x4d', 'ResNeXt-50 32x4d', '1905.00546',
206+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
207+
_entry('swsl_resnext101_32x4d', 'ResNeXt-101 32x4d', '1905.00546',
208+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
209+
_entry('swsl_resnext101_32x8d', 'ResNeXt-101 32x8d', '1905.00546',
210+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
211+
_entry('swsl_resnext101_32x16d', 'ResNeXt-101 32x16d', '1905.00546'),
212+
213+
_entry('swsl_resnet50', 'ResNet-50 (288x288 Mean-Max Pooling)', '1905.00546',
214+
ttp=True, args=dict(img_size=288),
215+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
216+
_entry('swsl_resnext50_32x4d', 'ResNeXt-50 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
217+
ttp=True, args=dict(img_size=288),
218+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
219+
_entry('swsl_resnext101_32x4d', 'ResNeXt-101 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
220+
ttp=True, args=dict(img_size=288),
221+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
222+
_entry('swsl_resnext101_32x8d', 'ResNeXt-101 32x8d (288x288 Mean-Max Pooling)', '1905.00546',
223+
ttp=True, args=dict(img_size=288),
224+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
225+
_entry('swsl_resnext101_32x16d', 'ResNeXt-101 32x16d (288x288 Mean-Max Pooling)', '1905.00546',
226+
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2,
227+
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
228+
170229
## DLA official impl weights (to remove if sotabench added to source)
171230
_entry('dla34', 'DLA-34', '1707.06484'),
172231
_entry('dla46_c', 'DLA-46-C', '1707.06484'),

timm/models/helpers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,30 +57,32 @@ def resume_checkpoint(model, checkpoint_path):
5757
raise FileNotFoundError()
5858

5959

60-
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
61-
if 'url' not in default_cfg or not default_cfg['url']:
60+
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None):
61+
if cfg is None:
62+
cfg = getattr(model, 'default_cfg')
63+
if cfg is None or 'url' not in cfg or not cfg['url']:
6264
logging.warning("Pretrained model URL is invalid, using random initialization.")
6365
return
6466

65-
state_dict = model_zoo.load_url(default_cfg['url'], progress=False)
67+
state_dict = model_zoo.load_url(cfg['url'], progress=False)
6668

6769
if in_chans == 1:
68-
conv1_name = default_cfg['first_conv']
70+
conv1_name = cfg['first_conv']
6971
logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
7072
conv1_weight = state_dict[conv1_name + '.weight']
7173
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
7274
elif in_chans != 3:
7375
assert False, "Invalid in_chans for pretrained weights"
7476

7577
strict = True
76-
classifier_name = default_cfg['classifier']
77-
if num_classes == 1000 and default_cfg['num_classes'] == 1001:
78+
classifier_name = cfg['classifier']
79+
if num_classes == 1000 and cfg['num_classes'] == 1001:
7880
# special case for imagenet trained models with extra background class in pretrained weights
7981
classifier_weight = state_dict[classifier_name + '.weight']
8082
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
8183
classifier_bias = state_dict[classifier_name + '.bias']
8284
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
83-
elif num_classes != default_cfg['num_classes']:
85+
elif num_classes != cfg['num_classes']:
8486
# completely discard fully connected for all other differences between pretrained and created model
8587
del state_dict[classifier_name + '.weight']
8688
del state_dict[classifier_name + '.bias']

0 commit comments

Comments
 (0)