Skip to content

Commit ad150e7

Browse files
committed
Update results csv file rank/diff script and small validate script tweak for batch validation
1 parent d72ddaf commit ad150e7

File tree

2 files changed

+55
-37
lines changed

2 files changed

+55
-37
lines changed

results/generate_csv_results.py

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,75 @@
11
import numpy as np
22
import pandas as pd
33

4+
45
results = {
5-
'results-imagenet.csv': pd.read_csv('results-imagenet.csv'),
6-
'results-imagenetv2-matched-frequency.csv': pd.read_csv('results-imagenetv2-matched-frequency.csv'),
7-
'results-sketch.csv': pd.read_csv('results-sketch.csv'),
8-
'results-imagenet-a.csv': pd.read_csv('results-imagenet-a.csv'),
9-
'results-imagenet-r.csv': pd.read_csv('results-imagenet-r.csv'),
10-
'results-imagenet-real.csv': pd.read_csv('results-imagenet-real.csv'),
6+
'results-imagenet.csv': [
7+
'results-imagenet-real.csv',
8+
'results-imagenetv2-matched-frequency.csv',
9+
'results-sketch.csv'
10+
],
11+
'results-imagenet-a-clean.csv': [
12+
'results-imagenet-a.csv',
13+
],
14+
'results-imagenet-r-clean.csv': [
15+
'results-imagenet-r.csv',
16+
],
1117
}
1218

1319

14-
def diff(csv_file):
15-
base_models = results['results-imagenet.csv']['model'].values
16-
csv_models = results[csv_file]['model'].values
20+
def diff(base_df, test_csv):
21+
base_models = base_df['model'].values
22+
test_df = pd.read_csv(test_csv)
23+
test_models = test_df['model'].values
1724

18-
rank_diff = np.zeros_like(csv_models, dtype='object')
19-
top1_diff = np.zeros_like(csv_models, dtype='object')
20-
top5_diff = np.zeros_like(csv_models, dtype='object')
25+
rank_diff = np.zeros_like(test_models, dtype='object')
26+
top1_diff = np.zeros_like(test_models, dtype='object')
27+
top5_diff = np.zeros_like(test_models, dtype='object')
2128

22-
for rank, model in enumerate(csv_models):
29+
for rank, model in enumerate(test_models):
2330
if model in base_models:
24-
base_rank = int(np.where(base_models==model)[0])
25-
top1_d = results[csv_file]['top1'][rank]-results['results-imagenet.csv']['top1'][base_rank]
26-
top5_d = results[csv_file]['top5'][rank]-results['results-imagenet.csv']['top5'][base_rank]
31+
base_rank = int(np.where(base_models == model)[0])
32+
top1_d = test_df['top1'][rank] - base_df['top1'][base_rank]
33+
top5_d = test_df['top5'][rank] - base_df['top5'][base_rank]
2734

2835
# rank_diff
29-
if rank == base_rank: rank_diff[rank] = f'='
30-
elif rank > base_rank: rank_diff[rank] = f'-{rank-base_rank}'
31-
else: rank_diff[rank] = f'+{base_rank-rank}'
36+
if rank == base_rank:
37+
rank_diff[rank] = f'0'
38+
elif rank > base_rank:
39+
rank_diff[rank] = f'-{rank - base_rank}'
40+
else:
41+
rank_diff[rank] = f'+{base_rank - rank}'
3242

3343
# top1_diff
34-
if top1_d >= .0: top1_diff[rank] = f'+{top1_d:.3f}'
35-
else : top1_diff[rank] = f'-{abs(top1_d):.3f}'
44+
if top1_d >= .0:
45+
top1_diff[rank] = f'+{top1_d:.3f}'
46+
else:
47+
top1_diff[rank] = f'-{abs(top1_d):.3f}'
3648

3749
# top5_diff
38-
if top5_d >= .0: top5_diff[rank] = f'+{top5_d:.3f}'
39-
else : top5_diff[rank] = f'-{abs(top5_d):.3f}'
50+
if top5_d >= .0:
51+
top5_diff[rank] = f'+{top5_d:.3f}'
52+
else:
53+
top5_diff[rank] = f'-{abs(top5_d):.3f}'
4054

4155
else:
42-
rank_diff[rank] = 'X'
43-
top1_diff[rank] = 'X'
44-
top5_diff[rank] = 'X'
45-
46-
results[csv_file]['rank_diff'] = rank_diff
47-
results[csv_file]['top1_diff'] = top1_diff
48-
results[csv_file]['top5_diff'] = top5_diff
49-
50-
results[csv_file]['param_count'] = results[csv_file]['param_count'].map('{:,.2f}'.format)
56+
rank_diff[rank] = ''
57+
top1_diff[rank] = ''
58+
top5_diff[rank] = ''
59+
60+
test_df['top1_diff'] = top1_diff
61+
test_df['top5_diff'] = top5_diff
62+
test_df['rank_diff'] = rank_diff
5163

52-
results[csv_file].to_csv(csv_file, index=False, float_format='%.3f')
64+
test_df['param_count'] = test_df['param_count'].map('{:,.2f}'.format)
65+
test_df.sort_values('top1', ascending=False, inplace=True)
66+
test_df.to_csv(test_csv, index=False, float_format='%.3f')
5367

5468

55-
for csv_file in results:
56-
if csv_file != 'results-imagenet.csv':
57-
diff(csv_file)
69+
for base_results, test_results in results.items():
70+
base_df = pd.read_csv(base_results)
71+
base_df.sort_values('top1', ascending=False, inplace=True)
72+
for test_csv in test_results:
73+
diff(base_df, test_csv)
74+
base_df['param_count'] = base_df['param_count'].map('{:,.2f}'.format)
75+
base_df.to_csv(base_results, index=False, float_format='%.3f')

validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def main():
271271
result = OrderedDict(model=args.model)
272272
r = {}
273273
while not r and batch_size >= args.num_gpu:
274+
torch.cuda.empty_cache()
274275
try:
275276
args.batch_size = batch_size
276277
print('Validating with batch size: %d' % args.batch_size)
@@ -281,7 +282,6 @@ def main():
281282
raise e
282283
batch_size = max(batch_size // 2, args.num_gpu)
283284
print("Validation failed, reducing batch size by 50%")
284-
torch.cuda.empty_cache()
285285
result.update(r)
286286
if args.checkpoint:
287287
result['checkpoint'] = args.checkpoint

0 commit comments

Comments
 (0)