Skip to content

Commit 7dab6d1

Browse files
committed
Default to img_size in model default_cfg, defer output folder creation until later in the init sequence
1 parent 9bcd651 commit 7dab6d1

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

train.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
4343
parser.add_argument('--pretrained', action='store_true', default=False,
4444
help='Start with pretrained version of specified network (if avail)')
45-
parser.add_argument('--img-size', type=int, default=224, metavar='N',
46-
help='Image patch size (default: 224)')
45+
parser.add_argument('--img-size', type=int, default=None, metavar='N',
46+
help='Image patch size (default: None => model default)')
4747
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
4848
help='Override mean pixel value of dataset')
4949
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@@ -159,15 +159,6 @@ def main():
159159

160160
torch.manual_seed(args.seed + args.rank)
161161

162-
output_dir = ''
163-
if args.local_rank == 0:
164-
output_base = args.output if args.output else './output'
165-
exp_name = '-'.join([
166-
datetime.now().strftime("%Y%m%d-%H%M%S"),
167-
args.model,
168-
str(args.img_size)])
169-
output_dir = get_outdir(output_base, 'train', exp_name)
170-
171162
model = create_model(
172163
args.model,
173164
pretrained=args.pretrained,
@@ -291,13 +282,21 @@ def main():
291282
validate_loss_fn = train_loss_fn
292283

293284
eval_metric = args.eval_metric
285+
best_metric = None
286+
best_epoch = None
294287
saver = None
295-
if output_dir:
296-
# only set if process is rank 0
288+
output_dir = ''
289+
if args.local_rank == 0:
290+
output_base = args.output if args.output else './output'
291+
exp_name = '-'.join([
292+
datetime.now().strftime("%Y%m%d-%H%M%S"),
293+
args.model,
294+
str(data_config['input_size'][-1])
295+
])
296+
output_dir = get_outdir(output_base, 'train', exp_name)
297297
decreasing = True if eval_metric == 'loss' else False
298298
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
299-
best_metric = None
300-
best_epoch = None
299+
301300
try:
302301
for epoch in range(start_epoch, num_epochs):
303302
if args.distributed:

utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,9 @@ def _load_checkpoint(self, checkpoint_path):
253253
name = k
254254
new_state_dict[name] = v
255255
self.ema.load_state_dict(new_state_dict)
256-
print("=> loaded state_dict_ema")
256+
print("=> Loaded state_dict_ema")
257257
else:
258-
print("=> failed to find state_dict_ema, starting from loaded model weights)")
258+
print("=> Failed to find state_dict_ema, starting from loaded model weights")
259259

260260
def update(self, model):
261261
# correct a mismatch in state dict keys

0 commit comments

Comments
 (0)