Skip to content

Commit bd0f791

Browse files
committed
Add weights only flag to avoid warning, try to keep bwd compat. Default to True for remote load of pretrained weights, keep False for local checkpoing load to avoid training checkpoint breaks.. fix #2249
1 parent 531215e commit bd0f791

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

timm/models/_builder.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,21 @@ def load_pretrained(
177177
model.load_pretrained(pretrained_loc)
178178
return
179179
else:
180-
state_dict = load_state_dict_from_url(
181-
pretrained_loc,
182-
map_location='cpu',
183-
progress=_DOWNLOAD_PROGRESS,
184-
check_hash=_CHECK_HASH,
185-
)
180+
try:
181+
state_dict = load_state_dict_from_url(
182+
pretrained_loc,
183+
map_location='cpu',
184+
progress=_DOWNLOAD_PROGRESS,
185+
check_hash=_CHECK_HASH,
186+
weights_only=True,
187+
)
188+
except ValueError:
189+
state_dict = load_state_dict_from_url(
190+
pretrained_loc,
191+
map_location='cpu',
192+
progress=_DOWNLOAD_PROGRESS,
193+
check_hash=_CHECK_HASH,
194+
)
186195
elif load_from == 'hf-hub':
187196
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
188197
if isinstance(pretrained_loc, (list, tuple)):
@@ -193,7 +202,7 @@ def load_pretrained(
193202
else:
194203
state_dict = load_state_dict_from_hf(*pretrained_loc)
195204
else:
196-
state_dict = load_state_dict_from_hf(pretrained_loc)
205+
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
197206
else:
198207
model_name = pretrained_cfg.get('architecture', 'this model')
199208
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")

timm/models/_helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,18 @@ def load_state_dict(
4444
checkpoint_path: str,
4545
use_ema: bool = True,
4646
device: Union[str, torch.device] = 'cpu',
47+
weights_only: bool = False,
4748
) -> Dict[str, Any]:
4849
if checkpoint_path and os.path.isfile(checkpoint_path):
4950
# Check if safetensors or not and load weights accordingly
5051
if str(checkpoint_path).endswith(".safetensors"):
5152
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
5253
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
5354
else:
54-
checkpoint = torch.load(checkpoint_path, map_location=device)
55+
try:
56+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
57+
except ValueError:
58+
checkpoint = torch.load(checkpoint_path, map_location=device)
5559

5660
state_dict_key = ''
5761
if isinstance(checkpoint, dict):
@@ -79,6 +83,7 @@ def load_checkpoint(
7983
strict: bool = True,
8084
remap: bool = False,
8185
filter_fn: Optional[Callable] = None,
86+
weights_only: bool = False,
8287
):
8388
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
8489
# numpy checkpoint, try to load via model specific load_pretrained fn
@@ -88,7 +93,7 @@ def load_checkpoint(
8893
raise NotImplementedError('Model cannot load numpy checkpoint')
8994
return
9095

91-
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
96+
state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only)
9297
if remap:
9398
state_dict = remap_state_dict(state_dict, model)
9499
elif filter_fn:
@@ -126,7 +131,7 @@ def resume_checkpoint(
126131
):
127132
resume_epoch = None
128133
if os.path.isfile(checkpoint_path):
129-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
134+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
130135
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
131136
if log_info:
132137
_logger.info('Restoring model state from checkpoint...')

timm/models/_hub.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,11 @@ def load_model_config_from_hf(model_id: str):
168168
return pretrained_cfg, model_name, model_args
169169

170170

171-
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
171+
def load_state_dict_from_hf(
172+
model_id: str,
173+
filename: str = HF_WEIGHTS_NAME,
174+
weights_only: bool = False,
175+
):
172176
assert has_hf_hub(True)
173177
hf_model_id, hf_revision = hf_split(model_id)
174178

@@ -187,7 +191,11 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
187191
# Otherwise, load using pytorch.load
188192
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
189193
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
190-
return torch.load(cached_file, map_location='cpu')
194+
try:
195+
state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only)
196+
except ValueError:
197+
state_dict = torch.load(cached_file, map_location='cpu')
198+
return state_dict
191199

192200

193201
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):

0 commit comments

Comments
 (0)