5050 help = 'Number of GPUS to use' )
5151parser .add_argument ('--no-test-pool' , dest = 'no_test_pool' , action = 'store_true' ,
5252 help = 'disable test time pool' )
53- parser .add_argument ('--tf-preprocessing' , dest = 'tf_preprocessing' , action = 'store_true' ,
53+ parser .add_argument ('--no-prefetcher' , action = 'store_true' , default = False ,
54+ help = 'disable fast prefetcher' )
55+ parser .add_argument ('--fp16' , action = 'store_true' , default = False ,
56+ help = 'Use half precision (fp16)' )
57+ parser .add_argument ('--tf-preprocessing' , action = 'store_true' , default = False ,
5458 help = 'Use Tensorflow preprocessing pipeline (require CPU TF installed' )
5559parser .add_argument ('--use-ema' , dest = 'use_ema' , action = 'store_true' ,
5660 help = 'use ema version of weights if present' )
5963def validate (args ):
6064 # might as well try to validate something
6165 args .pretrained = args .pretrained or not args .checkpoint
66+ args .prefetcher = not args .no_prefetcher
6267
6368 # create model
6469 model = create_model (
@@ -81,19 +86,23 @@ def validate(args):
8186 else :
8287 model = model .cuda ()
8388
89+ if args .fp16 :
90+ model = model .half ()
91+
8492 criterion = nn .CrossEntropyLoss ().cuda ()
8593
8694 crop_pct = 1.0 if test_time_pool else data_config ['crop_pct' ]
8795 loader = create_loader (
8896 Dataset (args .data , load_bytes = args .tf_preprocessing ),
8997 input_size = data_config ['input_size' ],
9098 batch_size = args .batch_size ,
91- use_prefetcher = True ,
99+ use_prefetcher = args . prefetcher ,
92100 interpolation = data_config ['interpolation' ],
93101 mean = data_config ['mean' ],
94102 std = data_config ['std' ],
95103 num_workers = args .workers ,
96104 crop_pct = crop_pct ,
105+ fp16 = args .fp16 ,
97106 tf_preprocessing = args .tf_preprocessing )
98107
99108 batch_time = AverageMeter ()
@@ -105,8 +114,11 @@ def validate(args):
105114 end = time .time ()
106115 with torch .no_grad ():
107116 for i , (input , target ) in enumerate (loader ):
108- target = target .cuda ()
109- input = input .cuda ()
117+ if args .no_prefetcher :
118+ target = target .cuda ()
119+ input = input .cuda ()
120+ if args .fp16 :
121+ input = input .half ()
110122
111123 # compute output
112124 output = model (input )
@@ -125,7 +137,7 @@ def validate(args):
125137 if i % args .log_freq == 0 :
126138 logging .info (
127139 'Test: [{0:>4d}/{1}] '
128- 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
140+ 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s ) '
129141 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
130142 'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
131143 'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})' .format (
0 commit comments