66try :
77 from apex import amp
88 from apex .parallel import DistributedDataParallel as DDP
9+ from apex .parallel import convert_syncbn_model
910 has_apex = True
1011except ImportError :
1112 has_apex = False
1213
1314from data import Dataset , create_loader , resolve_data_config , FastCollateMixup , mixup_target
14- from models import create_model , resume_checkpoint
15+ from models import create_model , resume_checkpoint , load_checkpoint
1516from utils import *
1617from loss import LabelSmoothingCrossEntropy , SoftTargetCrossEntropy
1718from optim import create_optimizer
4142 help = 'Test/inference time augmentation (oversampling) factor. 0=None (default: 0)' )
4243parser .add_argument ('--pretrained' , action = 'store_true' , default = False ,
4344 help = 'Start with pretrained version of specified network (if avail)' )
44- parser .add_argument ('--img-size' , type = int , default = 224 , metavar = 'N' ,
45- help = 'Image patch size (default: 224 )' )
45+ parser .add_argument ('--img-size' , type = int , default = None , metavar = 'N' ,
46+ help = 'Image patch size (default: None => model default )' )
4647parser .add_argument ('--mean' , type = float , nargs = '+' , default = None , metavar = 'MEAN' ,
4748 help = 'Override mean pixel value of dataset' )
4849parser .add_argument ('--std' , type = float , nargs = '+' , default = None , metavar = 'STD' ,
9192 help = 'BatchNorm momentum override (if not None)' )
9293parser .add_argument ('--bn-eps' , type = float , default = None ,
9394 help = 'BatchNorm epsilon override (if not None)' )
95+ parser .add_argument ('--model-ema' , action = 'store_true' , default = False ,
96+ help = 'Enable tracking moving average of model weights' )
97+ parser .add_argument ('--model-ema-force-cpu' , action = 'store_true' , default = False ,
98+ help = 'Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.' )
99+ parser .add_argument ('--model-ema-decay' , type = float , default = 0.9998 ,
100+ help = 'decay factor for model weights moving average (default: 0.9998)' )
94101parser .add_argument ('--seed' , type = int , default = 42 , metavar = 'S' ,
95102 help = 'random seed (default: 42)' )
96103parser .add_argument ('--log-interval' , type = int , default = 50 , metavar = 'N' ,
97104 help = 'how many batches to wait before logging training status' )
98- parser .add_argument ('--recovery-interval' , type = int , default = 1000 , metavar = 'N' ,
105+ parser .add_argument ('--recovery-interval' , type = int , default = 0 , metavar = 'N' ,
99106 help = 'how many batches to wait before writing recovery checkpoint' )
100107parser .add_argument ('-j' , '--workers' , type = int , default = 4 , metavar = 'N' ,
101108 help = 'how many training processes to use (default: 1)' )
109116 help = 'save images of input bathes every log interval for debugging' )
110117parser .add_argument ('--amp' , action = 'store_true' , default = False ,
111118 help = 'use NVIDIA amp for mixed precision training' )
119+ parser .add_argument ('--sync-bn' , action = 'store_true' ,
120+ help = 'enabling apex sync BN.' )
112121parser .add_argument ('--no-prefetcher' , action = 'store_true' , default = False ,
113122 help = 'disable fast prefetcher' )
114123parser .add_argument ('--output' , default = '' , type = str , metavar = 'PATH' ,
@@ -131,36 +140,24 @@ def main():
131140
132141 args .device = 'cuda:0'
133142 args .world_size = 1
134- r = - 1
143+ args . rank = 0 # global rank
135144 if args .distributed :
136145 args .num_gpu = 1
137146 args .device = 'cuda:%d' % args .local_rank
138147 torch .cuda .set_device (args .local_rank )
139- torch .distributed .init_process_group (backend = 'nccl' ,
140- init_method = 'env://' )
148+ torch .distributed .init_process_group (
149+ backend = 'nccl' , init_method = 'env://' )
141150 args .world_size = torch .distributed .get_world_size ()
142- r = torch .distributed .get_rank ()
151+ args .rank = torch .distributed .get_rank ()
152+ assert args .rank >= 0
143153
144154 if args .distributed :
145155 print ('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
146- % (r , args .world_size ))
156+ % (args . rank , args .world_size ))
147157 else :
148158 print ('Training with a single process on %d GPUs.' % args .num_gpu )
149159
150- # FIXME seed handling for multi-process distributed?
151- torch .manual_seed (args .seed )
152-
153- output_dir = ''
154- if args .local_rank == 0 :
155- if args .output :
156- output_base = args .output
157- else :
158- output_base = './output'
159- exp_name = '-' .join ([
160- datetime .now ().strftime ("%Y%m%d-%H%M%S" ),
161- args .model ,
162- str (args .img_size )])
163- output_dir = get_outdir (output_base , 'train' , exp_name )
160+ torch .manual_seed (args .seed + args .rank )
164161
165162 model = create_model (
166163 args .model ,
@@ -191,6 +188,8 @@ def main():
191188 args .amp = False
192189 model = nn .DataParallel (model , device_ids = list (range (args .num_gpu ))).cuda ()
193190 else :
191+ if args .distributed and args .sync_bn and has_apex :
192+ model = convert_syncbn_model (model )
194193 model .cuda ()
195194
196195 optimizer = create_optimizer (args , model )
@@ -205,8 +204,20 @@ def main():
205204 use_amp = False
206205 print ('AMP disabled' )
207206
207+ model_ema = None
208+ if args .model_ema :
209+ model_ema = ModelEma (
210+ model ,
211+ decay = args .model_ema_decay ,
212+ device = 'cpu' if args .model_ema_force_cpu else '' ,
213+ resume = args .resume )
214+
208215 if args .distributed :
209216 model = DDP (model , delay_allreduce = True )
217+ if model_ema is not None and not args .model_ema_force_cpu :
218+ # must also distribute EMA model to allow validation
219+ model_ema .ema = DDP (model_ema .ema , delay_allreduce = True )
220+ model_ema .ema_has_module = True
210221
211222 lr_scheduler , num_epochs = create_scheduler (args , optimizer )
212223 if start_epoch > 0 :
@@ -271,23 +282,37 @@ def main():
271282 validate_loss_fn = train_loss_fn
272283
273284 eval_metric = args .eval_metric
285+ best_metric = None
286+ best_epoch = None
274287 saver = None
275- if output_dir :
288+ output_dir = ''
289+ if args .local_rank == 0 :
290+ output_base = args .output if args .output else './output'
291+ exp_name = '-' .join ([
292+ datetime .now ().strftime ("%Y%m%d-%H%M%S" ),
293+ args .model ,
294+ str (data_config ['input_size' ][- 1 ])
295+ ])
296+ output_dir = get_outdir (output_base , 'train' , exp_name )
276297 decreasing = True if eval_metric == 'loss' else False
277298 saver = CheckpointSaver (checkpoint_dir = output_dir , decreasing = decreasing )
278- best_metric = None
279- best_epoch = None
299+
280300 try :
281301 for epoch in range (start_epoch , num_epochs ):
282302 if args .distributed :
283303 loader_train .sampler .set_epoch (epoch )
284304
285305 train_metrics = train_epoch (
286306 epoch , model , loader_train , optimizer , train_loss_fn , args ,
287- lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir , use_amp = use_amp )
307+ lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir ,
308+ use_amp = use_amp , model_ema = model_ema )
309+
310+ eval_metrics = validate (model , loader_eval , validate_loss_fn , args )
288311
289- eval_metrics = validate (
290- model , loader_eval , validate_loss_fn , args )
312+ if model_ema is not None and not args .model_ema_force_cpu :
313+ ema_eval_metrics = validate (
314+ model_ema .ema , loader_eval , validate_loss_fn , args , log_suffix = ' (EMA)' )
315+ eval_metrics = ema_eval_metrics
291316
292317 if lr_scheduler is not None :
293318 lr_scheduler .step (epoch , eval_metrics [eval_metric ])
@@ -298,15 +323,12 @@ def main():
298323
299324 if saver is not None :
300325 # save proper checkpoint with eval metric
301- best_metric , best_epoch = saver .save_checkpoint ({
302- 'epoch' : epoch + 1 ,
303- 'arch' : args .model ,
304- 'state_dict' : model .state_dict (),
305- 'optimizer' : optimizer .state_dict (),
306- 'args' : args ,
307- },
326+ save_metric = eval_metrics [eval_metric ]
327+ best_metric , best_epoch = saver .save_checkpoint (
328+ model , optimizer , args ,
308329 epoch = epoch + 1 ,
309- metric = eval_metrics [eval_metric ])
330+ model_ema = model_ema ,
331+ metric = save_metric )
310332
311333 except KeyboardInterrupt :
312334 pass
@@ -316,7 +338,7 @@ def main():
316338
317339def train_epoch (
318340 epoch , model , loader , optimizer , loss_fn , args ,
319- lr_scheduler = None , saver = None , output_dir = '' , use_amp = False ):
341+ lr_scheduler = None , saver = None , output_dir = '' , use_amp = False , model_ema = None ):
320342
321343 if args .prefetcher and args .mixup > 0 and loader .mixup_enabled :
322344 if args .mixup_off_epoch and epoch >= args .mixup_off_epoch :
@@ -359,6 +381,8 @@ def train_epoch(
359381 optimizer .step ()
360382
361383 torch .cuda .synchronize ()
384+ if model_ema is not None :
385+ model_ema .update (model )
362386 num_updates += 1
363387
364388 batch_time_m .update (time .time () - end )
@@ -394,18 +418,11 @@ def train_epoch(
394418 padding = 0 ,
395419 normalize = True )
396420
397- if args . local_rank == 0 and (
398- saver is not None and last_batch or (batch_idx + 1 ) % args .recovery_interval == 0 ):
421+ if saver is not None and args . recovery_interval and (
422+ last_batch or (batch_idx + 1 ) % args .recovery_interval == 0 ):
399423 save_epoch = epoch + 1 if last_batch else epoch
400- saver .save_recovery ({
401- 'epoch' : save_epoch ,
402- 'arch' : args .model ,
403- 'state_dict' : model .state_dict (),
404- 'optimizer' : optimizer .state_dict (),
405- 'args' : args ,
406- },
407- epoch = save_epoch ,
408- batch_idx = batch_idx )
424+ saver .save_recovery (
425+ model , optimizer , args , save_epoch , model_ema = model_ema , batch_idx = batch_idx )
409426
410427 if lr_scheduler is not None :
411428 lr_scheduler .step_update (num_updates = num_updates , metric = losses_m .avg )
@@ -415,7 +432,7 @@ def train_epoch(
415432 return OrderedDict ([('loss' , losses_m .avg )])
416433
417434
418- def validate (model , loader , loss_fn , args ):
435+ def validate (model , loader , loss_fn , args , log_suffix = '' ):
419436 batch_time_m = AverageMeter ()
420437 losses_m = AverageMeter ()
421438 prec1_m = AverageMeter ()
@@ -461,12 +478,13 @@ def validate(model, loader, loss_fn, args):
461478 batch_time_m .update (time .time () - end )
462479 end = time .time ()
463480 if args .local_rank == 0 and (last_batch or batch_idx % args .log_interval == 0 ):
464- print ('Test: [{0}/{1}]\t '
481+ log_name = 'Test' + log_suffix
482+ print ('{0}: [{1}/{2}]\t '
465483 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
466484 'Loss {loss.val:.4f} ({loss.avg:.4f}) '
467485 'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
468486 'Prec@5 {top5.val:.4f} ({top5.avg:.4f})' .format (
469- batch_idx , last_idx ,
487+ log_name , batch_idx , last_idx ,
470488 batch_time = batch_time_m , loss = losses_m ,
471489 top1 = prec1_m , top5 = prec5_m ))
472490
@@ -475,12 +493,5 @@ def validate(model, loader, loss_fn, args):
475493 return metrics
476494
477495
478- def reduce_tensor (tensor , n ):
479- rt = tensor .clone ()
480- dist .all_reduce (rt , op = dist .ReduceOp .SUM )
481- rt /= n
482- return rt
483-
484-
485496if __name__ == '__main__' :
486497 main ()
0 commit comments