|
23 | 23 | from contextlib import suppress |
24 | 24 | from datetime import datetime |
25 | 25 |
|
| 26 | +import wandb |
| 27 | + |
26 | 28 | import torch |
27 | 29 | import torch.nn as nn |
28 | 30 | import torchvision.utils |
@@ -293,7 +295,8 @@ def _parse_args(): |
293 | 295 | def main(): |
294 | 296 | setup_default_logging() |
295 | 297 | args, args_text = _parse_args() |
296 | | - |
| 298 | + wandb.init(project='efficientnet_v2', config=args) |
| 299 | + wandb.run.name = args.model |
297 | 300 | args.prefetcher = not args.no_prefetcher |
298 | 301 | args.distributed = False |
299 | 302 | if 'WORLD_SIZE' in os.environ: |
@@ -572,14 +575,14 @@ def main(): |
572 | 575 | epoch, model, loader_train, optimizer, train_loss_fn, args, |
573 | 576 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, |
574 | 577 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) |
575 | | - |
| 578 | + wandb.log(train_metrics) |
576 | 579 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): |
577 | 580 | if args.local_rank == 0: |
578 | 581 | _logger.info("Distributing BatchNorm running means and vars") |
579 | 582 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce') |
580 | 583 |
|
581 | 584 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) |
582 | | - |
| 585 | + wandb.log(eval_metrics) |
583 | 586 | if model_ema is not None and not args.model_ema_force_cpu: |
584 | 587 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): |
585 | 588 | distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') |
@@ -711,7 +714,7 @@ def train_one_epoch( |
711 | 714 | if hasattr(optimizer, 'sync_lookahead'): |
712 | 715 | optimizer.sync_lookahead() |
713 | 716 |
|
714 | | - return OrderedDict([('loss', losses_m.avg)]) |
| 717 | + return OrderedDict([('train_loss', losses_m.avg)]) |
715 | 718 |
|
716 | 719 |
|
717 | 720 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): |
@@ -773,7 +776,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') |
773 | 776 | log_name, batch_idx, last_idx, batch_time=batch_time_m, |
774 | 777 | loss=losses_m, top1=top1_m, top5=top5_m)) |
775 | 778 |
|
776 | | - metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) |
| 779 | + metrics = OrderedDict([('val_loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) |
777 | 780 |
|
778 | 781 | return metrics |
779 | 782 |
|
|
0 commit comments