Skip to content

Commit 9e12530

Browse files
kaczmarjrwightman
authored andcommitted
use utils namespace instead of function/classnames
This fixes buggy behavior introduced by #1266. Related to #1273.
1 parent db64393 commit 9e12530

File tree

1 file changed

+23
-25
lines changed

1 file changed

+23
-25
lines changed

train.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
3232
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
3333
convert_splitbn_model, model_parameters
34-
from timm.utils import setup_default_logging, random_seed, set_jit_fuser, ModelEmaV2,\
35-
get_outdir, CheckpointSaver, distribute_bn, update_summary, accuracy, AverageMeter,\
36-
dispatch_clip_grad, reduce_tensor
34+
from timm import utils
3735
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\
3836
LabelSmoothingCrossEntropy
3937
from timm.optim import create_optimizer_v2, optimizer_kwargs
@@ -346,7 +344,7 @@ def _parse_args():
346344

347345

348346
def main():
349-
setup_default_logging()
347+
utils.setup_default_logging()
350348
args, args_text = _parse_args()
351349

352350
if args.log_wandb:
@@ -391,10 +389,10 @@ def main():
391389
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
392390
"Install NVIDA apex or upgrade to PyTorch 1.6")
393391

394-
random_seed(args.seed, args.rank)
392+
utils.random_seed(args.seed, args.rank)
395393

396394
if args.fuser:
397-
set_jit_fuser(args.fuser)
395+
utils.set_jit_fuser(args.fuser)
398396

399397
model = create_model(
400398
args.model,
@@ -492,7 +490,7 @@ def main():
492490
model_ema = None
493491
if args.model_ema:
494492
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
495-
model_ema = ModelEmaV2(
493+
model_ema = utils.ModelEmaV2(
496494
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
497495
if args.resume:
498496
load_checkpoint(model_ema.module, args.resume, use_ema=True)
@@ -640,9 +638,9 @@ def main():
640638
safe_model_name(args.model),
641639
str(data_config['input_size'][-1])
642640
])
643-
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
641+
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
644642
decreasing = True if eval_metric == 'loss' else False
645-
saver = CheckpointSaver(
643+
saver = utils.CheckpointSaver(
646644
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
647645
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
648646
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
@@ -661,13 +659,13 @@ def main():
661659
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
662660
if args.local_rank == 0:
663661
_logger.info("Distributing BatchNorm running means and vars")
664-
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
662+
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
665663

666664
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
667665

668666
if model_ema is not None and not args.model_ema_force_cpu:
669667
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
670-
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
668+
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
671669
ema_eval_metrics = validate(
672670
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
673671
eval_metrics = ema_eval_metrics
@@ -677,7 +675,7 @@ def main():
677675
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
678676

679677
if output_dir is not None:
680-
update_summary(
678+
utils.update_summary(
681679
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
682680
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
683681

@@ -704,9 +702,9 @@ def train_one_epoch(
704702
mixup_fn.mixup_enabled = False
705703

706704
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
707-
batch_time_m = AverageMeter()
708-
data_time_m = AverageMeter()
709-
losses_m = AverageMeter()
705+
batch_time_m = utils.AverageMeter()
706+
data_time_m = utils.AverageMeter()
707+
losses_m = utils.AverageMeter()
710708

711709
model.train()
712710

@@ -740,7 +738,7 @@ def train_one_epoch(
740738
else:
741739
loss.backward(create_graph=second_order)
742740
if args.clip_grad is not None:
743-
dispatch_clip_grad(
741+
utils.dispatch_clip_grad(
744742
model_parameters(model, exclude_head='agc' in args.clip_mode),
745743
value=args.clip_grad, mode=args.clip_mode)
746744
optimizer.step()
@@ -756,7 +754,7 @@ def train_one_epoch(
756754
lr = sum(lrl) / len(lrl)
757755

758756
if args.distributed:
759-
reduced_loss = reduce_tensor(loss.data, args.world_size)
757+
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
760758
losses_m.update(reduced_loss.item(), input.size(0))
761759

762760
if args.local_rank == 0:
@@ -801,10 +799,10 @@ def train_one_epoch(
801799

802800

803801
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
804-
batch_time_m = AverageMeter()
805-
losses_m = AverageMeter()
806-
top1_m = AverageMeter()
807-
top5_m = AverageMeter()
802+
batch_time_m = utils.AverageMeter()
803+
losses_m = utils.AverageMeter()
804+
top1_m = utils.AverageMeter()
805+
top5_m = utils.AverageMeter()
808806

809807
model.eval()
810808

@@ -831,12 +829,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
831829
target = target[0:target.size(0):reduce_factor]
832830

833831
loss = loss_fn(output, target)
834-
acc1, acc5 = accuracy(output, target, topk=(1, 5))
832+
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
835833

836834
if args.distributed:
837-
reduced_loss = reduce_tensor(loss.data, args.world_size)
838-
acc1 = reduce_tensor(acc1, args.world_size)
839-
acc5 = reduce_tensor(acc5, args.world_size)
835+
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
836+
acc1 = utils.reduce_tensor(acc1, args.world_size)
837+
acc5 = utils.reduce_tensor(acc5, args.world_size)
840838
else:
841839
reduced_loss = loss.data
842840

0 commit comments

Comments
 (0)