|
16 | 16 | from pytorch_custom_utils import save_load |
17 | 17 |
|
18 | 18 | from beartype import beartype |
19 | | -from beartype.typing import Tuple, Callable, List, Dict, Any |
| 19 | +from beartype.typing import Tuple, Callable, List, Dict, Any,Optional, Union |
| 20 | +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download |
20 | 21 |
|
21 | 22 | from einops import rearrange, repeat, reduce, pack, unpack |
22 | 23 | from einops.layers.torch import Rearrange |
@@ -635,6 +636,40 @@ def __init__( |
635 | 636 | self.commit_loss_weight = commit_loss_weight |
636 | 637 | self.bin_smooth_blur_sigma = bin_smooth_blur_sigma |
637 | 638 |
|
| 639 | + @classmethod |
| 640 | + def _from_pretrained( |
| 641 | + cls, |
| 642 | + *, |
| 643 | + model_id: str, |
| 644 | + revision: Optional[str], |
| 645 | + cache_dir: Optional[Union[str, Path]], |
| 646 | + force_download: bool, |
| 647 | + proxies: Optional[Dict], |
| 648 | + resume_download: bool, |
| 649 | + local_files_only: bool, |
| 650 | + token: Union[str, bool, None], |
| 651 | + map_location: str = "cpu", |
| 652 | + strict: bool = False, |
| 653 | + **model_kwargs, |
| 654 | + ): |
| 655 | + model_filename = "mesh-autoencoder.bin" |
| 656 | + model_file = Path(model_id) / model_filename |
| 657 | + if not model_file.exists(): |
| 658 | + model_file = hf_hub_download( |
| 659 | + repo_id=model_id, |
| 660 | + filename=model_filename, |
| 661 | + revision=revision, |
| 662 | + cache_dir=cache_dir, |
| 663 | + force_download=force_download, |
| 664 | + proxies=proxies, |
| 665 | + resume_download=resume_download, |
| 666 | + token=token, |
| 667 | + local_files_only=local_files_only, |
| 668 | + ) |
| 669 | + model = cls.init_and_load(model_file,strict=strict) |
| 670 | + model.to(map_location) |
| 671 | + return model |
| 672 | + |
638 | 673 | @beartype |
639 | 674 | def encode( |
640 | 675 | self, |
@@ -1042,7 +1077,7 @@ def forward( |
1042 | 1077 | return recon_faces, total_loss, loss_breakdown |
1043 | 1078 |
|
1044 | 1079 | @save_load(version = __version__) |
1045 | | -class MeshTransformer(Module): |
| 1080 | +class MeshTransformer(Module,PyTorchModelHubMixin): |
1046 | 1081 | @beartype |
1047 | 1082 | def __init__( |
1048 | 1083 | self, |
@@ -1193,7 +1228,40 @@ def __init__( |
1193 | 1228 |
|
1194 | 1229 | self.pad_id = pad_id |
1195 | 1230 | autoencoder.pad_id = pad_id |
1196 | | - |
| 1231 | + |
| 1232 | + @classmethod |
| 1233 | + def _from_pretrained( |
| 1234 | + cls, |
| 1235 | + *, |
| 1236 | + model_id: str, |
| 1237 | + revision: Optional[str], |
| 1238 | + cache_dir: Optional[Union[str, Path]], |
| 1239 | + force_download: bool, |
| 1240 | + proxies: Optional[Dict], |
| 1241 | + resume_download: bool, |
| 1242 | + local_files_only: bool, |
| 1243 | + token: Union[str, bool, None], |
| 1244 | + map_location: str = "cpu", |
| 1245 | + strict: bool = False, |
| 1246 | + **model_kwargs, |
| 1247 | + ): |
| 1248 | + model_filename = "mesh-transformer.bin" |
| 1249 | + model_file = Path(model_id) / model_filename |
| 1250 | + if not model_file.exists(): |
| 1251 | + model_file = hf_hub_download( |
| 1252 | + repo_id=model_id, |
| 1253 | + filename=model_filename, |
| 1254 | + revision=revision, |
| 1255 | + cache_dir=cache_dir, |
| 1256 | + force_download=force_download, |
| 1257 | + proxies=proxies, |
| 1258 | + resume_download=resume_download, |
| 1259 | + token=token, |
| 1260 | + local_files_only=local_files_only, |
| 1261 | + ) |
| 1262 | + model = cls.init_and_load(model_file,strict=strict) |
| 1263 | + model.to(map_location) |
| 1264 | + return model |
1197 | 1265 | @property |
1198 | 1266 | def device(self): |
1199 | 1267 | return next(self.parameters()).device |
|
0 commit comments