273273 help = 'use the multi-epochs-loader to save time at the beginning of every epoch' )
274274parser .add_argument ('--torchscript' , dest = 'torchscript' , action = 'store_true' ,
275275 help = 'convert model torchscript for inference' )
276+ parser .add_argument ('--use-wandb' , action = 'store_true' , default = False ,
277+ help = 'use wandb for training and validation logs' )
278+ parser .add_argument ('--wandb-project-name' , type = str , default = None ,
279+ help = 'wandb project name to be used' )
276280
277281
278282def _parse_args ():
@@ -295,8 +299,13 @@ def _parse_args():
295299def main ():
296300 setup_default_logging ()
297301 args , args_text = _parse_args ()
298- wandb .init (project = 'efficientnet_v2' , config = args )
299- wandb .run .name = args .model
302+
303+ if args .use_wandb :
304+ if not args .wandb_project_name :
305+ args .wandb_project_name = args .model
306+ _logger .warning (f"Wandb project name not provided, defaulting to { args .model } " )
307+ wandb .init (project = args .wandb_project_name , config = args )
308+
300309 args .prefetcher = not args .no_prefetcher
301310 args .distributed = False
302311 if 'WORLD_SIZE' in os .environ :
@@ -575,14 +584,18 @@ def main():
575584 epoch , model , loader_train , optimizer , train_loss_fn , args ,
576585 lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir ,
577586 amp_autocast = amp_autocast , loss_scaler = loss_scaler , model_ema = model_ema , mixup_fn = mixup_fn )
578- wandb . log ( train_metrics )
587+
579588 if args .distributed and args .dist_bn in ('broadcast' , 'reduce' ):
580589 if args .local_rank == 0 :
581590 _logger .info ("Distributing BatchNorm running means and vars" )
582591 distribute_bn (model , args .world_size , args .dist_bn == 'reduce' )
583592
584593 eval_metrics = validate (model , loader_eval , validate_loss_fn , args , amp_autocast = amp_autocast )
585- wandb .log (eval_metrics )
594+
595+ if args .use_wandb :
596+ wandb .log (train_metrics )
597+ wandb .log (eval_metrics )
598+
586599 if model_ema is not None and not args .model_ema_force_cpu :
587600 if args .distributed and args .dist_bn in ('broadcast' , 'reduce' ):
588601 distribute_bn (model_ema , args .world_size , args .dist_bn == 'reduce' )
0 commit comments