Skip to content

Commit 513c9c3

Browse files
alexanderdannrwightman
authored andcommitted
Revert "Running formatting with command from CONTRIBUTING.md"
This reverts commit ed00d06. Reducing diff to keep pull request only for functional change.
1 parent efe7e27 commit 513c9c3

File tree

1 file changed

+60
-68
lines changed

1 file changed

+60
-68
lines changed

timm/models/_hub.py

Lines changed: 60 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
try:
1919
import safetensors.torch
20-
2120
_has_safetensors = True
2221
except ImportError:
2322
_has_safetensors = False
@@ -33,7 +32,6 @@
3332
try:
3433
from huggingface_hub import HfApi, hf_hub_download, model_info
3534
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
36-
3735
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
3836
_has_hf_hub = True
3937
except ImportError:
@@ -42,16 +40,8 @@
4240

4341
_logger = logging.getLogger(__name__)
4442

45-
__all__ = [
46-
'get_cache_dir',
47-
'download_cached_file',
48-
'has_hf_hub',
49-
'hf_split',
50-
'load_model_config_from_hf',
51-
'load_state_dict_from_hf',
52-
'save_for_hf',
53-
'push_to_hf_hub',
54-
]
43+
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
44+
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
5545

5646
# Default name for a weights file hosted on the Huggingface Hub.
5747
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
@@ -76,10 +66,10 @@ def get_cache_dir(child_dir: str = ''):
7666

7767

