Skip to content

Commit da5ef5e

Browse files
committed
add film conditioning of coarse and fine transformer with pooled text embedding before being sent to transformers
1 parent 8d7032d commit da5ef5e

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,24 @@ def scatter_mean(
247247

248248
# resnet block
249249

250+
class FiLM(Module):
251+
def __init__(self, dim, dim_out = None):
252+
super().__init__()
253+
dim_out = default(dim_out, dim)
254+
linear = nn.Linear(dim, dim_out * 2)
255+
256+
self.to_gamma_beta = nn.Sequential(
257+
linear,
258+
Rearrange('b (gb d) -> gb b 1 d', gb = 2)
259+
)
260+
261+
nn.init.zeros_(linear.weight)
262+
nn.init.constant_(linear.bias, 1.)
263+
264+
def forward(self, x, cond):
265+
gamma, beta = self.to_gamma_beta(cond)
266+
return x * gamma + beta
267+
250268
class PixelNorm(Module):
251269
def __init__(self, dim, eps = 1e-4):
252270
super().__init__()
@@ -1100,7 +1118,8 @@ def __init__(
11001118
dim_text = self.conditioner.dim_latent
11011119
cross_attn_dim_context = dim_text
11021120

1103-
self.to_sos_text_cond = nn.Linear(dim_text, dim_fine)
1121+
self.text_coarse_film_cond = FiLM(dim_text, dim)
1122+
self.text_fine_film_cond = FiLM(dim_text, dim_fine)
11041123

11051124
# for summarizing the vertices of each face
11061125

@@ -1352,6 +1371,8 @@ def forward_on_codes(
13521371

13531372
text_embed, text_mask = maybe_dropped_text_embeds
13541373

1374+
pooled_text_embed = masked_mean(text_embed, text_mask, dim = 1)
1375+
13551376
attn_context_kwargs = dict(
13561377
context = text_embed,
13571378
context_mask = text_mask
@@ -1465,6 +1486,11 @@ def forward_on_codes(
14651486

14661487
should_cache_fine = not divisible_by(curr_vertex_pos + 1, num_tokens_per_face)
14671488

1489+
# condition face codes with text if needed
1490+
1491+
if self.condition_on_text:
1492+
face_codes = self.text_coarse_film_cond(face_codes, pooled_text_embed)
1493+
14681494
# attention on face codes (coarse)
14691495

14701496
if need_call_first_transformer:
@@ -1543,6 +1569,8 @@ def forward_on_codes(
15431569

15441570
fine_attn_context_kwargs = dict()
15451571

1572+
# optional text cross attention conditioning for fine transformer
1573+
15461574
if self.fine_cross_attend_text:
15471575
repeat_batch = fine_vertex_codes.shape[0] // text_embed.shape[0]
15481576

@@ -1554,6 +1582,16 @@ def forward_on_codes(
15541582
context_mask = text_mask
15551583
)
15561584

1585+
# also film condition the fine vertex codes
1586+
1587+
if self.condition_on_text:
1588+
repeat_batch = fine_vertex_codes.shape[0] // pooled_text_embed.shape[0]
1589+
1590+
pooled_text_embed = repeat(pooled_text_embed, 'b ... -> (b r) ...', r = repeat_batch)
1591+
fine_vertex_codes = self.text_fine_film_cond(fine_vertex_codes, pooled_text_embed)
1592+
1593+
# fine transformer
1594+
15571595
attended_vertex_codes, fine_cache = self.fine_decoder(
15581596
fine_vertex_codes,
15591597
cache = fine_cache,

meshgpt_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.10'
1+
__version__ = '1.2.11'

0 commit comments

Comments
 (0)