|
42 | 42 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') |
43 | 43 | parser.add_argument('--pretrained', action='store_true', default=False, |
44 | 44 | help='Start with pretrained version of specified network (if avail)') |
45 | | -parser.add_argument('--img-size', type=int, default=224, metavar='N', |
46 | | - help='Image patch size (default: 224)') |
| 45 | +parser.add_argument('--img-size', type=int, default=None, metavar='N', |
| 46 | + help='Image patch size (default: None => model default)') |
47 | 47 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', |
48 | 48 | help='Override mean pixel value of dataset') |
49 | 49 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', |
@@ -159,15 +159,6 @@ def main(): |
159 | 159 |
|
160 | 160 | torch.manual_seed(args.seed + args.rank) |
161 | 161 |
|
162 | | - output_dir = '' |
163 | | - if args.local_rank == 0: |
164 | | - output_base = args.output if args.output else './output' |
165 | | - exp_name = '-'.join([ |
166 | | - datetime.now().strftime("%Y%m%d-%H%M%S"), |
167 | | - args.model, |
168 | | - str(args.img_size)]) |
169 | | - output_dir = get_outdir(output_base, 'train', exp_name) |
170 | | - |
171 | 162 | model = create_model( |
172 | 163 | args.model, |
173 | 164 | pretrained=args.pretrained, |
@@ -291,13 +282,21 @@ def main(): |
291 | 282 | validate_loss_fn = train_loss_fn |
292 | 283 |
|
293 | 284 | eval_metric = args.eval_metric |
| 285 | + best_metric = None |
| 286 | + best_epoch = None |
294 | 287 | saver = None |
295 | | - if output_dir: |
296 | | - # only set if process is rank 0 |
| 288 | + output_dir = '' |
| 289 | + if args.local_rank == 0: |
| 290 | + output_base = args.output if args.output else './output' |
| 291 | + exp_name = '-'.join([ |
| 292 | + datetime.now().strftime("%Y%m%d-%H%M%S"), |
| 293 | + args.model, |
| 294 | + str(data_config['input_size'][-1]) |
| 295 | + ]) |
| 296 | + output_dir = get_outdir(output_base, 'train', exp_name) |
297 | 297 | decreasing = True if eval_metric == 'loss' else False |
298 | 298 | saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) |
299 | | - best_metric = None |
300 | | - best_epoch = None |
| 299 | + |
301 | 300 | try: |
302 | 301 | for epoch in range(start_epoch, num_epochs): |
303 | 302 | if args.distributed: |
|
0 commit comments