Skip to content

Commit 3eac7dc

Browse files
authored
Merge pull request #501 from rwightman/hf_hub_revisit
Support for huggingface hub via create_model and default_cfgs.
2 parents e5ba5dc + 45c048b commit 3eac7dc

39 files changed

+451
-173
lines changed

sotabench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
509509
model.eval()
510510
with torch.no_grad():
511511
# warmup
512-
input = torch.randn((batch_size,) + data_config['input_size']).cuda()
512+
input = torch.randn((batch_size,) + tuple(data_config['input_size'])).cuda()
513513
model(input)
514514

515515
bar = tqdm(desc="Evaluation", mininterval=5, total=50000)

timm/data/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ class RandomResizedCropAndInterpolation:
7272

7373
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
7474
interpolation='bilinear'):
75-
if isinstance(size, tuple):
76-
self.size = size
75+
if isinstance(size, (list, tuple)):
76+
self.size = tuple(size)
7777
else:
7878
self.size = (size, size)
7979
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):

timm/data/transforms_factory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def transforms_imagenet_train(
7878
secondary_tfl = []
7979
if auto_augment:
8080
assert isinstance(auto_augment, str)
81-
if isinstance(img_size, tuple):
81+
if isinstance(img_size, (tuple, list)):
8282
img_size_min = min(img_size)
8383
else:
8484
img_size_min = img_size
@@ -136,7 +136,7 @@ def transforms_imagenet_eval(
136136
std=IMAGENET_DEFAULT_STD):
137137
crop_pct = crop_pct or DEFAULT_CROP_PCT
138138

139-
if isinstance(img_size, tuple):
139+
if isinstance(img_size, (tuple, list)):
140140
assert len(img_size) == 2
141141
if img_size[-1] == img_size[-2]:
142142
# fall-back to older behaviour so Resize scales to shortest edge if target is square
@@ -186,7 +186,7 @@ def create_transform(
186186
tf_preprocessing=False,
187187
separate=False):
188188

189-
if isinstance(input_size, tuple):
189+
if isinstance(input_size, (tuple, list)):
190190
img_size = input_size[-2:]
191191
else:
192192
img_size = input_size

timm/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from .xception_aligned import *
3232
from .hardcorenas import *
3333

34-
from .factory import create_model
34+
from .factory import create_model, split_model_name, safe_model_name
3535
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
3636
from .layers import TestTimePoolHead, apply_test_time_pool
3737
from .layers import convert_splitbn_model

timm/models/cspnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,10 @@ def forward(self, x):
409409
def _create_cspnet(variant, pretrained=False, **kwargs):
410410
cfg_variant = variant.split('_')[0]
411411
return build_model_with_cfg(
412-
CspNet, variant, pretrained, default_cfg=default_cfgs[variant],
413-
feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], **kwargs)
412+
CspNet, variant, pretrained,
413+
default_cfg=default_cfgs[variant],
414+
feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant],
415+
**kwargs)
414416

415417

416418
@register_model

timm/models/densenet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,10 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
287287
kwargs['growth_rate'] = growth_rate
288288
kwargs['block_config'] = block_config
289289
return build_model_with_cfg(
290-
DenseNet, variant, pretrained, default_cfg=default_cfgs[variant],
291-
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, **kwargs)
290+
DenseNet, variant, pretrained,
291+
default_cfg=default_cfgs[variant],
292+
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained,
293+
**kwargs)
292294

293295

294296
@register_model

timm/models/dla.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,11 @@ def forward(self, x):
338338

339339
def _create_dla(variant, pretrained=False, **kwargs):
340340
return build_model_with_cfg(
341-
DLA, variant, pretrained, default_cfg=default_cfgs[variant],
342-
pretrained_strict=False, feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), **kwargs)
341+
DLA, variant, pretrained,
342+
default_cfg=default_cfgs[variant],
343+
pretrained_strict=False,
344+
feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
345+
**kwargs)
343346

344347

345348
@register_model

timm/models/dpn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,10 @@ def forward(self, x):
262262

263263
def _create_dpn(variant, pretrained=False, **kwargs):
264264
return build_model_with_cfg(
265-
DPN, variant, pretrained, default_cfg=default_cfgs[variant],
266-
feature_cfg=dict(feature_concat=True, flatten_sequential=True), **kwargs)
265+
DPN, variant, pretrained,
266+
default_cfg=default_cfgs[variant],
267+
feature_cfg=dict(feature_concat=True, flatten_sequential=True),
268+
**kwargs)
267269

268270

269271
@register_model

timm/models/efficientnet.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -452,18 +452,20 @@ def forward(self, x) -> List[torch.Tensor]:
452452
return list(out.values())
453453

454454

455-
def _create_effnet(model_kwargs, variant, pretrained=False):
455+
def _create_effnet(variant, pretrained=False, **kwargs):
456456
features_only = False
457457
model_cls = EfficientNet
458-
if model_kwargs.pop('features_only', False):
458+
kwargs_filter = None
459+
if kwargs.pop('features_only', False):
459460
features_only = True
460-
model_kwargs.pop('num_classes', 0)
461-
model_kwargs.pop('num_features', 0)
462-
model_kwargs.pop('head_conv', None)
461+
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
463462
model_cls = EfficientNetFeatures
464463
model = build_model_with_cfg(
465-
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
466-
pretrained_strict=not features_only, **model_kwargs)
464+
model_cls, variant, pretrained,
465+
default_cfg=default_cfgs[variant],
466+
pretrained_strict=not features_only,
467+
kwargs_filter=kwargs_filter,
468+
**kwargs)
467469
if features_only:
468470
model.default_cfg = default_cfg_for_features(model.default_cfg)
469471
return model
@@ -501,7 +503,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
501503
norm_kwargs=resolve_bn_args(kwargs),
502504
**kwargs
503505
)
504-
model = _create_effnet(model_kwargs, variant, pretrained)
506+
model = _create_effnet(variant, pretrained, **model_kwargs)
505507
return model
506508

