@@ -44,14 +44,18 @@ def load_state_dict(
4444 checkpoint_path : str ,
4545 use_ema : bool = True ,
4646 device : Union [str , torch .device ] = 'cpu' ,
47+ weights_only : bool = False ,
4748) -> Dict [str , Any ]:
4849 if checkpoint_path and os .path .isfile (checkpoint_path ):
4950 # Check if safetensors or not and load weights accordingly
5051 if str (checkpoint_path ).endswith (".safetensors" ):
5152 assert _has_safetensors , "`pip install safetensors` to use .safetensors"
5253 checkpoint = safetensors .torch .load_file (checkpoint_path , device = device )
5354 else :
54- checkpoint = torch .load (checkpoint_path , map_location = device )
55+ try :
56+ checkpoint = torch .load (checkpoint_path , map_location = device , weights_only = weights_only )
57+ except TypeError :
58+ checkpoint = torch .load (checkpoint_path , map_location = device )
5559
5660 state_dict_key = ''
5761 if isinstance (checkpoint , dict ):
@@ -79,6 +83,7 @@ def load_checkpoint(
7983 strict : bool = True ,
8084 remap : bool = False ,
8185 filter_fn : Optional [Callable ] = None ,
86+ weights_only : bool = False ,
8287):
8388 if os .path .splitext (checkpoint_path )[- 1 ].lower () in ('.npz' , '.npy' ):
8489 # numpy checkpoint, try to load via model specific load_pretrained fn
@@ -88,7 +93,7 @@ def load_checkpoint(
8893 raise NotImplementedError ('Model cannot load numpy checkpoint' )
8994 return
9095
91- state_dict = load_state_dict (checkpoint_path , use_ema , device = device )
96+ state_dict = load_state_dict (checkpoint_path , use_ema , device = device , weights_only = weights_only )
9297 if remap :
9398 state_dict = remap_state_dict (state_dict , model )
9499 elif filter_fn :
@@ -126,7 +131,7 @@ def resume_checkpoint(
126131):
127132 resume_epoch = None
128133 if os .path .isfile (checkpoint_path ):
129- checkpoint = torch .load (checkpoint_path , map_location = 'cpu' )
134+ checkpoint = torch .load (checkpoint_path , map_location = 'cpu' , weights_only = False )
130135 if isinstance (checkpoint , dict ) and 'state_dict' in checkpoint :
131136 if log_info :
132137 _logger .info ('Restoring model state from checkpoint...' )
0 commit comments