@@ -247,6 +247,24 @@ def scatter_mean(
247247
248248# resnet block
249249
250+ class FiLM (Module ):
251+ def __init__ (self , dim , dim_out = None ):
252+ super ().__init__ ()
253+ dim_out = default (dim_out , dim )
254+ linear = nn .Linear (dim , dim_out * 2 )
255+
256+ self .to_gamma_beta = nn .Sequential (
257+ linear ,
258+ Rearrange ('b (gb d) -> gb b 1 d' , gb = 2 )
259+ )
260+
261+ nn .init .zeros_ (linear .weight )
262+ nn .init .constant_ (linear .bias , 1. )
263+
264+ def forward (self , x , cond ):
265+ gamma , beta = self .to_gamma_beta (cond )
266+ return x * gamma + beta
267+
250268class PixelNorm (Module ):
251269 def __init__ (self , dim , eps = 1e-4 ):
252270 super ().__init__ ()
@@ -1100,7 +1118,8 @@ def __init__(
11001118 dim_text = self .conditioner .dim_latent
11011119 cross_attn_dim_context = dim_text
11021120
1103- self .to_sos_text_cond = nn .Linear (dim_text , dim_fine )
1121+ self .text_coarse_film_cond = FiLM (dim_text , dim )
1122+ self .text_fine_film_cond = FiLM (dim_text , dim_fine )
11041123
11051124 # for summarizing the vertices of each face
11061125
@@ -1352,6 +1371,8 @@ def forward_on_codes(
13521371
13531372 text_embed , text_mask = maybe_dropped_text_embeds
13541373
1374+ pooled_text_embed = masked_mean (text_embed , text_mask , dim = 1 )
1375+
13551376 attn_context_kwargs = dict (
13561377 context = text_embed ,
13571378 context_mask = text_mask
@@ -1465,6 +1486,11 @@ def forward_on_codes(
14651486
14661487 should_cache_fine = not divisible_by (curr_vertex_pos + 1 , num_tokens_per_face )
14671488
1489+ # condition face codes with text if needed
1490+
1491+ if self .condition_on_text :
1492+ face_codes = self .text_coarse_film_cond (face_codes , pooled_text_embed )
1493+
14681494 # attention on face codes (coarse)
14691495
14701496 if need_call_first_transformer :
@@ -1543,6 +1569,8 @@ def forward_on_codes(
15431569
15441570 fine_attn_context_kwargs = dict ()
15451571
1572+ # optional text cross attention conditioning for fine transformer
1573+
15461574 if self .fine_cross_attend_text :
15471575 repeat_batch = fine_vertex_codes .shape [0 ] // text_embed .shape [0 ]
15481576
@@ -1554,6 +1582,16 @@ def forward_on_codes(
15541582 context_mask = text_mask
15551583 )
15561584
1585+ # also film condition the fine vertex codes
1586+
1587+ if self .condition_on_text :
1588+ repeat_batch = fine_vertex_codes .shape [0 ] // pooled_text_embed .shape [0 ]
1589+
1590+ pooled_text_embed = repeat (pooled_text_embed , 'b ... -> (b r) ...' , r = repeat_batch )
1591+ fine_vertex_codes = self .text_fine_film_cond (fine_vertex_codes , pooled_text_embed )
1592+
1593+ # fine transformer
1594+
15571595 attended_vertex_codes , fine_cache = self .fine_decoder (
15581596 fine_vertex_codes ,
15591597 cache = fine_cache ,
0 commit comments