Skip to content

Commit 867a0e5

Browse files
committed
Add default_cfg back to models wrapped in feature extraction module as per discussion in #294.
1 parent 4ca52d7 commit 867a0e5

File tree

4 files changed

+22
-6
lines changed

4 files changed

+22
-6
lines changed

timm/models/efficientnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
3535
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
3636
from .features import FeatureInfo, FeatureHooks
37-
from .helpers import build_model_with_cfg
37+
from .helpers import build_model_with_cfg, default_cfg_for_features
3838
from .layers import create_conv2d, create_classifier
3939
from .registry import register_model
4040

@@ -462,9 +462,11 @@ def _create_effnet(model_kwargs, variant, pretrained=False):
462462
else:
463463
load_strict = True
464464
model_cls = EfficientNet
465-
return build_model_with_cfg(
465+
model = build_model_with_cfg(
466466
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
467467
pretrained_strict=load_strict, **model_kwargs)
468+
model.default_cfg = default_cfg_for_features(model.default_cfg)
469+
return model
468470

469471

470472
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):

timm/models/helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant):
251251
return adapt_model_from_string(parent_module, f.read().strip())
252252

253253

254+
def default_cfg_for_features(default_cfg):
255+
default_cfg = deepcopy(default_cfg)
256+
# remove default pretrained cfg fields that don't have much relevance for feature backbone
257+
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
258+
for tr in to_remove:
259+
default_cfg.pop(tr, None)
260+
return default_cfg
261+
262+
254263
def build_model_with_cfg(
255264
model_cls: Callable,
256265
variant: str,
@@ -296,5 +305,6 @@ def build_model_with_cfg(
296305
else:
297306
assert False, f'Unknown feature class {feature_cls}'
298307
model = feature_cls(model, **feature_cfg)
308+
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
299309

300310
return model

timm/models/hrnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1919
from .features import FeatureInfo
20-
from .helpers import build_model_with_cfg
20+
from .helpers import build_model_with_cfg, default_cfg_for_features
2121
from .layers import create_classifier
2222
from .registry import register_model
2323
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
@@ -779,9 +779,11 @@ def _create_hrnet(variant, pretrained, **model_kwargs):
779779
model_kwargs['num_classes'] = 0
780780
strict = False
781781

782-
return build_model_with_cfg(
782+
model = build_model_with_cfg(
783783
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
784784
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
785+
model.default_cfg = default_cfg_for_features(model.default_cfg)
786+
return model
785787

786788

787789
@register_model

timm/models/mobilenetv3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
1818
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
1919
from .features import FeatureInfo, FeatureHooks
20-
from .helpers import build_model_with_cfg
20+
from .helpers import build_model_with_cfg, default_cfg_for_features
2121
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
2222
from .registry import register_model
2323

@@ -211,9 +211,11 @@ def _create_mnv3(model_kwargs, variant, pretrained=False):
211211
else:
212212
load_strict = True
213213
model_cls = MobileNetV3
214-
return build_model_with_cfg(
214+
model = build_model_with_cfg(
215215
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
216216
pretrained_strict=load_strict, **model_kwargs)
217+
model.default_cfg = default_cfg_for_features(model.default_cfg)
218+
return model
217219

218220

219221
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):

0 commit comments

Comments
 (0)