Skip to content

Commit d3ee3de

Browse files
committed
Update validation script first batch prime and clear cuda cache between multi-model runs
1 parent 0aca083 commit d3ee3de

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

validate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def validate(args):
145145
model.eval()
146146
with torch.no_grad():
147147
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
148-
model(torch.randn((args.batch_size,) + data_config['input_size']).cuda())
148+
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
149+
model(input)
149150
end = time.time()
150151
for i, (input, target) in enumerate(loader):
151152
if args.no_prefetcher:
@@ -238,6 +239,7 @@ def main():
238239
raise e
239240
batch_size = max(batch_size // 2, args.num_gpu)
240241
print("Validation failed, reducing batch size by 50%")
242+
torch.cuda.empty_cache()
241243
result.update(r)
242244
if args.checkpoint:
243245
result['checkpoint'] = args.checkpoint

0 commit comments

Comments
 (0)