@@ -202,28 +202,32 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
202202 pytest .skip ("Fixed input size model > limit." )
203203
204204 input_tensor = torch .randn ((batch_size , * input_size ))
205+ feat_dim = getattr (model , 'feature_dim' , None )
205206
206207 outputs = model .forward_features (input_tensor )
207208 if isinstance (outputs , (tuple , list )):
208209 # cannot currently verify multi-tensor output.
209210 pass
210211 else :
211- feat_dim = - 1 if outputs .ndim == 3 else 1
212+ if feat_dim is None :
213+ feat_dim = - 1 if outputs .ndim == 3 else 1
212214 assert outputs .shape [feat_dim ] == model .num_features
213215
214216 # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
215217 model .reset_classifier (0 )
216218 outputs = model .forward (input_tensor )
217219 if isinstance (outputs , (tuple , list )):
218220 outputs = outputs [0 ]
219- feat_dim = - 1 if outputs .ndim == 3 else 1
221+ if feat_dim is None :
222+ feat_dim = - 1 if outputs .ndim == 3 else 1
220223 assert outputs .shape [feat_dim ] == model .num_features , 'pooled num_features != config'
221224
222225 model = create_model (model_name , pretrained = False , num_classes = 0 ).eval ()
223226 outputs = model .forward (input_tensor )
224227 if isinstance (outputs , (tuple , list )):
225228 outputs = outputs [0 ]
226- feat_dim = - 1 if outputs .ndim == 3 else 1
229+ if feat_dim is None :
230+ feat_dim = - 1 if outputs .ndim == 3 else 1
227231 assert outputs .shape [feat_dim ] == model .num_features
228232
229233 # check classifier name matches default_cfg
0 commit comments