Skip to content

Commit 8be5ff9

Browse files
committed
cleanup
1 parent 1972a66 commit 8be5ff9

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from pytorch_custom_utils import save_load
1717

1818
from beartype import beartype
19-
from beartype.typing import Tuple, Callable, List, Dict, Any,Optional, Union
19+
from beartype.typing import Tuple, Callable, List, Dict, Any
20+
2021
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
2122

2223
from einops import rearrange, repeat, reduce, pack, unpack
@@ -641,13 +642,13 @@ def _from_pretrained(
641642
cls,
642643
*,
643644
model_id: str,
644-
revision: Optional[str],
645-
cache_dir: Optional[Union[str, Path]],
645+
revision: str | None,
646+
cache_dir: str | Path | None,
646647
force_download: bool,
647-
proxies: Optional[Dict],
648+
proxies: Dict | None,
648649
resume_download: bool,
649650
local_files_only: bool,
650-
token: Union[str, bool, None],
651+
token: str | bool | None,
651652
map_location: str = "cpu",
652653
strict: bool = False,
653654
**model_kwargs,
@@ -1234,19 +1235,20 @@ def _from_pretrained(
12341235
cls,
12351236
*,
12361237
model_id: str,
1237-
revision: Optional[str],
1238-
cache_dir: Optional[Union[str, Path]],
1238+
revision: str | None,
1239+
cache_dir: str | Path | None,
12391240
force_download: bool,
1240-
proxies: Optional[Dict],
1241+
proxies: Dict | None,
12411242
resume_download: bool,
12421243
local_files_only: bool,
1243-
token: Union[str, bool, None],
1244+
token: str | bool | None,
12441245
map_location: str = "cpu",
12451246
strict: bool = False,
12461247
**model_kwargs,
12471248
):
12481249
model_filename = "mesh-transformer.bin"
12491250
model_file = Path(model_id) / model_filename
1251+
12501252
if not model_file.exists():
12511253
model_file = hf_hub_download(
12521254
repo_id=model_id,
@@ -1258,10 +1260,12 @@ def _from_pretrained(
12581260
resume_download=resume_download,
12591261
token=token,
12601262
local_files_only=local_files_only,
1261-
)
1263+
)
1264+
12621265
model = cls.init_and_load(model_file,strict=strict)
12631266
model.to(map_location)
12641267
return model
1268+
12651269
@property
12661270
def device(self):
12671271
return next(self.parameters()).device

0 commit comments

Comments
 (0)