Skip to content

Commit cd72e66

Browse files
committed
Bug in last mod for features_only default_cfg
1 parent 867a0e5 commit cd72e66

File tree

3 files changed

+17
-17
lines changed

3 files changed

+17
-17
lines changed

timm/models/efficientnet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,19 +453,19 @@ def forward(self, x) -> List[torch.Tensor]:
453453

454454

455455
def _create_effnet(model_kwargs, variant, pretrained=False):
456+
features_only = False
457+
model_cls = EfficientNet
456458
if model_kwargs.pop('features_only', False):
457-
load_strict = False
459+
features_only = True
458460
model_kwargs.pop('num_classes', 0)
459461
model_kwargs.pop('num_features', 0)
460462
model_kwargs.pop('head_conv', None)
461463
model_cls = EfficientNetFeatures
462-
else:
463-
load_strict = True
464-
model_cls = EfficientNet
465464
model = build_model_with_cfg(
466465
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
467-
pretrained_strict=load_strict, **model_kwargs)
468-
model.default_cfg = default_cfg_for_features(model.default_cfg)
466+
pretrained_strict=not features_only, **model_kwargs)
467+
if features_only:
468+
model.default_cfg = default_cfg_for_features(model.default_cfg)
469469
return model
470470

471471

timm/models/hrnet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -773,16 +773,16 @@ def forward(self, x) -> List[torch.tensor]:
773773

774774
def _create_hrnet(variant, pretrained, **model_kwargs):
775775
model_cls = HighResolutionNet
776-
strict = True
776+
features_only = False
777777
if model_kwargs.pop('features_only', False):
778778
model_cls = HighResolutionNetFeatures
779779
model_kwargs['num_classes'] = 0
780-
strict = False
781-
780+
features_only = True
782781
model = build_model_with_cfg(
783782
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
784-
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
785-
model.default_cfg = default_cfg_for_features(model.default_cfg)
783+
model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs)
784+
if features_only:
785+
model.default_cfg = default_cfg_for_features(model.default_cfg)
786786
return model
787787

788788

timm/models/mobilenetv3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,20 +201,20 @@ def forward(self, x) -> List[torch.Tensor]:
201201

202202

203203
def _create_mnv3(model_kwargs, variant, pretrained=False):
204+
features_only = False
205+
model_cls = MobileNetV3
204206
if model_kwargs.pop('features_only', False):
205-
load_strict = False
207+
features_only = True
206208
model_kwargs.pop('num_classes', 0)
207209
model_kwargs.pop('num_features', 0)
208210
model_kwargs.pop('head_conv', None)
209211
model_kwargs.pop('head_bias', None)
210212
model_cls = MobileNetV3Features
211-
else:
212-
load_strict = True
213-
model_cls = MobileNetV3
214213
model = build_model_with_cfg(
215214
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
216-
pretrained_strict=load_strict, **model_kwargs)
217-
model.default_cfg = default_cfg_for_features(model.default_cfg)
215+
pretrained_strict=not features_only, **model_kwargs)
216+
if features_only:
217+
model.default_cfg = default_cfg_for_features(model.default_cfg)
218218
return model
219219

220220

0 commit comments

Comments
 (0)