@@ -93,6 +93,30 @@ def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]:
9393 return cmd , cmd_args
9494
9595
96+ def _get_model_cfgs (
97+ model_names ,
98+ num_classes = None ,
99+ expand_train_test = False ,
100+ include_crop = True ,
101+ ):
102+ model_cfgs = []
103+ for n in model_names :
104+ pt_cfg = get_pretrained_cfg (n )
105+ if num_classes is not None and getattr (pt_cfg , 'num_classes' , 0 ) != num_classes :
106+ continue
107+ model_cfgs .append ((n , pt_cfg .input_size [- 1 ], pt_cfg .crop_pct ))
108+ if expand_train_test and pt_cfg .test_input_size is not None :
109+ if pt_cfg .test_crop_pct is not None :
110+ model_cfgs .append ((n , pt_cfg .test_input_size [- 1 ], pt_cfg .test_crop_pct ))
111+ else :
112+ model_cfgs .append ((n , pt_cfg .test_input_size [- 1 ], pt_cfg .crop_pct ))
113+ if include_crop :
114+ model_cfgs = [(n , {'img-size' : r , 'crop-pct' : cp }) for n , r , cp in sorted (model_cfgs )]
115+ else :
116+ model_cfgs = [(n , {'img-size' : r }) for n , r , cp in sorted (model_cfgs )]
117+ return model_cfgs
118+
119+
96120def main ():
97121 args = parser .parse_args ()
98122 cmd , cmd_args = cmd_from_args (args )
@@ -105,35 +129,25 @@ def main():
105129 model_cfgs = [(n , None ) for n in model_names ]
106130 elif args .model_list == 'all_in1k' :
107131 model_names = list_models (pretrained = True )
108- model_cfgs = []
109- for n in model_names :
110- pt_cfg = get_pretrained_cfg (n )
111- if getattr (pt_cfg , 'num_classes' , 0 ) == 1000 :
112- print (n , pt_cfg .num_classes )
113- model_cfgs .append ((n , None ))
132+ model_cfgs = _get_model_cfgs (model_names , num_classes = 1000 , expand_train_test = True )
114133 elif args .model_list == 'all_res' :
115134 model_names = list_models ()
116- model_names += list_models (pretrained = True )
117- model_cfgs = set ()
118- for n in model_names :
119- pt_cfg = get_pretrained_cfg (n )
120- if pt_cfg is None :
121- print (f'Model { n } is missing pretrained cfg, skipping.' )
122- continue
123- n = n .split ('.' )[0 ]
124- model_cfgs .add ((n , pt_cfg .input_size [- 1 ]))
125- if pt_cfg .test_input_size is not None :
126- model_cfgs .add ((n , pt_cfg .test_input_size [- 1 ]))
127- model_cfgs = [(n , {'img-size' : r }) for n , r in sorted (model_cfgs )]
135+ model_cfgs = _get_model_cfgs (model_names , expand_train_test = True , include_crop = False )
128136 elif not is_model (args .model_list ):
129137 # model name doesn't exist, try as wildcard filter
130138 model_names = list_models (args .model_list )
131139 model_cfgs = [(n , None ) for n in model_names ]
132140
133141 if not model_cfgs and os .path .exists (args .model_list ):
134142 with open (args .model_list ) as f :
143+ model_cfgs = []
135144 model_names = [line .rstrip () for line in f ]
136- model_cfgs = [(n , None ) for n in model_names ]
145+ _get_model_cfgs (
146+ model_names ,
147+ #num_classes=1000,
148+ expand_train_test = True ,
149+ #include_crop=False,
150+ )
137151
138152 if len (model_cfgs ):
139153 results_file = args .results_file or './results.csv'
0 commit comments