Skip to content

Commit 00c8e0b

Browse files
committed
Make use of wandb configurable
1 parent 8e6fb86 commit 00c8e0b

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

train.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@
273273
help='use the multi-epochs-loader to save time at the beginning of every epoch')
274274
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
275275
help='convert model torchscript for inference')
276+
parser.add_argument('--use-wandb', action='store_true', default=False,
277+
help='use wandb for training and validation logs')
278+
parser.add_argument('--wandb-project-name', type=str, default=None,
279+
help='wandb project name to be used')
276280

277281

278282
def _parse_args():
@@ -295,8 +299,13 @@ def _parse_args():
295299
def main():
296300
setup_default_logging()
297301
args, args_text = _parse_args()
298-
wandb.init(project='efficientnet_v2', config=args)
299-
wandb.run.name = args.model
302+
303+
if args.use_wandb:
304+
if not args.wandb_project_name:
305+
args.wandb_project_name = args.model
306+
_logger.warning(f"Wandb project name not provided, defaulting to {args.model}")
307+
wandb.init(project=args.wandb_project_name, config=args)
308+
300309
args.prefetcher = not args.no_prefetcher
301310
args.distributed = False
302311
if 'WORLD_SIZE' in os.environ:
@@ -575,14 +584,18 @@ def main():
575584
epoch, model, loader_train, optimizer, train_loss_fn, args,
576585
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
577586
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
578-
wandb.log(train_metrics)
587+
579588
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
580589
if args.local_rank == 0:
581590
_logger.info("Distributing BatchNorm running means and vars")
582591
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
583592

584593
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
585-
wandb.log(eval_metrics)
594+
595+
if args.use_wandb:
596+
wandb.log(train_metrics)
597+
wandb.log(eval_metrics)
598+
586599
if model_ema is not None and not args.model_ema_force_cpu:
587600
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
588601
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

0 commit comments

Comments
 (0)