Skip to content

Commit 6c240fa

Browse files
authored
Merge pull request #84 from MarcusLoppe/huggingface_hub
Huggingface "from_pretrained" support
2 parents 3ebeda0 + 2932a33 commit 6c240fa

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from pytorch_custom_utils import save_load
1717

1818
from beartype import beartype
19-
from beartype.typing import Tuple, Callable, List, Dict, Any
19+
from beartype.typing import Tuple, Callable, List, Dict, Any,Optional, Union
20+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
2021

2122
from einops import rearrange, repeat, reduce, pack, unpack
2223
from einops.layers.torch import Rearrange
@@ -635,6 +636,40 @@ def __init__(
635636
self.commit_loss_weight = commit_loss_weight
636637
self.bin_smooth_blur_sigma = bin_smooth_blur_sigma
637638

639+
@classmethod
640+
def _from_pretrained(
641+
cls,
642+
*,
643+
model_id: str,
644+
revision: Optional[str],
645+
cache_dir: Optional[Union[str, Path]],
646+
force_download: bool,
647+
proxies: Optional[Dict],
648+
resume_download: bool,
649+
local_files_only: bool,
650+
token: Union[str, bool, None],
651+
map_location: str = "cpu",
652+
strict: bool = False,
653+
**model_kwargs,
654+
):
655+
model_filename = "mesh-autoencoder.bin"
656+
model_file = Path(model_id) / model_filename
657+
if not model_file.exists():
658+
model_file = hf_hub_download(
659+
repo_id=model_id,
660+
filename=model_filename,
661+
revision=revision,
662+
cache_dir=cache_dir,
663+
force_download=force_download,
664+
proxies=proxies,
665+
resume_download=resume_download,
666+
token=token,
667+
local_files_only=local_files_only,
668+
)
669+
model = cls.init_and_load(model_file,strict=strict)
670+
model.to(map_location)
671+
return model
672+
638673
@beartype
639674
def encode(
640675
self,
@@ -1042,7 +1077,7 @@ def forward(
10421077
return recon_faces, total_loss, loss_breakdown
10431078

10441079
@save_load(version = __version__)
1045-
class MeshTransformer(Module):
1080+
class MeshTransformer(Module,PyTorchModelHubMixin):
10461081
@beartype
10471082
def __init__(
10481083
self,
@@ -1193,7 +1228,40 @@ def __init__(
11931228

11941229
self.pad_id = pad_id
11951230
autoencoder.pad_id = pad_id
1196-
1231+
1232+
@classmethod
1233+
def _from_pretrained(
1234+
cls,
1235+
*,
1236+
model_id: str,
1237+
revision: Optional[str],
1238+
cache_dir: Optional[Union[str, Path]],
1239+
force_download: bool,
1240+
proxies: Optional[Dict],
1241+
resume_download: bool,
1242+
local_files_only: bool,
1243+
token: Union[str, bool, None],
1244+
map_location: str = "cpu",
1245+
strict: bool = False,
1246+
**model_kwargs,
1247+
):
1248+
model_filename = "mesh-transformer.bin"
1249+
model_file = Path(model_id) / model_filename
1250+
if not model_file.exists():
1251+
model_file = hf_hub_download(
1252+
repo_id=model_id,
1253+
filename=model_filename,
1254+
revision=revision,
1255+
cache_dir=cache_dir,
1256+
force_download=force_download,
1257+
proxies=proxies,
1258+
resume_download=resume_download,
1259+
token=token,
1260+
local_files_only=local_files_only,
1261+
)
1262+
model = cls.init_and_load(model_file,strict=strict)
1263+
model.to(map_location)
1264+
return model
11971265
@property
11981266
def device(self):
11991267
return next(self.parameters()).device

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
install_requires=[
2323
'accelerate>=0.25.0',
2424
'beartype',
25+
"huggingface_hub>=0.21.4",
2526
'classifier-free-guidance-pytorch>=0.6.2',
2627
'einops>=0.7.0',
2728
'einx[torch]>=0.1.3',

0 commit comments

Comments
 (0)