1616from pytorch_custom_utils import save_load
1717
1818from beartype import beartype
19- from beartype .typing import Tuple , Callable , List , Dict , Any ,Optional , Union
19+ from beartype .typing import Tuple , Callable , List , Dict , Any
20+
2021from huggingface_hub import PyTorchModelHubMixin , hf_hub_download
2122
2223from einops import rearrange , repeat , reduce , pack , unpack
@@ -641,13 +642,13 @@ def _from_pretrained(
641642 cls ,
642643 * ,
643644 model_id : str ,
644- revision : Optional [ str ] ,
645- cache_dir : Optional [ Union [ str , Path ]] ,
645+ revision : str | None ,
646+ cache_dir : str | Path | None ,
646647 force_download : bool ,
647- proxies : Optional [ Dict ] ,
648+ proxies : Dict | None ,
648649 resume_download : bool ,
649650 local_files_only : bool ,
650- token : Union [ str , bool , None ] ,
651+ token : str | bool | None ,
651652 map_location : str = "cpu" ,
652653 strict : bool = False ,
653654 ** model_kwargs ,
@@ -1234,19 +1235,20 @@ def _from_pretrained(
12341235 cls ,
12351236 * ,
12361237 model_id : str ,
1237- revision : Optional [ str ] ,
1238- cache_dir : Optional [ Union [ str , Path ]] ,
1238+ revision : str | None ,
1239+ cache_dir : str | Path | None ,
12391240 force_download : bool ,
1240- proxies : Optional [ Dict ] ,
1241+ proxies : Dict | None ,
12411242 resume_download : bool ,
12421243 local_files_only : bool ,
1243- token : Union [ str , bool , None ] ,
1244+ token : str | bool | None ,
12441245 map_location : str = "cpu" ,
12451246 strict : bool = False ,
12461247 ** model_kwargs ,
12471248 ):
12481249 model_filename = "mesh-transformer.bin"
12491250 model_file = Path (model_id ) / model_filename
1251+
12501252 if not model_file .exists ():
12511253 model_file = hf_hub_download (
12521254 repo_id = model_id ,
@@ -1258,10 +1260,12 @@ def _from_pretrained(
12581260 resume_download = resume_download ,
12591261 token = token ,
12601262 local_files_only = local_files_only ,
1261- )
1263+ )
1264+
12621265 model = cls .init_and_load (model_file ,strict = strict )
12631266 model .to (map_location )
12641267 return model
1268+
12651269 @property
12661270 def device (self ):
12671271 return next (self .parameters ()).device
0 commit comments