Skip to content

Commit 45c048b

Browse files
committed
A few minor fixes and bit more cleanup on the huggingface hub integration.
1 parent ead80d3 commit 45c048b

File tree

4 files changed

+43
-33
lines changed

4 files changed

+43
-33
lines changed

timm/models/factory.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .registry import is_model, is_model_in_modules, model_entrypoint
22
from .helpers import load_checkpoint
33
from .layers import set_layer_config
4-
from .hub import load_config_from_hf
4+
from .hub import load_model_config_from_hf
55

66

77
def split_model_name(model_name):
@@ -67,11 +67,9 @@ def create_model(
6767
kwargs = {k: v for k, v in kwargs.items() if v is not None}
6868

6969
if source_name == 'hf_hub':
70-
# Load model weights + default_cfg from Hugging Face hub.
71-
# For model names specified in the form `hf_hub:path/architecture_name#revision`
72-
hf_default_cfg = load_config_from_hf(model_name)
73-
hf_default_cfg['hf_hub'] = model_name # insert hf_hub id for pretrained weight load during creation
74-
model_name = hf_default_cfg.get('architecture')
70+
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
71+
# load model weights + default_cfg from Hugging Face hub.
72+
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
7573
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
7674

7775
if is_model(model_name):

timm/models/helpers.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -323,17 +323,14 @@ def default_cfg_for_features(default_cfg):
323323
return default_cfg
324324

325325

326-
def overlay_external_default_cfg(kwargs, default_cfg):
327-
""" Overlay 'default_cfg' in kwargs on top of default_cfg arg.
326+
def overlay_external_default_cfg(default_cfg, kwargs):
327+
""" Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
328328
"""
329-
default_cfg = default_cfg or {}
330329
external_default_cfg = kwargs.pop('external_default_cfg', None)
331330
if external_default_cfg:
332-
default_cfg = deepcopy(default_cfg)
333331
default_cfg.pop('url', None) # url should come from external cfg
334332
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
335333
default_cfg.update(external_default_cfg)
336-
return default_cfg
337334

338335

339336
def set_default_kwargs(kwargs, names, default_cfg):
@@ -344,7 +341,7 @@ def set_default_kwargs(kwargs, names, default_cfg):
344341
input_size = default_cfg.get('input_size', None)
345342
if input_size is not None:
346343
assert len(input_size) == 3
347-
kwargs.setdefault(n, input_size[:-2])
344+
kwargs.setdefault(n, input_size[-2:])
348345
elif n == 'in_chans':
349346
input_size = default_cfg.get('input_size', None)
350347
if input_size is not None:
@@ -363,6 +360,25 @@ def filter_kwargs(kwargs, names):
363360
kwargs.pop(n, None)
364361

365362

363+
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
364+
""" Update the default_cfg and kwargs before passing to model
365+
366+
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
367+
could/should be replaced by an improved configuration mechanism
368+
369+
Args:
370+
default_cfg: input default_cfg (updated in-place)
371+
kwargs: keyword args passed to model build fn (updated in-place)
372+
kwargs_filter: keyword arg keys that must be removed before model __init__
373+
"""
374+
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
375+
overlay_external_default_cfg(default_cfg, kwargs)
376+
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
377+
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
378+
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
379+
filter_kwargs(kwargs, names=kwargs_filter)
380+
381+
366382
def build_model_with_cfg(
367383
model_cls: Callable,
368384
variant: str,
@@ -399,29 +415,20 @@ def build_model_with_cfg(
399415
pruned = kwargs.pop('pruned', False)
400416
features = False
401417
feature_cfg = feature_cfg or {}
418+
default_cfg = deepcopy(default_cfg) if default_cfg else {}
419+
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
420+
default_cfg.setdefault('architecture', variant)
402421

403-
# Setup for featyre extraction wrapper done at end of this fn
422+
# Setup for feature extraction wrapper done at end of this fn
404423
if kwargs.pop('features_only', False):
405424
features = True
406425
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
407426
if 'out_indices' in kwargs:
408427
feature_cfg['out_indices'] = kwargs.pop('out_indices')
409428

410-
# FIXME this next sequence of overlay default_cfg, set default kwargs, filter kwargs
411-
# could/should be replaced by an improved configuration mechanism
412-
413-
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
414-
default_cfg = overlay_external_default_cfg(kwargs, default_cfg)
415-
416-
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
417-
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
418-
419-
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
420-
filter_kwargs(kwargs, names=kwargs_filter)
421-
422429
# Build the model
423430
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
424-
model.default_cfg = deepcopy(default_cfg)
431+
model.default_cfg = default_cfg
425432

426433
if pruned:
427434
model = adapt_model_from_file(model, variant)

timm/models/hub.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_logger = logging.getLogger(__name__)
2424

2525

26-
def get_cache_dir(child=''):
26+
def get_cache_dir(child_dir=''):
2727
"""
2828
Returns the location of the directory where models are cached (and creates it if necessary).
2929
"""
@@ -32,8 +32,8 @@ def get_cache_dir(child=''):
3232
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
3333

3434
hub_dir = get_dir()
35-
children = () if not child else child,
36-
model_dir = os.path.join(hub_dir, 'checkpoints', *children)
35+
child_dir = () if not child_dir else (child_dir,)
36+
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
3737
os.makedirs(model_dir, exist_ok=True)
3838
return model_dir
3939

@@ -80,10 +80,13 @@ def _download_from_hf(model_id: str, filename: str):
8080
return cached_download(url, cache_dir=get_cache_dir('hf'))
8181

8282

83-
def load_config_from_hf(model_id: str):
83+
def load_model_config_from_hf(model_id: str):
8484
assert has_hf_hub(True)
8585
cached_file = _download_from_hf(model_id, 'config.json')
86-
return load_cfg_from_json(cached_file)
86+
default_cfg = load_cfg_from_json(cached_file)
87+
default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
88+
model_name = default_cfg.get('architecture')
89+
return default_cfg, model_name
8790

8891

8992
def load_state_dict_from_hf(model_id: str):

timm/models/vision_transformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import logging
2222
from functools import partial
2323
from collections import OrderedDict
24+
from copy import deepcopy
2425

2526
import torch
2627
import torch.nn as nn
@@ -462,9 +463,10 @@ def checkpoint_filter_fn(state_dict, model):
462463

463464

464465
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
465-
default_cfg = overlay_external_default_cfg(kwargs, default_cfgs[variant])
466+
default_cfg = deepcopy(default_cfgs[variant])
467+
overlay_external_default_cfg(default_cfg, kwargs)
466468
default_num_classes = default_cfg['num_classes']
467-
default_img_size = default_cfg['input_size'][-1]
469+
default_img_size = default_cfg['input_size'][-2:]
468470

469471
num_classes = kwargs.pop('num_classes', default_num_classes)
470472
img_size = kwargs.pop('img_size', default_img_size)

0 commit comments

Comments
 (0)