Skip to content

Commit 2b49ab7

Browse files
committed
Fix ResNetV2 pretrained classifier issue. Fixes #540
1 parent de9dff9 commit 2b49ab7

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_model_default_cfgs(model_name, batch_size):
132132
def test_model_load_pretrained(model_name, batch_size):
133133
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
134134
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
135-
create_model(model_name, pretrained=True, in_chans=in_chans)
135+
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5)
136136

137137
@pytest.mark.timeout(120)
138138
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))

timm/models/resnetv2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def get_classifier(self):
365365
return self.head.fc
366366

367367
def reset_classifier(self, num_classes, global_pool='avg'):
368+
self.num_classes = num_classes
368369
self.head = ClassifierHead(
369370
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
370371

@@ -393,8 +394,9 @@ def load_pretrained(self, checkpoint_path, prefix='resnet/'):
393394
self.stem.conv.weight.copy_(stem_conv_w)
394395
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
395396
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
396-
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
397-
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
397+
if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
398+
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
399+
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
398400
for i, (sname, stage) in enumerate(self.stages.named_children()):
399401
for j, (bname, block) in enumerate(stage.blocks.named_children()):
400402
convname = 'standardized_conv2d'

0 commit comments

Comments
 (0)