4848 help = 'Number of GPUS to use' )
4949parser .add_argument ('--no-test-pool' , dest = 'no_test_pool' , action = 'store_true' ,
5050 help = 'disable test time pool' )
51+ parser .add_argument ('--topk' , default = 5 , type = int ,
52+ metavar = 'N' , help = 'Top-k to output to CSV' )
5153
5254
5355def main ():
@@ -85,15 +87,16 @@ def main():
8587
8688 model .eval ()
8789
90+ k = min (args .topk , args .num_classes )
8891 batch_time = AverageMeter ()
8992 end = time .time ()
90- top5_ids = []
93+ topk_ids = []
9194 with torch .no_grad ():
9295 for batch_idx , (input , _ ) in enumerate (loader ):
9396 input = input .cuda ()
9497 labels = model (input )
95- top5 = labels .topk (5 )[1 ]
96- top5_ids .append (top5 .cpu ().numpy ())
98+ topk = labels .topk (k )[1 ]
99+ topk_ids .append (topk .cpu ().numpy ())
97100
98101 # measure elapsed time
99102 batch_time .update (time .time () - end )
@@ -104,11 +107,11 @@ def main():
104107 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})' .format (
105108 batch_idx , len (loader ), batch_time = batch_time ))
106109
107- top5_ids = np .concatenate (top5_ids , axis = 0 ).squeeze ()
110+ topk_ids = np .concatenate (topk_ids , axis = 0 ).squeeze ()
108111
109- with open (os .path .join (args .output_dir , './top5_ids .csv' ), 'w' ) as out_file :
112+ with open (os .path .join (args .output_dir , './topk_ids .csv' ), 'w' ) as out_file :
110113 filenames = loader .dataset .filenames ()
111- for filename , label in zip (filenames , top5_ids ):
114+ for filename , label in zip (filenames , topk_ids ):
112115 filename = os .path .basename (filename )
113116 out_file .write ('{0},{1},{2},{3},{4},{5}\n ' .format (
114117 filename , label [0 ], label [1 ], label [2 ], label [3 ], label [4 ]))
0 commit comments