Skip to content

Commit 7d657d2

Browse files
committed
Improve resolve_pretrained_cfg behaviour when no cfg exists, warn instead of crash. Improve usability ex #1311
1 parent 879df47 commit 7d657d2

File tree

4 files changed

+21
-12
lines changed

4 files changed

+21
-12
lines changed

timm/models/helpers.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -455,18 +455,27 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
455455
filter_kwargs(kwargs, names=kwargs_filter)
456456

457457

458-
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None):
458+
def resolve_pretrained_cfg(variant: str, **kwargs):
459+
pretrained_cfg = kwargs.pop('pretrained_cfg', None)
459460
if pretrained_cfg and isinstance(pretrained_cfg, dict):
460-
# highest priority, pretrained_cfg available and passed explicitly
461+
# highest priority, pretrained_cfg available and passed in args
461462
return deepcopy(pretrained_cfg)
462-
if kwargs and 'pretrained_cfg' in kwargs:
463-
# next highest, pretrained_cfg in a kwargs dict, pop and return
464-
pretrained_cfg = kwargs.pop('pretrained_cfg', {})
465-
if pretrained_cfg:
466-
return deepcopy(pretrained_cfg)
467-
# lookup pretrained cfg in model registry by variant
463+
# fallback to looking up pretrained cfg in model registry by variant identifier
468464
pretrained_cfg = get_pretrained_cfg(variant)
469-
assert pretrained_cfg
465+
if not pretrained_cfg:
466+
_logger.warning(
467+
f"No pretrained configuration specified for {variant} model. Using a default."
468+
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
469+
pretrained_cfg = dict(
470+
url='',
471+
num_classes=1000,
472+
input_size=(3, 224, 224),
473+
pool_size=None,
474+
crop_pct=.9,
475+
interpolation='bicubic',
476+
first_conv='',
477+
classifier='',
478+
)
470479
return pretrained_cfg
471480

472481

timm/models/inception_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def forward(self, x):
428428

429429

430430
def _create_inception_v3(variant, pretrained=False, **kwargs):
431-
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
431+
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
432432
aux_logits = kwargs.pop('aux_logits', False)
433433
if aux_logits:
434434
assert not kwargs.pop('features_only', False)

timm/models/vision_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
633633
if kwargs.get('features_only', None):
634634
raise RuntimeError('features_only not implemented for Vision Transformer models.')
635635

636-
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
636+
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
637637
model = build_model_with_cfg(
638638
VisionTransformer, variant, pretrained,
639639
pretrained_cfg=pretrained_cfg,

timm/models/vision_transformer_relpos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.utils.checkpoint import checkpoint
1717

1818
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
19-
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply
19+
from .helpers import build_model_with_cfg, named_apply
2020
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
2121
from .registry import register_model
2222

0 commit comments

Comments
 (0)