Skip to content

Commit 856f618

Browse files
add {rank,top1,top5}_diff in results
1 parent d72ac0d commit 856f618

File tree

4 files changed

+743
-693
lines changed

4 files changed

+743
-693
lines changed

results/generate_csv_results.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
results = {
5+
'results-imagenet.csv' : pd.read_csv('results-imagenet.csv'),
6+
'results-imagenetv2-matched-frequency.csv': pd.read_csv('results-imagenetv2-matched-frequency.csv'),
7+
'results-sketch.csv' : pd.read_csv('results-sketch.csv'),
8+
'results-imagenet-a.csv' : pd.read_csv('results-imagenet-a.csv'),
9+
}
10+
11+
def diff(csv_file):
12+
base_models = results['results-imagenet.csv']['model'].values
13+
csv_models = results[csv_file]['model'].values
14+
15+
rank_diff = np.zeros_like(csv_models, dtype='object')
16+
top1_diff = np.zeros_like(csv_models, dtype='object')
17+
top5_diff = np.zeros_like(csv_models, dtype='object')
18+
19+
for rank, model in enumerate(csv_models):
20+
if model in base_models:
21+
base_rank = int(np.where(base_models==model)[0])
22+
top1_d = results[csv_file]['top1'][rank]-results['results-imagenet.csv']['top1'][base_rank]
23+
top5_d = results[csv_file]['top5'][rank]-results['results-imagenet.csv']['top5'][base_rank]
24+
25+
# rank_diff
26+
if rank == base_rank: rank_diff[rank] = f'='
27+
elif rank > base_rank: rank_diff[rank] = f'-{rank-base_rank}'
28+
else: rank_diff[rank] = f'+{base_rank-rank}'
29+
30+
# top1_diff
31+
if top1_d >= .0: top1_diff[rank] = f'+{top1_d:.3f}'
32+
else : top1_diff[rank] = f'-{abs(top1_d):.3f}'
33+
34+
# top5_diff
35+
if top5_d >= .0: top5_diff[rank] = f'+{top5_d:.3f}'
36+
else : top5_diff[rank] = f'-{abs(top5_d):.3f}'
37+
38+
else:
39+
rank_diff[rank] = 'X'
40+
top1_diff[rank] = 'X'
41+
top5_diff[rank] = 'X'
42+
43+
results[csv_file]['rank_diff'] = rank_diff
44+
results[csv_file]['top1_diff'] = top1_diff
45+
results[csv_file]['top5_diff'] = top5_diff
46+
47+
results[csv_file].to_csv(csv_file, index=False, float_format='%.4f')
48+
49+
for csv_file in results:
50+
if csv_file != 'results-imagenet.csv': diff(csv_file)

0 commit comments

Comments
 (0)