@@ -323,17 +323,14 @@ def default_cfg_for_features(default_cfg):
323323 return default_cfg
324324
325325
326- def overlay_external_default_cfg (kwargs , default_cfg ):
327- """ Overlay 'default_cfg ' in kwargs on top of default_cfg arg.
326+ def overlay_external_default_cfg (default_cfg , kwargs ):
327+ """ Overlay 'external_default_cfg ' in kwargs on top of default_cfg arg.
328328 """
329- default_cfg = default_cfg or {}
330329 external_default_cfg = kwargs .pop ('external_default_cfg' , None )
331330 if external_default_cfg :
332- default_cfg = deepcopy (default_cfg )
333331 default_cfg .pop ('url' , None ) # url should come from external cfg
334332 default_cfg .pop ('hf_hub' , None ) # hf hub id should come from external cfg
335333 default_cfg .update (external_default_cfg )
336- return default_cfg
337334
338335
339336def set_default_kwargs (kwargs , names , default_cfg ):
@@ -344,7 +341,7 @@ def set_default_kwargs(kwargs, names, default_cfg):
344341 input_size = default_cfg .get ('input_size' , None )
345342 if input_size is not None :
346343 assert len (input_size ) == 3
347- kwargs .setdefault (n , input_size [: - 2 ])
344+ kwargs .setdefault (n , input_size [- 2 : ])
348345 elif n == 'in_chans' :
349346 input_size = default_cfg .get ('input_size' , None )
350347 if input_size is not None :
@@ -363,6 +360,25 @@ def filter_kwargs(kwargs, names):
363360 kwargs .pop (n , None )
364361
365362
363+ def update_default_cfg_and_kwargs (default_cfg , kwargs , kwargs_filter ):
364+ """ Update the default_cfg and kwargs before passing to model
365+
366+ FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
367+ could/should be replaced by an improved configuration mechanism
368+
369+ Args:
370+ default_cfg: input default_cfg (updated in-place)
371+ kwargs: keyword args passed to model build fn (updated in-place)
372+ kwargs_filter: keyword arg keys that must be removed before model __init__
373+ """
374+ # Overlay default cfg values from `external_default_cfg` if it exists in kwargs
375+ overlay_external_default_cfg (default_cfg , kwargs )
376+ # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
377+ set_default_kwargs (kwargs , names = ('num_classes' , 'global_pool' , 'in_chans' ), default_cfg = default_cfg )
378+ # Filter keyword args for task specific model variants (some 'features only' models, etc.)
379+ filter_kwargs (kwargs , names = kwargs_filter )
380+
381+
366382def build_model_with_cfg (
367383 model_cls : Callable ,
368384 variant : str ,
@@ -399,29 +415,20 @@ def build_model_with_cfg(
399415 pruned = kwargs .pop ('pruned' , False )
400416 features = False
401417 feature_cfg = feature_cfg or {}
418+ default_cfg = deepcopy (default_cfg ) if default_cfg else {}
419+ update_default_cfg_and_kwargs (default_cfg , kwargs , kwargs_filter )
420+ default_cfg .setdefault ('architecture' , variant )
402421
403- # Setup for featyre extraction wrapper done at end of this fn
422+ # Setup for feature extraction wrapper done at end of this fn
404423 if kwargs .pop ('features_only' , False ):
405424 features = True
406425 feature_cfg .setdefault ('out_indices' , (0 , 1 , 2 , 3 , 4 ))
407426 if 'out_indices' in kwargs :
408427 feature_cfg ['out_indices' ] = kwargs .pop ('out_indices' )
409428
410- # FIXME this next sequence of overlay default_cfg, set default kwargs, filter kwargs
411- # could/should be replaced by an improved configuration mechanism
412-
413- # Overlay default cfg values from `external_default_cfg` if it exists in kwargs
414- default_cfg = overlay_external_default_cfg (kwargs , default_cfg )
415-
416- # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
417- set_default_kwargs (kwargs , names = ('num_classes' , 'global_pool' , 'in_chans' ), default_cfg = default_cfg )
418-
419- # Filter keyword args for task specific model variants (some 'features only' models, etc.)
420- filter_kwargs (kwargs , names = kwargs_filter )
421-
422429 # Build the model
423430 model = model_cls (** kwargs ) if model_cfg is None else model_cls (cfg = model_cfg , ** kwargs )
424- model .default_cfg = deepcopy ( default_cfg )
431+ model .default_cfg = default_cfg
425432
426433 if pruned :
427434 model = adapt_model_from_file (model , variant )
0 commit comments