Skip to content

Commit a604011

Browse files
committed
Add support for passing model args via hf hub config
1 parent 23e7f17 commit a604011

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

timm/models/_factory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def create_model(
9999
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
100100
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
101101
# load model weights + pretrained_cfg from Hugging Face hub.
102-
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
102+
pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
103+
if model_args:
104+
for k, v in model_args.items():
105+
kwargs.setdefault(k, v)
103106
else:
104107
model_name, pretrained_tag = split_model_name_tag(model_name)
105108
if pretrained_tag and not pretrained_cfg:

timm/models/_hub.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,9 @@ def load_model_config_from_hf(model_id: str):
164164
if 'label_descriptions' in hf_config:
165165
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
166166

167+
model_args = hf_config.get('model_args', {})
167168
model_name = hf_config['architecture']
168-
return pretrained_cfg, model_name
169+
return pretrained_cfg, model_name, model_args
169170

170171

171172
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
@@ -193,19 +194,23 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
193194
def save_config_for_hf(
194195
model,
195196
config_path: str,
196-
model_config: Optional[dict] = None
197+
model_config: Optional[dict] = None,
198+
model_args: Optional[dict] = None
197199
):
198200
model_config = model_config or {}
199201
hf_config = {}
200202
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
201203
# set some values at root config level
202204
hf_config['architecture'] = pretrained_cfg.pop('architecture')
203-
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
204-
hf_config['num_features'] = model_config.get('num_features', model.num_features)
205-
global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
205+
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
206+
207+
# NOTE these attr saved for informational purposes, do not impact model build
208+
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
209+
global_pool_type = model_config.pop('global_pool', getattr(model, 'global_pool', None))
206210
if isinstance(global_pool_type, str) and global_pool_type:
207211
hf_config['global_pool'] = global_pool_type
208212

213+
# Save class label info
209214
if 'labels' in model_config:
210215
_logger.warning(
211216
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
@@ -225,6 +230,9 @@ def save_config_for_hf(
225230
# maps label names -> descriptions
226231
hf_config['label_descriptions'] = label_descriptions
227232

233+
if model_args:
234+
hf_config['model_args'] = model_args
235+
228236
hf_config['pretrained_cfg'] = pretrained_cfg
229237
hf_config.update(model_config)
230238

@@ -236,6 +244,7 @@ def save_for_hf(
236244
model,
237245
save_directory: str,
238246
model_config: Optional[dict] = None,
247+
model_args: Optional[dict] = None,
239248
safe_serialization: Union[bool, Literal["both"]] = False,
240249
):
241250
assert has_hf_hub(True)
@@ -251,11 +260,16 @@ def save_for_hf(
251260
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
252261

253262
config_path = save_directory / 'config.json'
254-
save_config_for_hf(model, config_path, model_config=model_config)
263+
save_config_for_hf(
264+
model,
265+
config_path,
266+
model_config=model_config,
267+
model_args=model_args,
268+
)
255269

256270

257271
def push_to_hf_hub(
258-
model,
272+
model: torch.nn.Module,
259273
repo_id: str,
260274
commit_message: str = 'Add model',
261275
token: Optional[str] = None,
@@ -264,6 +278,7 @@ def push_to_hf_hub(
264278
create_pr: bool = False,
265279
model_config: Optional[dict] = None,
266280
model_card: Optional[dict] = None,
281+
model_args: Optional[dict] = None,
267282
safe_serialization: Union[bool, Literal["both"]] = False,
268283
):
269284
"""
@@ -291,7 +306,13 @@ def push_to_hf_hub(
291306
# Dump model and push to Hub
292307
with TemporaryDirectory() as tmpdir:
293308
# Save model weights and config.
294-
save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
309+
save_for_hf(
310+
model,
311+
tmpdir,
312+
model_config=model_config,
313+
model_args=model_args,
314+
safe_serialization=safe_serialization,
315+
)
295316

296317
# Add readme if it does not exist
297318
if not has_readme:

0 commit comments

Comments
 (0)