7868
def download_cached_file(
79-
url: Union[str, List[str], Tuple[str, str]],
80-
check_hash: bool = True,
81-
progress: bool = False,
82-
cache_dir: Optional[Union[str, Path]] = None,
69+
url: Union[str, List[str], Tuple[str, str]],
70+
check_hash: bool = True,
71+
progress: bool = False,
72+
cache_dir: Optional[Union[str, Path]] = None,
8373
):
8474
if isinstance(url, (list, tuple)):
8575
url, filename = url
@@ -102,9 +92,9 @@ def download_cached_file(
10292

10393

10494
def check_cached_file(
105-
url: Union[str, List[str], Tuple[str, str]],
106-
check_hash: bool = True,
107-
cache_dir: Optional[Union[str, Path]] = None,
95+
url: Union[str, List[str], Tuple[str, str]],
96+
check_hash: bool = True,
97+
cache_dir: Optional[Union[str, Path]] = None,
10898
):
10999
if isinstance(url, (list, tuple)):
110100
url, filename = url
@@ -121,7 +111,7 @@ def check_cached_file(
121111
if hash_prefix:
122112
with open(cached_file, 'rb') as f:
123113
hd = hashlib.sha256(f.read()).hexdigest()
124-
if hd[: len(hash_prefix)] != hash_prefix:
114+
if hd[:len(hash_prefix)] != hash_prefix:
125115
return False
126116
return True
127117
return False
@@ -131,8 +121,7 @@ def has_hf_hub(necessary: bool = False):
131121
if not _has_hf_hub and necessary:
132122
# if no HF Hub module installed, and it is necessary to continue, raise error
133123
raise RuntimeError(
134-
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.'
135-
)
124+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
136125
return _has_hf_hub
137126

138127

@@ -152,9 +141,9 @@ def load_cfg_from_json(json_file: Union[str, Path]):
152141

153142

154143
def download_from_hf(
155-
model_id: str,
156-
filename: str,
157-
cache_dir: Optional[Union[str, Path]] = None,
144+
model_id: str,
145+
filename: str,
146+
cache_dir: Optional[Union[str, Path]] = None,
158147
):
159148
hf_model_id, hf_revision = hf_split(model_id)
160149
return hf_hub_download(
@@ -166,8 +155,8 @@ def download_from_hf(
166155

167156

168157
def _parse_model_cfg(
169-
cfg: Dict[str, Any],
170-
extra_fields: Dict[str, Any],
158+
cfg: Dict[str, Any],
159+
extra_fields: Dict[str, Any],
171160
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
172161
""""""
173162
# legacy "single‑dict" → split
@@ -178,7 +167,7 @@ def _parse_model_cfg(
178167
"num_features": pretrained_cfg.pop("num_features", None),
179168
"pretrained_cfg": pretrained_cfg,
180169
}
181-
if "labels" in pretrained_cfg: # rename ‑‑> label_names
170+
if "labels" in pretrained_cfg: # rename ‑‑> label_names
182171
pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
183172

184173
pretrained_cfg = cfg["pretrained_cfg"]
@@ -198,8 +187,8 @@ def _parse_model_cfg(
198187

199188

200189
def load_model_config_from_hf(
201-
model_id: str,
202-
cache_dir: Optional[Union[str, Path]] = None,
190+
model_id: str,
191+
cache_dir: Optional[Union[str, Path]] = None,
203192
):
204193
"""Original HF‑Hub loader (unchanged download, shared parsing)."""
205194
assert has_hf_hub(True)
@@ -209,7 +198,7 @@ def load_model_config_from_hf(
209198

210199

211200
def load_model_config_from_path(
212-
model_path: Union[str, Path],
201+
model_path: Union[str, Path],
213202
):
214203
"""Load from ``<model_path>/config.json`` on the local filesystem."""
215204
model_path = Path(model_path)
@@ -222,10 +211,10 @@ def load_model_config_from_path(
222211

223212

224213
def load_state_dict_from_hf(
225-
model_id: str,
226-
filename: str = HF_WEIGHTS_NAME,
227-
weights_only: bool = False,
228-
cache_dir: Optional[Union[str, Path]] = None,
214+
model_id: str,
215+
filename: str = HF_WEIGHTS_NAME,
216+
weights_only: bool = False,
217+
cache_dir: Optional[Union[str, Path]] = None,
229218
):
230219
assert has_hf_hub(True)
231220
hf_model_id, hf_revision = hf_split(model_id)
@@ -242,8 +231,7 @@ def load_state_dict_from_hf(
242231
)
243232
_logger.info(
244233
f"[{model_id}] Safe alternative available for '{filename}' "
245-
f"(as '{safe_filename}'). Loading weights using safetensors."
246-
)
234+
f"(as '{safe_filename}'). Loading weights using safetensors.")
247235
return safetensors.torch.load_file(cached_safe_file, device="cpu")
248236
except EntryNotFoundError:
249237
pass
@@ -275,10 +263,9 @@ def load_state_dict_from_hf(
275263
)
276264
_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
277265

278-
279266
def load_state_dict_from_path(
280-
path: str,
281-
weights_only: bool = False,
267+
path: str,
268+
weights_only: bool = False,
282269
):
283270
found_file = None
284271
for fname in _PREFERRED_FILES:
@@ -293,7 +280,10 @@ def load_state_dict_from_path(
293280
files = sorted(path.glob(f"*{ext}"))
294281
if files:
295282
if len(files) > 1:
296-
logging.warning(f"Multiple {ext} checkpoints in {path}: {names}. " f"Using '{files[0].name}'.")
283+
logging.warning(
284+
f"Multiple {ext} checkpoints in {path}: {names}. "
285+
f"Using '{files[0].name}'."
286+
)
297287
found_file = files[0]
298288

299289
if not found_file:
@@ -307,10 +297,10 @@ def load_state_dict_from_path(
307297

308298

309299
def load_custom_from_hf(
310-
model_id: str,
311-
filename: str,
312-
model: torch.nn.Module,
313-
cache_dir: Optional[Union[str, Path]] = None,
300+
model_id: str,
301+
filename: str,
302+
model: torch.nn.Module,
303+
cache_dir: Optional[Union[str, Path]] = None,
314304
):
315305
assert has_hf_hub(True)
316306
hf_model_id, hf_revision = hf_split(model_id)
@@ -324,7 +314,10 @@ def load_custom_from_hf(
324314

325315

326316
def save_config_for_hf(
327-
model: torch.nn.Module, config_path: str, model_config: Optional[dict] = None, model_args: Optional[dict] = None
317+
model: torch.nn.Module,
318+
config_path: str,
319+
model_config: Optional[dict] = None,
320+
model_args: Optional[dict] = None
328321
):
329322
model_config = model_config or {}
330323
hf_config = {}
@@ -343,8 +336,7 @@ def save_config_for_hf(
343336
if 'labels' in model_config:
344337
_logger.warning(
345338
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
346-
" Renaming provided 'labels' field to 'label_names'."
347-
)
339+
" Renaming provided 'labels' field to 'label_names'.")
348340
model_config.setdefault('label_names', model_config.pop('labels'))
349341

350342
label_names = model_config.pop('label_names', None)
@@ -371,11 +363,11 @@ def save_config_for_hf(
371363

372364

373365
def save_for_hf(
374-
model: torch.nn.Module,
375-
save_directory: str,
376-
model_config: Optional[dict] = None,
377-
model_args: Optional[dict] = None,
378-
safe_serialization: Union[bool, Literal["both"]] = False,
366+
model: torch.nn.Module,
367+
save_directory: str,
368+
model_config: Optional[dict] = None,
369+
model_args: Optional[dict] = None,
370+
safe_serialization: Union[bool, Literal["both"]] = False,
379371
):
380372
assert has_hf_hub(True)
381373
save_directory = Path(save_directory)
@@ -399,18 +391,18 @@ def save_for_hf(
399391

400392

401393
def push_to_hf_hub(
402-
model: torch.nn.Module,
403-
repo_id: str,
404-
commit_message: str = 'Add model',
405-
token: Optional[str] = None,
406-
revision: Optional[str] = None,
407-
private: bool = False,
408-
create_pr: bool = False,
409-
model_config: Optional[dict] = None,
410-
model_card: Optional[dict] = None,
411-
model_args: Optional[dict] = None,
412-
task_name: str = 'image-classification',
413-
safe_serialization: Union[bool, Literal["both"]] = 'both',
394+
model: torch.nn.Module,
395+
repo_id: str,
396+
commit_message: str = 'Add model',
397+
token: Optional[str] = None,
398+
revision: Optional[str] = None,
399+
private: bool = False,
400+
create_pr: bool = False,
401+
model_config: Optional[dict] = None,
402+
model_card: Optional[dict] = None,
403+
model_args: Optional[dict] = None,
404+
task_name: str = 'image-classification',
405+
safe_serialization: Union[bool, Literal["both"]] = 'both',
414406
):
415407
"""
416408
Arguments:
@@ -460,9 +452,9 @@ def push_to_hf_hub(
460452

461453

462454
def generate_readme(
463-
model_card: dict,
464-
model_name: str,
465-
task_name: str = 'image-classification',
455+
model_card: dict,
456+
model_name: str,
457+
task_name: str = 'image-classification',
466458
):
467459
tags = model_card.get('tags', None) or [task_name, 'timm', 'transformers']
468460
readme_text = "---\n"

0 commit comments

Comments
 (0)