7171 help = "Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'" )
7272parser .add_argument ('--detail' , action = 'store_true' , default = False ,
7373 help = 'Provide train fwd/bwd/opt breakdown detail if True. Defaults to False' )
74+ parser .add_argument ('--no-retry' , action = 'store_true' , default = False ,
75+ help = 'Do not decay batch size and retry on error.' )
7476parser .add_argument ('--results-file' , default = '' , type = str , metavar = 'FILENAME' ,
7577 help = 'Output csv file for validation results (summary)' )
7678parser .add_argument ('--num-warm-iter' , default = 10 , type = int ,
@@ -169,10 +171,9 @@ def resolve_precision(precision: str):
169171
170172
171173def profile_deepspeed (model , input_size = (3 , 224 , 224 ), batch_size = 1 , detailed = False ):
172- macs , _ = get_model_profile (
174+ _ , macs , _ = get_model_profile (
173175 model = model ,
174- input_res = (batch_size ,) + input_size , # input shape or input to the input_constructor
175- input_constructor = None , # if specified, a constructor taking input_res is used as input to the model
176+ input_shape = (batch_size ,) + input_size , # input shape/resolution
176177 print_profile = detailed , # prints the model graph with the measured profile attached to each module
177178 detailed = detailed , # print the detailed profile
178179 warm_up = 10 , # the number of warm-ups before measuring the time of each module
@@ -197,8 +198,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
197198
198199class BenchmarkRunner :
199200 def __init__ (
200- self , model_name , detail = False , device = 'cuda' , torchscript = False , aot_autograd = False , precision = 'float32' ,
201- fuser = '' , num_warm_iter = 10 , num_bench_iter = 50 , use_train_size = False , ** kwargs ):
201+ self ,
202+ model_name ,
203+ detail = False ,
204+ device = 'cuda' ,
205+ torchscript = False ,
206+ aot_autograd = False ,
207+ precision = 'float32' ,
208+ fuser = '' ,
209+ num_warm_iter = 10 ,
210+ num_bench_iter = 50 ,
211+ use_train_size = False ,
212+ ** kwargs
213+ ):
202214 self .model_name = model_name
203215 self .detail = detail
204216 self .device = device
@@ -256,7 +268,13 @@ def _init_input(self):
256268
257269class InferenceBenchmarkRunner (BenchmarkRunner ):
258270
259- def __init__ (self , model_name , device = 'cuda' , torchscript = False , ** kwargs ):
271+ def __init__ (
272+ self ,
273+ model_name ,
274+ device = 'cuda' ,
275+ torchscript = False ,
276+ ** kwargs
277+ ):
260278 super ().__init__ (model_name = model_name , device = device , torchscript = torchscript , ** kwargs )
261279 self .model .eval ()
262280
@@ -325,7 +343,13 @@ def _step():
325343
326344class TrainBenchmarkRunner (BenchmarkRunner ):
327345
328- def __init__ (self , model_name , device = 'cuda' , torchscript = False , ** kwargs ):
346+ def __init__ (
347+ self ,
348+ model_name ,
349+ device = 'cuda' ,
350+ torchscript = False ,
351+ ** kwargs
352+ ):
329353 super ().__init__ (model_name = model_name , device = device , torchscript = torchscript , ** kwargs )
330354 self .model .train ()
331355
@@ -492,7 +516,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
492516 return max (0 , int (out_batch_size ))
493517
494518
495- def _try_run (model_name , bench_fn , initial_batch_size , bench_kwargs ):
519+ def _try_run (model_name , bench_fn , bench_kwargs , initial_batch_size , no_batch_size_retry = False ):
496520 batch_size = initial_batch_size
497521 results = dict ()
498522 error_str = 'Unknown'
@@ -507,8 +531,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
507531 if 'channels_last' in error_str :
508532 _logger .error (f'{ model_name } not supported in channels_last, skipping.' )
509533 break
510- _logger .warning (f'"{ error_str } " while running benchmark. Reducing batch size to { batch_size } for retry.' )
534+ _logger .error (f'"{ error_str } " while running benchmark.' )
535+ if no_batch_size_retry :
536+ break
511537 batch_size = decay_batch_exp (batch_size )
538+ _logger .warning (f'Reducing batch size to { batch_size } for retry.' )
512539 results ['error' ] = error_str
513540 return results
514541
@@ -550,7 +577,13 @@ def benchmark(args):
550577
551578 model_results = OrderedDict (model = model )
552579 for prefix , bench_fn in zip (prefixes , bench_fns ):
553- run_results = _try_run (model , bench_fn , initial_batch_size = batch_size , bench_kwargs = bench_kwargs )
580+ run_results = _try_run (
581+ model ,
582+ bench_fn ,
583+ bench_kwargs = bench_kwargs ,
584+ initial_batch_size = batch_size ,
585+ no_batch_size_retry = args .no_retry ,
586+ )
554587 if prefix and 'error' not in run_results :
555588 run_results = {'_' .join ([prefix , k ]): v for k , v in run_results .items ()}
556589 model_results .update (run_results )
0 commit comments