Skip to content

Commit 907555f

Browse files
committed
Update model img-size/crop expansion for bulk runner
1 parent 4515a43 commit 907555f

File tree

1 file changed

+33
-19
lines changed

1 file changed

+33
-19
lines changed

bulk_runner.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,30 @@ def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]:
9393
return cmd, cmd_args
9494

9595

96+
def _get_model_cfgs(
97+
model_names,
98+
num_classes=None,
99+
expand_train_test=False,
100+
include_crop=True,
101+
):
102+
model_cfgs = []
103+
for n in model_names:
104+
pt_cfg = get_pretrained_cfg(n)
105+
if num_classes is not None and getattr(pt_cfg, 'num_classes', 0) != num_classes:
106+
continue
107+
model_cfgs.append((n, pt_cfg.input_size[-1], pt_cfg.crop_pct))
108+
if expand_train_test and pt_cfg.test_input_size is not None:
109+
if pt_cfg.test_crop_pct is not None:
110+
model_cfgs.append((n, pt_cfg.test_input_size[-1], pt_cfg.test_crop_pct))
111+
else:
112+
model_cfgs.append((n, pt_cfg.test_input_size[-1], pt_cfg.crop_pct))
113+
if include_crop:
114+
model_cfgs = [(n, {'img-size': r, 'crop-pct': cp}) for n, r, cp in sorted(model_cfgs)]
115+
else:
116+
model_cfgs = [(n, {'img-size': r}) for n, r, cp in sorted(model_cfgs)]
117+
return model_cfgs
118+
119+
96120
def main():
97121
args = parser.parse_args()
98122
cmd, cmd_args = cmd_from_args(args)
@@ -105,35 +129,25 @@ def main():
105129
model_cfgs = [(n, None) for n in model_names]
106130
elif args.model_list == 'all_in1k':
107131
model_names = list_models(pretrained=True)
108-
model_cfgs = []
109-
for n in model_names:
110-
pt_cfg = get_pretrained_cfg(n)
111-
if getattr(pt_cfg, 'num_classes', 0) == 1000:
112-
print(n, pt_cfg.num_classes)
113-
model_cfgs.append((n, None))
132+
model_cfgs = _get_model_cfgs(model_names, num_classes=1000, expand_train_test=True)
114133
elif args.model_list == 'all_res':
115134
model_names = list_models()
116-
model_names += list_models(pretrained=True)
117-
model_cfgs = set()
118-
for n in model_names:
119-
pt_cfg = get_pretrained_cfg(n)
120-
if pt_cfg is None:
121-
print(f'Model {n} is missing pretrained cfg, skipping.')
122-
continue
123-
n = n.split('.')[0]
124-
model_cfgs.add((n, pt_cfg.input_size[-1]))
125-
if pt_cfg.test_input_size is not None:
126-
model_cfgs.add((n, pt_cfg.test_input_size[-1]))
127-
model_cfgs = [(n, {'img-size': r}) for n, r in sorted(model_cfgs)]
135+
model_cfgs = _get_model_cfgs(model_names, expand_train_test=True, include_crop=False)
128136
elif not is_model(args.model_list):
129137
# model name doesn't exist, try as wildcard filter
130138
model_names = list_models(args.model_list)
131139
model_cfgs = [(n, None) for n in model_names]
132140

133141
if not model_cfgs and os.path.exists(args.model_list):
134142
with open(args.model_list) as f:
143+
model_cfgs = []
135144
model_names = [line.rstrip() for line in f]
136-
model_cfgs = [(n, None) for n in model_names]
145+
_get_model_cfgs(
146+
model_names,
147+
#num_classes=1000,
148+
expand_train_test=True,
149+
#include_crop=False,
150+
)
137151

138152
if len(model_cfgs):
139153
results_file = args.results_file or './results.csv'

0 commit comments

Comments
 (0)