507509

@@ -537,7 +539,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
537539
norm_kwargs=resolve_bn_args(kwargs),
538540
**kwargs
539541
)
540-
model = _create_effnet(model_kwargs, variant, pretrained)
542+
model = _create_effnet(variant, pretrained, **model_kwargs)
541543
return model
542544

543545

@@ -566,7 +568,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
566568
norm_kwargs=resolve_bn_args(kwargs),
567569
**kwargs
568570
)
569-
model = _create_effnet(model_kwargs,variant, pretrained)
571+
model = _create_effnet(variant, pretrained, **model_kwargs)
570572
return model
571573

572574

@@ -595,7 +597,7 @@ def _gen_mobilenet_v2(
595597
act_layer=resolve_act_layer(kwargs, 'relu6'),
596598
**kwargs
597599
)
598-
model = _create_effnet(model_kwargs, variant, pretrained)
600+
model = _create_effnet(variant, pretrained, **model_kwargs)
599601
return model
600602

601603

@@ -625,7 +627,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
625627
norm_kwargs=resolve_bn_args(kwargs),
626628
**kwargs
627629
)
628-
model = _create_effnet(model_kwargs, variant, pretrained)
630+
model = _create_effnet(variant, pretrained, **model_kwargs)
629631
return model
630632

631633

@@ -660,7 +662,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
660662
norm_kwargs=resolve_bn_args(kwargs),
661663
**kwargs
662664
)
663-
model = _create_effnet(model_kwargs, variant, pretrained)
665+
model = _create_effnet(variant, pretrained, **model_kwargs)
664666
return model
665667

666668

@@ -706,7 +708,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
706708
norm_kwargs=resolve_bn_args(kwargs),
707709
**kwargs,
708710
)
709-
model = _create_effnet(model_kwargs, variant, pretrained)
711+
model = _create_effnet(variant, pretrained, **model_kwargs)
710712
return model
711713

712714

@@ -735,7 +737,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
735737
act_layer=resolve_act_layer(kwargs, 'relu'),
736738
**kwargs,
737739
)
738-
model = _create_effnet(model_kwargs, variant, pretrained)
740+
model = _create_effnet(variant, pretrained, **model_kwargs)
739741
return model
740742

741743

@@ -765,7 +767,7 @@ def _gen_efficientnet_condconv(
765767
act_layer=resolve_act_layer(kwargs, 'swish'),
766768
**kwargs,
767769
)
768-
model = _create_effnet(model_kwargs, variant, pretrained)
770+
model = _create_effnet(variant, pretrained, **model_kwargs)
769771
return model
770772

771773

@@ -806,7 +808,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
806808
norm_kwargs=resolve_bn_args(kwargs),
807809
**kwargs,
808810
)
809-
model = _create_effnet(model_kwargs, variant, pretrained)
811+
model = _create_effnet(variant, pretrained, **model_kwargs)
810812
return model
811813

812814

@@ -839,7 +841,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
839841
norm_kwargs=resolve_bn_args(kwargs),
840842
**kwargs
841843
)
842-
model = _create_effnet(model_kwargs, variant, pretrained)
844+
model = _create_effnet(variant, pretrained, **model_kwargs)
843845
return model
844846

845847

@@ -872,7 +874,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
872874
norm_kwargs=resolve_bn_args(kwargs),
873875
**kwargs
874876
)
875-
model = _create_effnet(model_kwargs, variant, pretrained)
877+
model = _create_effnet(variant, pretrained, **model_kwargs)
876878
return model
877879

878880

timm/models/factory.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
11
from .registry import is_model, is_model_in_modules, model_entrypoint
22
from .helpers import load_checkpoint
33
from .layers import set_layer_config
4+
from .hub import load_model_config_from_hf
5+
6+
7+
def split_model_name(model_name):
8+
model_split = model_name.split(':', 1)
9+
if len(model_split) == 1:
10+
return '', model_split[0]
11+
else:
12+
source_name, model_name = model_split
13+
assert source_name in ('timm', 'hf_hub')
14+
return source_name, model_name
15+
16+
17+
def safe_model_name(model_name, remove_source=True):
18+
def make_safe(name):
19+
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
20+
if remove_source:
21+
model_name = split_model_name(model_name)[-1]
22+
return make_safe(model_name)
423

524

625
def create_model(
@@ -26,7 +45,7 @@ def create_model(
2645
global_pool (str): global pool type (default: 'avg')
2746
**: other kwargs are model specific
2847
"""
29-
model_args = dict(pretrained=pretrained)
48+
source_name, model_name = split_model_name(model_name)
3049

3150
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
3251
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
@@ -47,12 +66,19 @@ def create_model(
4766
# non-supporting models don't break and default args remain in effect.
4867
kwargs = {k: v for k, v in kwargs.items() if v is not None}
4968

69+
if source_name == 'hf_hub':
70+
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
71+
# load model weights + default_cfg from Hugging Face hub.
72+
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
73+
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
74+
75+
if is_model(model_name):
76+
create_fn = model_entrypoint(model_name)
77+
else:
78+
raise RuntimeError('Unknown model (%s)' % model_name)
79+
5080
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
51-
if is_model(model_name):
52-
create_fn = model_entrypoint(model_name)
53-
model = create_fn(**model_args, **kwargs)
54-
else:
55-
raise RuntimeError('Unknown model (%s)' % model_name)
81+
model = create_fn(pretrained=pretrained, **kwargs)
5682

5783
if checkpoint_path:
5884
load_checkpoint(model, checkpoint_path)

0 commit comments

Comments
 (0)