Skip to content

Commit c1a84ec

Browse files
committed
dataset not passed through PrefetchLoader for inference script. Fix #10
* also, make top5 configurable for lower class count cases
1 parent 2060e43 commit c1a84ec

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

data/loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def __len__(self):
6161
def sampler(self):
6262
return self.loader.sampler
6363

64+
@property
65+
def dataset(self):
66+
return self.loader.dataset
67+
6468
@property
6569
def mixup_enabled(self):
6670
if isinstance(self.loader.collate_fn, FastCollateMixup):

inference.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
help='Number of GPUS to use')
4949
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
5050
help='disable test time pool')
51+
parser.add_argument('--topk', default=5, type=int,
52+
metavar='N', help='Top-k to output to CSV')
5153

5254

5355
def main():
@@ -85,15 +87,16 @@ def main():
8587

8688
model.eval()
8789

90+
k = min(args.topk, args.num_classes)
8891
batch_time = AverageMeter()
8992
end = time.time()
90-
top5_ids = []
93+
topk_ids = []
9194
with torch.no_grad():
9295
for batch_idx, (input, _) in enumerate(loader):
9396
input = input.cuda()
9497
labels = model(input)
95-
top5 = labels.topk(5)[1]
96-
top5_ids.append(top5.cpu().numpy())
98+
topk = labels.topk(k)[1]
99+
topk_ids.append(topk.cpu().numpy())
97100

98101
# measure elapsed time
99102
batch_time.update(time.time() - end)
@@ -104,11 +107,11 @@ def main():
104107
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
105108
batch_idx, len(loader), batch_time=batch_time))
106109

107-
top5_ids = np.concatenate(top5_ids, axis=0).squeeze()
110+
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
108111

109-
with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file:
112+
with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
110113
filenames = loader.dataset.filenames()
111-
for filename, label in zip(filenames, top5_ids):
114+
for filename, label in zip(filenames, topk_ids):
112115
filename = os.path.basename(filename)
113116
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
114117
filename, label[0], label[1], label[2], label[3], label[4]))

0 commit comments

Comments
 (0)