Skip to content

Commit d667351

Browse files
committed
Tweak accuracy topk safety. Fix #807
1 parent 35c9740 commit d667351

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

timm/utils/metrics.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
5-
import torch
65

76

87
class AverageMeter:
@@ -30,7 +29,4 @@ def accuracy(output, target, topk=(1,)):
3029
_, pred = output.topk(maxk, 1, True, True)
3130
pred = pred.t()
3231
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
33-
return [
34-
correct[:k].reshape(-1).float().sum(0) * 100. / batch_size
35-
if k <= maxk else torch.tensor(100.) for k in topk
36-
]
32+
return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]

0 commit comments

Comments
 (0)