1414import torch .utils .model_zoo as model_zoo
1515
1616from .features import FeatureListNet , FeatureDictNet , FeatureHookNet
17- from .layers import Conv2dSame
17+ from .layers import Conv2dSame , Linear
1818
1919
2020_logger = logging .getLogger (__name__ )
@@ -234,7 +234,7 @@ def adapt_model_from_string(parent_module, model_string):
234234 if isinstance (old_module , nn .Linear ):
235235 # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
236236 num_features = state_dict [n + '.weight' ][1 ]
237- new_fc = nn . Linear (
237+ new_fc = Linear (
238238 in_features = num_features , out_features = old_module .out_features , bias = old_module .bias is not None )
239239 set_layer (new_module , n , new_fc )
240240 if hasattr (new_module , 'num_features' ):
@@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant):
251251 return adapt_model_from_string (parent_module , f .read ().strip ())
252252
253253
254+ def default_cfg_for_features (default_cfg ):
255+ default_cfg = deepcopy (default_cfg )
256+ # remove default pretrained cfg fields that don't have much relevance for feature backbone
257+ to_remove = ('num_classes' , 'crop_pct' , 'classifier' ) # add default final pool size?
258+ for tr in to_remove :
259+ default_cfg .pop (tr , None )
260+ return default_cfg
261+
262+
254263def build_model_with_cfg (
255264 model_cls : Callable ,
256265 variant : str ,
@@ -296,5 +305,6 @@ def build_model_with_cfg(
296305 else :
297306 assert False , f'Unknown feature class { feature_cls } '
298307 model = feature_cls (model , ** feature_cfg )
308+ model .default_cfg = default_cfg_for_features (default_cfg ) # add back default_cfg
299309
300310 return model
0 commit comments