Skip to content

Commit 2df9f28

Browse files
authored
Merge pull request #2257 from huggingface/weights_only
Add weights only flag to avoid warning, try to keep bwd compat.
2 parents 531215e + a7b0bfc commit 2df9f28

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

timm/models/_builder.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,21 @@ def load_pretrained(
177177
model.load_pretrained(pretrained_loc)
178178
return
179179
else:
180-
state_dict = load_state_dict_from_url(
181-
pretrained_loc,
182-
map_location='cpu',
183-
progress=_DOWNLOAD_PROGRESS,
184-
check_hash=_CHECK_HASH,
185-
)
180+
try:
181+
state_dict = load_state_dict_from_url(
182+
pretrained_loc,
183+
map_location='cpu',
184+
progress=_DOWNLOAD_PROGRESS,
185+
check_hash=_CHECK_HASH,
186+
weights_only=True,
187+
)
188+
except TypeError:
189+
state_dict = load_state_dict_from_url(
190+
pretrained_loc,
191+
map_location='cpu',
192+
progress=_DOWNLOAD_PROGRESS,
193+
check_hash=_CHECK_HASH,
194+
)
186195
elif load_from == 'hf-hub':
187196
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
188197
if isinstance(pretrained_loc, (list, tuple)):
@@ -193,7 +202,7 @@ def load_pretrained(
193202
else:
194203
state_dict = load_state_dict_from_hf(*pretrained_loc)
195204
else:
196-
state_dict = load_state_dict_from_hf(pretrained_loc)
205+
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
197206
else:
198207
model_name = pretrained_cfg.get('architecture', 'this model')
199208
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")

timm/models/_helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,18 @@ def load_state_dict(
4444
checkpoint_path: str,
4545
use_ema: bool = True,
4646
device: Union[str, torch.device] = 'cpu',
47+
weights_only: bool = False,
4748
) -> Dict[str, Any]:
4849
if checkpoint_path and os.path.isfile(checkpoint_path):
4950
# Check if safetensors or not and load weights accordingly
5051
if str(checkpoint_path).endswith(".safetensors"):
5152
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
5253
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
5354
else:
54-
checkpoint = torch.load(checkpoint_path, map_location=device)
55+
try:
56+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
57+
except TypeError:
58+
checkpoint = torch.load(checkpoint_path, map_location=device)
5559

5660
state_dict_key = ''
5761
if isinstance(checkpoint, dict):
@@ -79,6 +83,7 @@ def load_checkpoint(
7983
strict: bool = True,
8084
remap: bool = False,
8185
filter_fn: Optional[Callable] = None,
86+
weights_only: bool = False,
8287
):
8388
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
8489
# numpy checkpoint, try to load via model specific load_pretrained fn
@@ -88,7 +93,7 @@ def load_checkpoint(
8893
raise NotImplementedError('Model cannot load numpy checkpoint')
8994
return
9095

91-
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
96+
state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only)
9297
if remap:
9398
state_dict = remap_state_dict(state_dict, model)
9499
elif filter_fn:
@@ -126,7 +131,7 @@ def resume_checkpoint(
126131
):
127132
resume_epoch = None
128133
if os.path.isfile(checkpoint_path):
129-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
134+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
130135
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
131136
if log_info:
132137
_logger.info('Restoring model state from checkpoint...')

timm/models/_hub.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,11 @@ def load_model_config_from_hf(model_id: str):
168168
return pretrained_cfg, model_name, model_args
169169

170170

171-
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
171+
def load_state_dict_from_hf(
172+
model_id: str,
173+
filename: str = HF_WEIGHTS_NAME,
174+
weights_only: bool = False,
175+
):
172176
assert has_hf_hub(True)
173177
hf_model_id, hf_revision = hf_split(model_id)
174178

@@ -187,7 +191,11 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
187191
# Otherwise, load using pytorch.load
188192
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
189193
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
190-
return torch.load(cached_file, map_location='cpu')
194+
try:
195+
state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only)
196+
except TypeError:
197+
state_dict = torch.load(cached_file, map_location='cpu')
198+
return state_dict
191199

192200

193201
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):

0 commit comments

Comments
 (0)