Skip to content

Commit 8e6fb86

Browse files
committed
Add wandb support
1 parent 779107b commit 8e6fb86

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

train.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from contextlib import suppress
2424
from datetime import datetime
2525

26+
import wandb
27+
2628
import torch
2729
import torch.nn as nn
2830
import torchvision.utils
@@ -293,7 +295,8 @@ def _parse_args():
293295
def main():
294296
setup_default_logging()
295297
args, args_text = _parse_args()
296-
298+
wandb.init(project='efficientnet_v2', config=args)
299+
wandb.run.name = args.model
297300
args.prefetcher = not args.no_prefetcher
298301
args.distributed = False
299302
if 'WORLD_SIZE' in os.environ:
@@ -572,14 +575,14 @@ def main():
572575
epoch, model, loader_train, optimizer, train_loss_fn, args,
573576
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
574577
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
575-
578+
wandb.log(train_metrics)
576579
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
577580
if args.local_rank == 0:
578581
_logger.info("Distributing BatchNorm running means and vars")
579582
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
580583

581584
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
582-
585+
wandb.log(eval_metrics)
583586
if model_ema is not None and not args.model_ema_force_cpu:
584587
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
585588
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
@@ -711,7 +714,7 @@ def train_one_epoch(
711714
if hasattr(optimizer, 'sync_lookahead'):
712715
optimizer.sync_lookahead()
713716

714-
return OrderedDict([('loss', losses_m.avg)])
717+
return OrderedDict([('train_loss', losses_m.avg)])
715718

716719

717720
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
@@ -773,7 +776,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
773776
log_name, batch_idx, last_idx, batch_time=batch_time_m,
774777
loss=losses_m, top1=top1_m, top5=top5_m))
775778

776-
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
779+
metrics = OrderedDict([('val_loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
777780

778781
return metrics
779782

0 commit comments

Comments
 (0)