Skip to content

Commit e15e68d

Browse files
committed
Fix #566, summary.csv writing to pwd on local_rank != 0. Tweak benchmark mem handling to see if it reduces likelihood of 'bad' exceptions on OOM.
1 parent 1b0c8e7 commit e15e68d

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,14 +374,14 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
374374
batch_size = initial_batch_size
375375
results = dict()
376376
while batch_size >= 1:
377+
torch.cuda.empty_cache()
377378
try:
378379
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
379380
results = bench.run()
380381
return results
381382
except RuntimeError as e:
382-
torch.cuda.empty_cache()
383-
batch_size = decay_batch_exp(batch_size)
384383
print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.')
384+
batch_size = decay_batch_exp(batch_size)
385385
return results
386386

387387

train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def main():
560560
best_metric = None
561561
best_epoch = None
562562
saver = None
563-
output_dir = ''
563+
output_dir = None
564564
if args.local_rank == 0:
565565
if args.experiment:
566566
exp_name = args.experiment
@@ -606,9 +606,10 @@ def main():
606606
# step LR for next epoch
607607
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
608608

609-
update_summary(
610-
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
611-
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
609+
if output_dir is not None:
610+
update_summary(
611+
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
612+
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
612613

613614
if saver is not None:
614615
# save proper checkpoint with eval metric
@@ -623,7 +624,7 @@ def main():
623624

624625
def train_one_epoch(
625626
epoch, model, loader, optimizer, loss_fn, args,
626-
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
627+
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
627628
loss_scaler=None, model_ema=None, mixup_fn=None):
628629

629630
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:

0 commit comments

Comments
 (0)