Skip to content

Commit 28e0152

Browse files
committed
Add --no-retry flag to benchmark.py to skip batch_size decay and retry on error. Fix #1226. Update deepspeed profile usage for latest DS releases. Fix # 1333
1 parent db0cee9 commit 28e0152

File tree

1 file changed

+43
-10
lines changed

1 file changed

+43
-10
lines changed

benchmark.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
7272
parser.add_argument('--detail', action='store_true', default=False,
7373
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
74+
parser.add_argument('--no-retry', action='store_true', default=False,
75+
help='Do not decay batch size and retry on error.')
7476
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
7577
help='Output csv file for validation results (summary)')
7678
parser.add_argument('--num-warm-iter', default=10, type=int,
@@ -169,10 +171,9 @@ def resolve_precision(precision: str):
169171

170172

171173
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
172-
macs, _ = get_model_profile(
174+
_, macs, _ = get_model_profile(
173175
model=model,
174-
input_res=(batch_size,) + input_size, # input shape or input to the input_constructor
175-
input_constructor=None, # if specified, a constructor taking input_res is used as input to the model
176+
input_shape=(batch_size,) + input_size, # input shape/resolution
176177
print_profile=detailed, # prints the model graph with the measured profile attached to each module
177178
detailed=detailed, # print the detailed profile
178179
warm_up=10, # the number of warm-ups before measuring the time of each module
@@ -197,8 +198,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
197198

198199
class BenchmarkRunner:
199200
def __init__(
200-
self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32',
201-
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
201+
self,
202+
model_name,
203+
detail=False,
204+
device='cuda',
205+
torchscript=False,
206+
aot_autograd=False,
207+
precision='float32',
208+
fuser='',
209+
num_warm_iter=10,
210+
num_bench_iter=50,
211+
use_train_size=False,
212+
**kwargs
213+
):
202214
self.model_name = model_name
203215
self.detail = detail
204216
self.device = device
@@ -256,7 +268,13 @@ def _init_input(self):
256268

257269
class InferenceBenchmarkRunner(BenchmarkRunner):
258270

259-
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
271+
def __init__(
272+
self,
273+
model_name,
274+
device='cuda',
275+
torchscript=False,
276+
**kwargs
277+
):
260278
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
261279
self.model.eval()
262280

@@ -325,7 +343,13 @@ def _step():
325343

326344
class TrainBenchmarkRunner(BenchmarkRunner):
327345

328-
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
346+
def __init__(
347+
self,
348+
model_name,
349+
device='cuda',
350+
torchscript=False,
351+
**kwargs
352+
):
329353
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
330354
self.model.train()
331355

@@ -492,7 +516,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
492516
return max(0, int(out_batch_size))
493517

494518

495-
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
519+
def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False):
496520
batch_size = initial_batch_size
497521
results = dict()
498522
error_str = 'Unknown'
@@ -507,8 +531,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
507531
if 'channels_last' in error_str:
508532
_logger.error(f'{model_name} not supported in channels_last, skipping.')
509533
break
510-
_logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
534+
_logger.error(f'"{error_str}" while running benchmark.')
535+
if no_batch_size_retry:
536+
break
511537
batch_size = decay_batch_exp(batch_size)
538+
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
512539
results['error'] = error_str
513540
return results
514541

@@ -550,7 +577,13 @@ def benchmark(args):
550577

551578
model_results = OrderedDict(model=model)
552579
for prefix, bench_fn in zip(prefixes, bench_fns):
553-
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
580+
run_results = _try_run(
581+
model,
582+
bench_fn,
583+
bench_kwargs=bench_kwargs,
584+
initial_batch_size=batch_size,
585+
no_batch_size_retry=args.no_retry,
586+
)
554587
if prefix and 'error' not in run_results:
555588
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
556589
model_results.update(run_results)

0 commit comments

Comments
 (0)