Skip to content

Commit 337702d

Browse files
authored
Merge pull request #2 from zhunzhong07/patch-1
Fix bug for prefetcher
2 parents e8cf619 + 1274873 commit 337702d

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,9 @@ def validate(model, loader, loss_fn, args):
428428
with torch.no_grad():
429429
for batch_idx, (input, target) in enumerate(loader):
430430
last_batch = batch_idx == last_idx
431+
if not args.prefetcher:
432+
input = input.cuda()
433+
target = target.cuda()
431434

432435
output = model(input)
433436
if isinstance(output, (tuple, list)):

0 commit comments

Comments
 (0)