Skip to content

Commit b9f8d40

Browse files
committed
Fix pretrained override logic for validate, checkpoint always trump pretrained flag during model create
1 parent 0e1fd11 commit b9f8d40

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

models/model_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def create_model(
3636
else:
3737
raise RuntimeError('Unknown model (%s)' % model_name)
3838

39-
if checkpoint_path and not pretrained:
39+
if checkpoint_path:
4040
load_checkpoint(model, checkpoint_path)
4141

4242
return model

validate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454

5555

5656
def validate(args):
57+
# might as well try to validate something
58+
args.pretrained = args.pretrained or not args.checkpoint
5759

5860
# create model
5961
model = create_model(
@@ -62,10 +64,8 @@ def validate(args):
6264
in_chans=3,
6365
pretrained=args.pretrained)
6466

65-
if args.checkpoint and not args.pretrained:
67+
if args.checkpoint:
6668
load_checkpoint(model, args.checkpoint, args.use_ema)
67-
else:
68-
args.pretrained = True # might as well try to validate something...
6969

7070
param_count = sum([m.numel() for m in model.parameters()])
7171
print('Model %s created, param count: %d' % (args.model, param_count))

0 commit comments

Comments
 (0)