1111from torch .utils .checkpoint import checkpoint
1212from torch .cuda .amp import autocast
1313
14- from torchtyping import TensorType
15-
1614from pytorch_custom_utils import save_load
1715
18- from beartype import beartype
1916from beartype .typing import Tuple , Callable , List , Dict , Any
17+ from meshgpt_pytorch .typing import Float , Int , Bool , typecheck
2018
2119from huggingface_hub import PyTorchModelHubMixin , hf_hub_download
2220
2624from einx import get_at
2725
2826from x_transformers import Decoder
29- from x_transformers .attend import Attend
3027from x_transformers .x_transformers import RMSNorm , FeedForward , LayerIntermediates
3128
3229from x_transformers .autoregressive_wrapper import (
@@ -78,8 +75,8 @@ def divisible_by(num, den):
7875def is_odd (n ):
7976 return not divisible_by (n , 2 )
8077
81- def is_empty (l ):
82- return len (l ) == 0
78+ def is_empty (x ):
79+ return len (x ) == 0
8380
8481def is_tensor_empty (t : Tensor ):
8582 return t .numel () == 0
@@ -157,7 +154,7 @@ def derive_angle(x, y, eps = 1e-5):
157154
158155@torch .no_grad ()
159156def get_derived_face_features (
160- face_coords : TensorType ['b' , 'nf' , ' nvf' , 3 , float ] # 3 or 4 vertices with 3 coordinates
157+ face_coords : Float ['b nf nvf 3' ] # 3 or 4 vertices with 3 coordinates
161158):
162159 shifted_face_coords = torch .cat ((face_coords [:, :, - 1 :], face_coords [:, :, :- 1 ]), dim = 2 )
163160
@@ -178,7 +175,7 @@ def get_derived_face_features(
178175
179176# tensor helper functions
180177
181- @beartype
178+ @typecheck
182179def discretize (
183180 t : Tensor ,
184181 * ,
@@ -194,7 +191,7 @@ def discretize(
194191
195192 return t .round ().long ().clamp (min = 0 , max = num_discrete - 1 )
196193
197- @beartype
194+ @typecheck
198195def undiscretize (
199196 t : Tensor ,
200197 * ,
@@ -210,7 +207,7 @@ def undiscretize(
210207 t /= num_discrete
211208 return t * (hi - lo ) + lo
212209
213- @beartype
210+ @typecheck
214211def gaussian_blur_1d (
215212 t : Tensor ,
216213 * ,
@@ -234,7 +231,7 @@ def gaussian_blur_1d(
234231 out = F .conv1d (t , kernel , padding = half_width , groups = channels )
235232 return rearrange (out , 'b c n -> b n c' )
236233
237- @beartype
234+ @typecheck
238235def scatter_mean (
239236 tgt : Tensor ,
240237 indices : Tensor ,
@@ -421,7 +418,7 @@ def forward(
421418
422419@save_load (version = __version__ )
423420class MeshAutoencoder (Module ):
424- @beartype
421+ @typecheck
425422 def __init__ (
426423 self ,
427424 num_discrete_coors = 128 ,
@@ -671,15 +668,15 @@ def _from_pretrained(
671668 model .to (map_location )
672669 return model
673670
674- @beartype
671+ @typecheck
675672 def encode (
676673 self ,
677674 * ,
678- vertices : TensorType ['b' , 'nv' , 3 , float ],
679- faces : TensorType ['b' , 'nf' , ' nvf', int ],
680- face_edges : TensorType ['b' , 'e' , 2 , int ],
681- face_mask : TensorType ['b' , ' nf', bool ],
682- face_edges_mask : TensorType ['b' , 'e' , bool ],
675+ vertices : Float ['b nv 3' ],
676+ faces : Int ['b nf nvf' ],
677+ face_edges : Int ['b e 2' ],
678+ face_mask : Bool ['b nf' ],
679+ face_edges_mask : Bool ['b e' ],
683680 return_face_coordinates = False
684681 ):
685682 """
@@ -692,7 +689,6 @@ def encode(
692689 d - embed dim
693690 """
694691
695- batch , num_vertices , num_coors , device = * vertices .shape , vertices .device
696692 _ , num_faces , num_vertices_per_face = faces .shape
697693
698694 assert self .num_vertices_per_face == num_vertices_per_face
@@ -773,18 +769,18 @@ def encode(
773769
774770 return face_embed , discrete_face_coords
775771
776- @beartype
772+ @typecheck
777773 def quantize (
778774 self ,
779775 * ,
780- faces : TensorType ['b' , 'nf' , ' nvf', int ],
781- face_mask : TensorType ['b' , 'n' , bool ],
782- face_embed : TensorType ['b' , 'nf' , 'd' , float ],
776+ faces : Int ['b nf nvf' ],
777+ face_mask : Bool ['b n' ],
778+ face_embed : Float ['b nf d' ],
783779 pad_id = None ,
784780 rvq_sample_codebook_temp = 1.
785781 ):
786782 pad_id = default (pad_id , self .pad_id )
787- batch , num_faces , device = * faces .shape [: 2 ], faces .device
783+ batch , device = faces .shape [0 ], faces .device
788784
789785 max_vertex_index = faces .amax ()
790786 num_vertices = int (max_vertex_index .item () + 1 )
@@ -858,11 +854,11 @@ def quantize_wrapper_fn(inp):
858854
859855 return face_embed_output , codes_output , commit_loss
860856
861- @beartype
857+ @typecheck
862858 def decode (
863859 self ,
864- quantized : TensorType ['b' , 'n' , 'd' , float ],
865- face_mask : TensorType ['b' , 'n' , bool ]
860+ quantized : Float ['b n d' ],
861+ face_mask : Bool ['b n' ]
866862 ):
867863 conv_face_mask = rearrange (face_mask , 'b n -> b 1 n' )
868864
@@ -884,12 +880,12 @@ def decode(
884880
885881 return rearrange (x , 'b d n -> b n d' )
886882
887- @beartype
883+ @typecheck
888884 @torch .no_grad ()
889885 def decode_from_codes_to_faces (
890886 self ,
891887 codes : Tensor ,
892- face_mask : TensorType ['b' , 'n' , bool ] | None = None ,
888+ face_mask : Bool ['b n' ] | None = None ,
893889 return_discrete_codes = False
894890 ):
895891 codes = rearrange (codes , 'b ... -> b (...)' )
@@ -964,13 +960,13 @@ def tokenize(self, vertices, faces, face_edges = None, **kwargs):
964960
965961 return codes
966962
967- @beartype
963+ @typecheck
968964 def forward (
969965 self ,
970966 * ,
971- vertices : TensorType ['b' , 'nv' , 3 , float ],
972- faces : TensorType ['b' , 'nf' , ' nvf', int ],
973- face_edges : TensorType ['b' , 'e' , 2 , int ] | None = None ,
967+ vertices : Float ['b nv 3' ],
968+ faces : Int ['b nf nvf' ],
969+ face_edges : Int ['b e 2' ] | None = None ,
974970 return_codes = False ,
975971 return_loss_breakdown = False ,
976972 return_recon_faces = False ,
@@ -980,7 +976,7 @@ def forward(
980976 if not exists (face_edges ):
981977 face_edges = derive_face_edges_from_faces (faces , pad_id = self .pad_id )
982978
983- num_faces , num_face_edges , device = faces . shape [ 1 ], face_edges . shape [ 1 ], faces .device
979+ device = faces .device
984980
985981 face_mask = reduce (faces != self .pad_id , 'b nf c -> b nf' , 'all' )
986982 face_edges_mask = reduce (face_edges != self .pad_id , 'b e ij -> b e' , 'all' )
@@ -1079,7 +1075,7 @@ def forward(
10791075
10801076@save_load (version = __version__ )
10811077class MeshTransformer (Module ,PyTorchModelHubMixin ):
1082- @beartype
1078+ @typecheck
10831079 def __init__ (
10841080 self ,
10851081 autoencoder : MeshAutoencoder ,
@@ -1270,7 +1266,7 @@ def _from_pretrained(
12701266 def device (self ):
12711267 return next (self .parameters ()).device
12721268
1273- @beartype
1269+ @typecheck
12741270 @torch .no_grad ()
12751271 def embed_texts (self , texts : str | List [str ]):
12761272 single_text = not isinstance (texts , list )
@@ -1287,7 +1283,7 @@ def embed_texts(self, texts: str | List[str]):
12871283
12881284 @eval_decorator
12891285 @torch .no_grad ()
1290- @beartype
1286+ @typecheck
12911287 def generate (
12921288 self ,
12931289 prompt : Tensor | None = None ,
@@ -1406,9 +1402,9 @@ def generate(
14061402 def forward (
14071403 self ,
14081404 * ,
1409- vertices : TensorType ['b' , 'nv' , 3 , int ],
1410- faces : TensorType ['b' , 'nf' , ' nvf', int ],
1411- face_edges : TensorType ['b' , 'e' , 2 , int ] | None = None ,
1405+ vertices : Int ['b nv 3' ],
1406+ faces : Int ['b nf nvf' ],
1407+ face_edges : Int ['b e 2' ] | None = None ,
14121408 codes : Tensor | None = None ,
14131409 cache : LayerIntermediates | None = None ,
14141410 ** kwargs
0 commit comments