Skip to content

Commit 38d8f67

Browse files
committed
Fix potential issue with change to num_classes arg in train/validate.py defaulting to None (rely on model def / default_cfg)
1 parent 587780e commit 38d8f67

File tree

3 files changed

+7
-0
lines changed

3 files changed

+7
-0
lines changed

timm/models/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
198198

199199
classifier_name = cfg['classifier']
200200
if num_classes == 1000 and cfg['num_classes'] == 1001:
201+
# FIXME this special case is problematic as number of pretrained weight sources increases
201202
# special case for imagenet trained models with extra background class in pretrained weights
202203
classifier_weight = state_dict[classifier_name + '.weight']
203204
state_dict[classifier_name + '.weight'] = classifier_weight[1:]

train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ def main():
337337
bn_eps=args.bn_eps,
338338
scriptable=args.torchscript,
339339
checkpoint_path=args.initial_checkpoint)
340+
if args.num_classes is None:
341+
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
342+
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
340343

341344
if args.local_rank == 0:
342345
_logger.info('Model %s created, param count: %d' %

validate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ def validate(args):
137137
in_chans=3,
138138
global_pool=args.gp,
139139
scriptable=args.torchscript)
140+
if args.num_classes is None:
141+
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
142+
args.num_classes = model.num_classes
140143

141144
if args.checkpoint:
142145
load_checkpoint(model, args.checkpoint, args.use_ema)

0 commit comments

Comments
 (0)