66try :
77 from apex import amp
88 from apex .parallel import DistributedDataParallel as DDP
9+ from apex .parallel import convert_syncbn_model
910 has_apex = True
1011except ImportError :
1112 has_apex = False
1213
1314from data import Dataset , create_loader , resolve_data_config , FastCollateMixup , mixup_target
14- from models import create_model , resume_checkpoint
15+ from models import create_model , resume_checkpoint , load_checkpoint
1516from utils import *
1617from loss import LabelSmoothingCrossEntropy , SoftTargetCrossEntropy
1718from optim import create_optimizer
9192 help = 'BatchNorm momentum override (if not None)' )
9293parser .add_argument ('--bn-eps' , type = float , default = None ,
9394 help = 'BatchNorm epsilon override (if not None)' )
95+ parser .add_argument ('--model-ema' , action = 'store_true' , default = False ,
96+ help = 'Enable tracking moving average of model weights' )
97+ parser .add_argument ('--model-ema-force-cpu' , action = 'store_true' , default = False ,
98+ help = 'Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.' )
99+ parser .add_argument ('--model-ema-decay' , type = float , default = 0.9998 ,
100+ help = 'decay factor for model weights moving average (default: 0.9998)' )
94101parser .add_argument ('--seed' , type = int , default = 42 , metavar = 'S' ,
95102 help = 'random seed (default: 42)' )
96103parser .add_argument ('--log-interval' , type = int , default = 50 , metavar = 'N' ,
97104 help = 'how many batches to wait before logging training status' )
98- parser .add_argument ('--recovery-interval' , type = int , default = 1000 , metavar = 'N' ,
105+ parser .add_argument ('--recovery-interval' , type = int , default = 0 , metavar = 'N' ,
99106 help = 'how many batches to wait before writing recovery checkpoint' )
100107parser .add_argument ('-j' , '--workers' , type = int , default = 4 , metavar = 'N' ,
101108 help = 'how many training processes to use (default: 1)' )
109116 help = 'save images of input bathes every log interval for debugging' )
110117parser .add_argument ('--amp' , action = 'store_true' , default = False ,
111118 help = 'use NVIDIA amp for mixed precision training' )
119+ parser .add_argument ('--sync-bn' , action = 'store_true' ,
120+ help = 'enabling apex sync BN.' )
112121parser .add_argument ('--no-prefetcher' , action = 'store_true' , default = False ,
113122 help = 'disable fast prefetcher' )
114123parser .add_argument ('--output' , default = '' , type = str , metavar = 'PATH' ,
@@ -131,31 +140,28 @@ def main():
131140
132141 args .device = 'cuda:0'
133142 args .world_size = 1
134- r = - 1
143+ args . rank = 0 # global rank
135144 if args .distributed :
136145 args .num_gpu = 1
137146 args .device = 'cuda:%d' % args .local_rank
138147 torch .cuda .set_device (args .local_rank )
139- torch .distributed .init_process_group (backend = 'nccl' ,
140- init_method = 'env://' )
148+ torch .distributed .init_process_group (
149+ backend = 'nccl' , init_method = 'env://' )
141150 args .world_size = torch .distributed .get_world_size ()
142- r = torch .distributed .get_rank ()
151+ args .rank = torch .distributed .get_rank ()
152+ assert args .rank >= 0
143153
144154 if args .distributed :
145155 print ('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
146- % (r , args .world_size ))
156+ % (args . rank , args .world_size ))
147157 else :
148158 print ('Training with a single process on %d GPUs.' % args .num_gpu )
149159
150- # FIXME seed handling for multi-process distributed?
151- torch .manual_seed (args .seed )
160+ torch .manual_seed (args .seed + args .rank )
152161
153162 output_dir = ''
154163 if args .local_rank == 0 :
155- if args .output :
156- output_base = args .output
157- else :
158- output_base = './output'
164+ output_base = args .output if args .output else './output'
159165 exp_name = '-' .join ([
160166 datetime .now ().strftime ("%Y%m%d-%H%M%S" ),
161167 args .model ,
@@ -191,6 +197,8 @@ def main():
191197 args .amp = False
192198 model = nn .DataParallel (model , device_ids = list (range (args .num_gpu ))).cuda ()
193199 else :
200+ if args .distributed and args .sync_bn and has_apex :
201+ model = convert_syncbn_model (model )
194202 model .cuda ()
195203
196204 optimizer = create_optimizer (args , model )
@@ -205,8 +213,20 @@ def main():
205213 use_amp = False
206214 print ('AMP disabled' )
207215
216+ model_ema = None
217+ if args .model_ema :
218+ model_ema = ModelEma (
219+ model ,
220+ decay = args .model_ema_decay ,
221+ device = 'cpu' if args .model_ema_force_cpu else '' ,
222+ resume = args .resume )
223+
208224 if args .distributed :
209225 model = DDP (model , delay_allreduce = True )
226+ if model_ema is not None and not args .model_ema_force_cpu :
227+ # must also distribute EMA model to allow validation
228+ model_ema .ema = DDP (model_ema .ema , delay_allreduce = True )
229+ model_ema .ema_has_module = True
210230
211231 lr_scheduler , num_epochs = create_scheduler (args , optimizer )
212232 if start_epoch > 0 :
@@ -273,6 +293,7 @@ def main():
273293 eval_metric = args .eval_metric
274294 saver = None
275295 if output_dir :
296+ # only set if process is rank 0
276297 decreasing = True if eval_metric == 'loss' else False
277298 saver = CheckpointSaver (checkpoint_dir = output_dir , decreasing = decreasing )
278299 best_metric = None
@@ -284,10 +305,15 @@ def main():
284305
285306 train_metrics = train_epoch (
286307 epoch , model , loader_train , optimizer , train_loss_fn , args ,
287- lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir , use_amp = use_amp )
308+ lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir ,
309+ use_amp = use_amp , model_ema = model_ema )
310+
311+ eval_metrics = validate (model , loader_eval , validate_loss_fn , args )
288312
289- eval_metrics = validate (
290- model , loader_eval , validate_loss_fn , args )
313+ if model_ema is not None and not args .model_ema_force_cpu :
314+ ema_eval_metrics = validate (
315+ model_ema .ema , loader_eval , validate_loss_fn , args , log_suffix = ' (EMA)' )
316+ eval_metrics = ema_eval_metrics
291317
292318 if lr_scheduler is not None :
293319 lr_scheduler .step (epoch , eval_metrics [eval_metric ])
@@ -298,15 +324,12 @@ def main():
298324
299325 if saver is not None :
300326 # save proper checkpoint with eval metric
301- best_metric , best_epoch = saver .save_checkpoint ({
302- 'epoch' : epoch + 1 ,
303- 'arch' : args .model ,
304- 'state_dict' : model .state_dict (),
305- 'optimizer' : optimizer .state_dict (),
306- 'args' : args ,
307- },
327+ save_metric = eval_metrics [eval_metric ]
328+ best_metric , best_epoch = saver .save_checkpoint (
329+ model , optimizer , args ,
308330 epoch = epoch + 1 ,
309- metric = eval_metrics [eval_metric ])
331+ model_ema = model_ema ,
332+ metric = save_metric )
310333
311334 except KeyboardInterrupt :
312335 pass
@@ -316,7 +339,7 @@ def main():
316339
317340def train_epoch (
318341 epoch , model , loader , optimizer , loss_fn , args ,
319- lr_scheduler = None , saver = None , output_dir = '' , use_amp = False ):
342+ lr_scheduler = None , saver = None , output_dir = '' , use_amp = False , model_ema = None ):
320343
321344 if args .prefetcher and args .mixup > 0 and loader .mixup_enabled :
322345 if args .mixup_off_epoch and epoch >= args .mixup_off_epoch :
@@ -359,6 +382,8 @@ def train_epoch(
359382 optimizer .step ()
360383
361384 torch .cuda .synchronize ()
385+ if model_ema is not None :
386+ model_ema .update (model )
362387 num_updates += 1
363388
364389 batch_time_m .update (time .time () - end )
@@ -394,18 +419,11 @@ def train_epoch(
394419 padding = 0 ,
395420 normalize = True )
396421
397- if args . local_rank == 0 and (
398- saver is not None and last_batch or (batch_idx + 1 ) % args .recovery_interval == 0 ):
422+ if saver is not None and args . recovery_interval and (
423+ last_batch or (batch_idx + 1 ) % args .recovery_interval == 0 ):
399424 save_epoch = epoch + 1 if last_batch else epoch
400- saver .save_recovery ({
401- 'epoch' : save_epoch ,
402- 'arch' : args .model ,
403- 'state_dict' : model .state_dict (),
404- 'optimizer' : optimizer .state_dict (),
405- 'args' : args ,
406- },
407- epoch = save_epoch ,
408- batch_idx = batch_idx )
425+ saver .save_recovery (
426+ model , optimizer , args , save_epoch , model_ema = model_ema , batch_idx = batch_idx )
409427
410428 if lr_scheduler is not None :
411429 lr_scheduler .step_update (num_updates = num_updates , metric = losses_m .avg )
@@ -415,7 +433,7 @@ def train_epoch(
415433 return OrderedDict ([('loss' , losses_m .avg )])
416434
417435
418- def validate (model , loader , loss_fn , args ):
436+ def validate (model , loader , loss_fn , args , log_suffix = '' ):
419437 batch_time_m = AverageMeter ()
420438 losses_m = AverageMeter ()
421439 prec1_m = AverageMeter ()
@@ -461,12 +479,13 @@ def validate(model, loader, loss_fn, args):
461479 batch_time_m .update (time .time () - end )
462480 end = time .time ()
463481 if args .local_rank == 0 and (last_batch or batch_idx % args .log_interval == 0 ):
464- print ('Test: [{0}/{1}]\t '
482+ log_name = 'Test' + log_suffix
483+ print ('{0}: [{1}/{2}]\t '
465484 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
466485 'Loss {loss.val:.4f} ({loss.avg:.4f}) '
467486 'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
468487 'Prec@5 {top5.val:.4f} ({top5.avg:.4f})' .format (
469- batch_idx , last_idx ,
488+ log_name , batch_idx , last_idx ,
470489 batch_time = batch_time_m , loss = losses_m ,
471490 top1 = prec1_m , top5 = prec5_m ))
472491
@@ -475,12 +494,5 @@ def validate(model, loader, loss_fn, args):
475494 return metrics
476495
477496
478- def reduce_tensor (tensor , n ):
479- rt = tensor .clone ()
480- dist .all_reduce (rt , op = dist .ReduceOp .SUM )
481- rt /= n
482- return rt
483-
484-
485497if __name__ == '__main__' :
486498 main ()
0 commit comments