3131from timm .data import create_dataset , create_loader , resolve_data_config , Mixup , FastCollateMixup , AugMixDataset
3232from timm .models import create_model , safe_model_name , resume_checkpoint , load_checkpoint ,\
3333 convert_splitbn_model , model_parameters
34- from timm .utils import setup_default_logging , random_seed , set_jit_fuser , ModelEmaV2 ,\
35- get_outdir , CheckpointSaver , distribute_bn , update_summary , accuracy , AverageMeter ,\
36- dispatch_clip_grad , reduce_tensor
34+ from timm import utils
3735from timm .loss import JsdCrossEntropy , BinaryCrossEntropy , SoftTargetCrossEntropy , BinaryCrossEntropy ,\
3836 LabelSmoothingCrossEntropy
3937from timm .optim import create_optimizer_v2 , optimizer_kwargs
@@ -346,7 +344,7 @@ def _parse_args():
346344
347345
348346def main ():
349- setup_default_logging ()
347+ utils . setup_default_logging ()
350348 args , args_text = _parse_args ()
351349
352350 if args .log_wandb :
@@ -391,10 +389,10 @@ def main():
391389 _logger .warning ("Neither APEX or native Torch AMP is available, using float32. "
392390 "Install NVIDA apex or upgrade to PyTorch 1.6" )
393391
394- random_seed (args .seed , args .rank )
392+ utils . random_seed (args .seed , args .rank )
395393
396394 if args .fuser :
397- set_jit_fuser (args .fuser )
395+ utils . set_jit_fuser (args .fuser )
398396
399397 model = create_model (
400398 args .model ,
@@ -492,7 +490,7 @@ def main():
492490 model_ema = None
493491 if args .model_ema :
494492 # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
495- model_ema = ModelEmaV2 (
493+ model_ema = utils . ModelEmaV2 (
496494 model , decay = args .model_ema_decay , device = 'cpu' if args .model_ema_force_cpu else None )
497495 if args .resume :
498496 load_checkpoint (model_ema .module , args .resume , use_ema = True )
@@ -640,9 +638,9 @@ def main():
640638 safe_model_name (args .model ),
641639 str (data_config ['input_size' ][- 1 ])
642640 ])
643- output_dir = get_outdir (args .output if args .output else './output/train' , exp_name )
641+ output_dir = utils . get_outdir (args .output if args .output else './output/train' , exp_name )
644642 decreasing = True if eval_metric == 'loss' else False
645- saver = CheckpointSaver (
643+ saver = utils . CheckpointSaver (
646644 model = model , optimizer = optimizer , args = args , model_ema = model_ema , amp_scaler = loss_scaler ,
647645 checkpoint_dir = output_dir , recovery_dir = output_dir , decreasing = decreasing , max_history = args .checkpoint_hist )
648646 with open (os .path .join (output_dir , 'args.yaml' ), 'w' ) as f :
@@ -661,13 +659,13 @@ def main():
661659 if args .distributed and args .dist_bn in ('broadcast' , 'reduce' ):
662660 if args .local_rank == 0 :
663661 _logger .info ("Distributing BatchNorm running means and vars" )
664- distribute_bn (model , args .world_size , args .dist_bn == 'reduce' )
662+ utils . distribute_bn (model , args .world_size , args .dist_bn == 'reduce' )
665663
666664 eval_metrics = validate (model , loader_eval , validate_loss_fn , args , amp_autocast = amp_autocast )
667665
668666 if model_ema is not None and not args .model_ema_force_cpu :
669667 if args .distributed and args .dist_bn in ('broadcast' , 'reduce' ):
670- distribute_bn (model_ema , args .world_size , args .dist_bn == 'reduce' )
668+ utils . distribute_bn (model_ema , args .world_size , args .dist_bn == 'reduce' )
671669 ema_eval_metrics = validate (
672670 model_ema .module , loader_eval , validate_loss_fn , args , amp_autocast = amp_autocast , log_suffix = ' (EMA)' )
673671 eval_metrics = ema_eval_metrics
@@ -677,7 +675,7 @@ def main():
677675 lr_scheduler .step (epoch + 1 , eval_metrics [eval_metric ])
678676
679677 if output_dir is not None :
680- update_summary (
678+ utils . update_summary (
681679 epoch , train_metrics , eval_metrics , os .path .join (output_dir , 'summary.csv' ),
682680 write_header = best_metric is None , log_wandb = args .log_wandb and has_wandb )
683681
@@ -704,9 +702,9 @@ def train_one_epoch(
704702 mixup_fn .mixup_enabled = False
705703
706704 second_order = hasattr (optimizer , 'is_second_order' ) and optimizer .is_second_order
707- batch_time_m = AverageMeter ()
708- data_time_m = AverageMeter ()
709- losses_m = AverageMeter ()
705+ batch_time_m = utils . AverageMeter ()
706+ data_time_m = utils . AverageMeter ()
707+ losses_m = utils . AverageMeter ()
710708
711709 model .train ()
712710
@@ -740,7 +738,7 @@ def train_one_epoch(
740738 else :
741739 loss .backward (create_graph = second_order )
742740 if args .clip_grad is not None :
743- dispatch_clip_grad (
741+ utils . dispatch_clip_grad (
744742 model_parameters (model , exclude_head = 'agc' in args .clip_mode ),
745743 value = args .clip_grad , mode = args .clip_mode )
746744 optimizer .step ()
@@ -756,7 +754,7 @@ def train_one_epoch(
756754 lr = sum (lrl ) / len (lrl )
757755
758756 if args .distributed :
759- reduced_loss = reduce_tensor (loss .data , args .world_size )
757+ reduced_loss = utils . reduce_tensor (loss .data , args .world_size )
760758 losses_m .update (reduced_loss .item (), input .size (0 ))
761759
762760 if args .local_rank == 0 :
@@ -801,10 +799,10 @@ def train_one_epoch(
801799
802800
803801def validate (model , loader , loss_fn , args , amp_autocast = suppress , log_suffix = '' ):
804- batch_time_m = AverageMeter ()
805- losses_m = AverageMeter ()
806- top1_m = AverageMeter ()
807- top5_m = AverageMeter ()
802+ batch_time_m = utils . AverageMeter ()
803+ losses_m = utils . AverageMeter ()
804+ top1_m = utils . AverageMeter ()
805+ top5_m = utils . AverageMeter ()
808806
809807 model .eval ()
810808
@@ -831,12 +829,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
831829 target = target [0 :target .size (0 ):reduce_factor ]
832830
833831 loss = loss_fn (output , target )
834- acc1 , acc5 = accuracy (output , target , topk = (1 , 5 ))
832+ acc1 , acc5 = utils . accuracy (output , target , topk = (1 , 5 ))
835833
836834 if args .distributed :
837- reduced_loss = reduce_tensor (loss .data , args .world_size )
838- acc1 = reduce_tensor (acc1 , args .world_size )
839- acc5 = reduce_tensor (acc5 , args .world_size )
835+ reduced_loss = utils . reduce_tensor (loss .data , args .world_size )
836+ acc1 = utils . reduce_tensor (acc1 , args .world_size )
837+ acc5 = utils . reduce_tensor (acc5 , args .world_size )
840838 else :
841839 reduced_loss = loss .data
842840
0 commit comments