2323from contextlib import suppress
2424from datetime import datetime
2525
26- import wandb
27-
2826import torch
2927import torch .nn as nn
3028import torchvision .utils
5452except AttributeError :
5553 pass
5654
55+ try :
56+ import wandb
57+ has_wandb = True
58+ except ModuleNotFoundError :
59+ has_wandb = False
60+
5761torch .backends .cudnn .benchmark = True
5862_logger = logging .getLogger ('train' )
5963
274278parser .add_argument ('--torchscript' , dest = 'torchscript' , action = 'store_true' ,
275279 help = 'convert model torchscript for inference' )
276280parser .add_argument ('--log-wandb' , action = 'store_true' , default = False ,
277- help = 'use wandb for training and validation logs ' )
281+ help = 'log training and validation metrics to wandb ' )
278282
279283
280284def _parse_args ():
@@ -299,8 +303,12 @@ def main():
299303 args , args_text = _parse_args ()
300304
301305 if args .log_wandb :
302- wandb .init (project = args .experiment , config = args )
303-
306+ if has_wandb :
307+ wandb .init (project = args .experiment , config = args )
308+ else :
309+ _logger .warning ("You've requested to log metrics to wandb but package not found. "
310+ "Metrics not being logged to wandb, try `pip install wandb`" )
311+
304312 args .prefetcher = not args .no_prefetcher
305313 args .distributed = False
306314 if 'WORLD_SIZE' in os .environ :
@@ -600,7 +608,7 @@ def main():
600608
601609 update_summary (
602610 epoch , train_metrics , eval_metrics , os .path .join (output_dir , 'summary.csv' ),
603- write_header = best_metric is None , log_wandb = args .log_wandb )
611+ write_header = best_metric is None , log_wandb = args .log_wandb and has_wandb )
604612
605613 if saver is not None :
606614 # save proper checkpoint with eval metric
0 commit comments